In [1]:
import os
os.chdir('..')

%load_ext autoreload
%autoreload 2

In [2]:
import datetime

import torch

import torch.nn as nn

from src.consts import *
from src.main import main, setup_torch, get_corpus
from src.model import RNNModel
from src.training import train, evaluate
from src.split_cross_entropy_loss import SplitCrossEntropyLoss

from notebooks.utils import summary, check_cuda_mem

In [3]:
use_data_paralellization = False

In [4]:
setup_torch()
device = torch.device("cuda" if USE_CUDA else "cpu")
corpus = get_corpus()

# TODO remove these two lines
assert len(corpus.dictionary) == 16127
assert corpus.valid.size()[0] == 216347
assert corpus.train.max() < len(corpus.dictionary)
assert corpus.valid.max() < len(corpus.dictionary)
assert corpus.test.max() < len(corpus.dictionary)

ntokens = len(corpus.dictionary)
model = RNNModel(MODEL_TYPE, ntokens, EMBEDDINGS_SIZE, HIDDEN_UNIT_COUNT, LAYER_COUNT, DROPOUT_PROB,
                 TIED)
if use_data_paralellization or USE_DATA_PARALLELIZATION:
    model = CustomDataParallel(model)
else:
    model.to(device)
criterion = nn.CrossEntropyLoss()
# criterion = SplitCrossEntropyLoss(EMBEDDINGS_SIZE, splits=[3800, 30000, 180000], verbose=False)
# criterion.to(device)

summary(model, criterion)

RNNModel(
  (drop): Dropout(p=0.2)
  (encoder): Embedding(16127, 200)
  (rnn): LSTM(200, 200, num_layers=2, dropout=0.2)
  (decoder): Linear(in_features=200, out_features=16127, bias=True)
)

encoder.weight torch.Size([16127, 200])
rnn.weight_ih_l0 torch.Size([800, 200])
rnn.weight_hh_l0 torch.Size([800, 200])
rnn.bias_ih_l0 torch.Size([800])
rnn.bias_hh_l0 torch.Size([800])
rnn.weight_ih_l1 torch.Size([800, 200])
rnn.weight_hh_l1 torch.Size([800, 200])
rnn.bias_ih_l1 torch.Size([800])
rnn.bias_hh_l1 torch.Size([800])
decoder.weight torch.Size([16127, 200])
decoder.bias torch.Size([16127])

Total Parameters: 3,884,727


In [5]:
train(model, corpus, criterion, device)

INFO 2019-05-24 16:20:08,472: | epoch   1 |   200/ 2965 batches | lr 20.00 | ms/batch  7.56 | loss  7.01 | ppl  1103.45
INFO 2019-05-24 16:20:09,921: | epoch   1 |   400/ 2965 batches | lr 20.00 | ms/batch  7.24 | loss  6.27 | ppl   528.87
INFO 2019-05-24 16:20:11,370: | epoch   1 |   600/ 2965 batches | lr 20.00 | ms/batch  7.24 | loss  5.93 | ppl   375.06
INFO 2019-05-24 16:20:12,818: | epoch   1 |   800/ 2965 batches | lr 20.00 | ms/batch  7.24 | loss  5.78 | ppl   322.76
INFO 2019-05-24 16:20:14,271: | epoch   1 |  1000/ 2965 batches | lr 20.00 | ms/batch  7.26 | loss  5.70 | ppl   298.95
INFO 2019-05-24 16:20:15,724: | epoch   1 |  1200/ 2965 batches | lr 20.00 | ms/batch  7.26 | loss  5.58 | ppl   266.04
INFO 2019-05-24 16:20:17,176: | epoch   1 |  1400/ 2965 batches | lr 20.00 | ms/batch  7.26 | loss  5.46 | ppl   236.14
INFO 2019-05-24 16:20:18,631: | epoch   1 |  1600/ 2965 batches | lr 20.00 | ms/batch  7.27 | loss  5.45 | ppl   233.80
INFO 2019-05-24 16:20:20,082: | epoch   

In [6]:
# timestamp = datetime.datetime.now()
# with open(MODEL_FILE_NAME.format(timestamp), 'wb') as f:
#     torch.save(model, f)

In [6]:
# # with open(MODEL_FILE_NAME.format(timestamp), 'rb') as f:
# # with open('models/trained_models/model-2019-05-22 18:36:17.481233.pt', 'rb') as f:
# with open('models/trained_models/model-2019-05-24 14:31:54.449858.pt', 'rb') as f:
#     model = torch.load(f)
#     # after load the rnn params are not a continuous chunk of memory
#     # this makes them a continuous chunk, and will speed up forward pass
#     model.rnn.flatten_parameters()

In [6]:
# val_loss = evaluate(model, corpus, criterion, device)