In [1]:
import argparse
from datetime import datetime
import numpy as np

from bilm.training import train, load_options_latest_checkpoint, load_vocab
from bilm.data import BidirectionalLMDataset

now = datetime.now()
date_fmt = '{:%m%d_%H%M}'.format(now)

train_prefix = '/media/scatter/scatterdisk/sandbox_temp/data/kakaotalk_sol_elmo/messages/*/*'
val_prefix = '/media/scatter/scatterdisk/sandbox_temp/data/kakaotalk_sol_elmo/val/*/*'
vocab_file = '/media/scatter/scatterdisk/sandbox_temp/data/kakaotalk_sol_elmo/kakaotalk_sol_unique_tokens.txt'
save_dir = '/media/scatter/scatterdisk/elmo_ckpt_{}'.format(date_fmt)

In [2]:
def main():
    # load the vocab
    vocab = load_vocab(vocab_file, 50)

    # define the options
    batch_size = 64  # batch size for each GPU
    n_gpus = 2

    # number of tokens in training data (this for 1B Word Benchmark)
    # n_train_tokens = 768648884
    # 연애의 과학 토크나이징된 카톡 데이터 (identified_corpus_20180105) 토큰 개수
    n_train_tokens = 605918

    options = {
        'bidirectional': True,
        'char_cnn': {
            'activation': 'relu',
            'embedding': {'dim': 16},
            'filters': [[1, 32],
                        [2, 32],
                        [3, 64],
                        [4, 128],
                        [5, 256],
                        [6, 512],
                        [7, 1204]],
            'max_characters_per_token': 50,
            'n_characters': 261,
            'n_highway': 2
        },

        'dropout': 0.2,
        'lstm': {
            'cell_clip': 3,
            'dim': 1024,
            'n_layers': 2,
            'proj_clip': 5,
            'projection_dim': 256,
            'use_skip_connections': True
        },

        'all_clip_norm_val': 10.0,
        'n_epochs': 20,
        'n_train_tokens': n_train_tokens,
        'batch_size': batch_size,
        'n_tokens_vocab': vocab.size,
        'unroll_steps': 20,
        'n_negative_samples_batch': 4096,
    }

    prefix = train_prefix
    data = BidirectionalLMDataset(prefix,
                                  vocab,
                                  test=False,
                                  shuffle_on_load=True)
    tf_save_dir = save_dir
    tf_log_dir = save_dir
    train(options, data, n_gpus, tf_save_dir, tf_log_dir)


# if __name__ == '__main__':
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--save_dir', help='Location of checkpoint files')
#     parser.add_argument('--vocab_file', help='Vocabulary file')
#     parser.add_argument('--train_prefix', help='Prefix for train files')
#     args = parser.parse_args()

#     main(args)