In [1]:
import os
os.chdir('..')

%load_ext autoreload
%autoreload 2

In [2]:
import datetime

import torch

import torch.nn as nn

from src.consts import *
from src.main import main, setup_torch, get_corpus
from src.model import RNNModel
from src.training import train, evaluate
from src.split_cross_entropy_loss import SplitCrossEntropyLoss
from src.utils import summary, check_cuda_mem, get_latest_model_file
from src.custom_data_parallel import CustomDataParallel
from src.parallel import DataParallelCriterion
from src.wsc_parser import generate_df_from_json
from src.winograd_schema_challenge import find_missing_wsc_words_in_corpus_vocab

In [3]:
setup_torch()
main_gpu_index = 0
device = torch.device("cuda:" + str(main_gpu_index) if USE_CUDA else "cpu")
corpus = get_corpus()

In [4]:
len(corpus.dictionary)

111550

In [5]:
assert corpus.valid.size()[0] == 11606861
assert corpus.train.max() < len(corpus.dictionary)
assert corpus.valid.max() < len(corpus.dictionary)
assert corpus.test.max() < len(corpus.dictionary)

In [6]:
df = generate_df_from_json()

In [7]:
find_missing_wsc_words_in_corpus_vocab(df, corpus, english=False)

['chatbots',
 'retirei',
 'Timmy',
 'protegemos',
 'largá-las',
 'acudi-lo',
 'usei',
 'gargalhou',
 'reconfortou',
 'abaixou',
 'balançavam',
 'constrangidos',
 'limpado',
 'assisti-la',
 'removi',
 'peixinho',
 'alicate',
 'acomodada',
 'malabarista',
 'acomodou',
 'compassivo',
 'malabarismos',
 'gameboy',
 'Huntertropic',
 'transando',
 'Janie',
 'custeou',
 'certificou-se',
 'Laputa',
 'willow-towered',
 'indiscreta',
 'Terpsichore',
 'kate',
 'enfiei',
 'respondê-la',
 'adorou',
 'GrWQWu8JyC',
 'respondentes',
 'barman',
 'pega-pega',
 'Kamtchatka',
 'Shatov',
 'ingrato',
 'entrevistaram',
 'newsletter',
 'muletas',
 'beliche',
 'Kirilov',
 'wrestles',
 'empático',
 'tapinha',
 'contou-lhe',
 'intrometida',
 'examplo',
 'melancias',
 'latir',
 'Check',
 'sacou',
 'Arqueologistas',
 'doíam',
 'esnobes',
 'repeti-la',
 '4:30',
 'Alongando',
 'tentei',
 'carreguei',
 'Canopy',
 'convidou-a',
 'treinadora',
 'punimos',
 'assoviando',
 'servi',
 'Loebner',
 'torradeira',
 'compreendê-

In [8]:
use_data_paralellization = True

In [9]:
ntokens = len(corpus.dictionary)

model = RNNModel(MODEL_TYPE, ntokens, EMBEDDINGS_SIZE, HIDDEN_UNIT_COUNT, LAYER_COUNT, DROPOUT_PROB,
                         TIED).to(device)
criterion = nn.CrossEntropyLoss()

if use_data_paralellization or USE_DATA_PARALLELIZATION:
    cuda_devices = [i for i in range(torch.cuda.device_count())]
    device_ids = [main_gpu_index] + cuda_devices[:main_gpu_index] + cuda_devices[main_gpu_index + 1:]
    model = CustomDataParallel(model, device_ids=device_ids)
    criterion = DataParallelCriterion(criterion, device_ids=device_ids)

optimizer = None

summary(model, criterion)

CustomDataParallel(
  (model): DataParallelModel(
    (module): RNNModel(
      (drop): Dropout(p=0.2)
      (encoder): Embedding(111550, 200)
      (rnn): LSTM(200, 200, num_layers=2, dropout=0.2)
      (decoder): Linear(in_features=200, out_features=111550, bias=True)
    )
  )
)

model.module.encoder.weight torch.Size([111550, 200])
model.module.rnn.weight_ih_l0 torch.Size([800, 200])
model.module.rnn.weight_hh_l0 torch.Size([800, 200])
model.module.rnn.bias_ih_l0 torch.Size([800])
model.module.rnn.bias_hh_l0 torch.Size([800])
model.module.rnn.weight_ih_l1 torch.Size([800, 200])
model.module.rnn.weight_hh_l1 torch.Size([800, 200])
model.module.rnn.bias_ih_l1 torch.Size([800])
model.module.rnn.bias_hh_l1 torch.Size([800])
model.module.decoder.weight torch.Size([111550, 200])
model.module.decoder.bias torch.Size([111550])

Total Parameters: 23,064,750


In [10]:
train(model, corpus, criterion, optimizer, device, use_data_paralellization)

INFO 2019-06-17 18:42:07,307: | epoch   1 |   200/ 1856 batches | lr 20.00 | ms/batch 247.55 | loss  8.51 | ppl  4967.97
INFO 2019-06-17 18:42:55,106: | epoch   1 |   400/ 1856 batches | lr 20.00 | ms/batch 238.99 | loss  7.87 | ppl  2614.08
INFO 2019-06-17 18:43:42,959: | epoch   1 |   600/ 1856 batches | lr 20.00 | ms/batch 239.27 | loss  7.73 | ppl  2272.82
INFO 2019-06-17 18:44:30,961: | epoch   1 |   800/ 1856 batches | lr 20.00 | ms/batch 240.00 | loss  7.63 | ppl  2062.42
INFO 2019-06-17 18:45:19,077: | epoch   1 |  1000/ 1856 batches | lr 20.00 | ms/batch 240.58 | loss  7.55 | ppl  1892.01
INFO 2019-06-17 18:46:07,157: | epoch   1 |  1200/ 1856 batches | lr 20.00 | ms/batch 240.40 | loss  7.49 | ppl  1786.53
INFO 2019-06-17 18:46:55,293: | epoch   1 |  1400/ 1856 batches | lr 20.00 | ms/batch 240.68 | loss  7.46 | ppl  1735.63
INFO 2019-06-17 18:47:43,390: | epoch   1 |  1600/ 1856 batches | lr 20.00 | ms/batch 240.48 | loss  7.42 | ppl  1668.99
INFO 2019-06-17 18:48:31,472: | 