In [1]:
import argparse
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
import import_ipynb
from model import Model
from preprocess import Dataset

importing Jupyter notebook from model.ipynb
importing Jupyter notebook from preprocess.ipynb


In [2]:
def train(dataset, model, args):
    model.train()

    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
    )

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    for epoch in range(args.max_epochs):
        state_h, state_c = model.init_state(args.sequence_length)

        for batch, (x, y) in enumerate(dataloader):

            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

In [3]:
def predict(dataset, model, text, next_words=50):
    words = text.split(' ')
    model.eval()

    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [4]:
parser = argparse.ArgumentParser()
parser.add_argument('--max-epochs', type=int, default=15)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--sequence-length', type=int, default=4)
args, unknown = parser.parse_known_args()


dataset = Dataset(args)
model = Model(dataset)

train(dataset, model, args)


{'epoch': 0, 'batch': 0, 'loss': 8.832392692565918}
{'epoch': 0, 'batch': 1, 'loss': 8.839625358581543}
{'epoch': 0, 'batch': 2, 'loss': 8.837326049804688}
{'epoch': 0, 'batch': 3, 'loss': 8.835883140563965}
{'epoch': 0, 'batch': 4, 'loss': 8.840753555297852}
{'epoch': 0, 'batch': 5, 'loss': 8.83841323852539}
{'epoch': 0, 'batch': 6, 'loss': 8.83777904510498}
{'epoch': 0, 'batch': 7, 'loss': 8.83365249633789}
{'epoch': 0, 'batch': 8, 'loss': 8.839579582214355}
{'epoch': 0, 'batch': 9, 'loss': 8.840751647949219}
{'epoch': 0, 'batch': 10, 'loss': 8.835182189941406}
{'epoch': 0, 'batch': 11, 'loss': 8.842975616455078}
{'epoch': 0, 'batch': 12, 'loss': 8.839540481567383}
{'epoch': 0, 'batch': 13, 'loss': 8.827845573425293}
{'epoch': 0, 'batch': 14, 'loss': 8.832148551940918}
{'epoch': 0, 'batch': 15, 'loss': 8.82927131652832}
{'epoch': 0, 'batch': 16, 'loss': 8.826456069946289}
{'epoch': 0, 'batch': 17, 'loss': 8.833536148071289}
{'epoch': 0, 'batch': 18, 'loss': 8.829416275024414}
{'epoch

{'epoch': 1, 'batch': 63, 'loss': 7.245120525360107}
{'epoch': 1, 'batch': 64, 'loss': 7.258820533752441}
{'epoch': 1, 'batch': 65, 'loss': 7.231175899505615}
{'epoch': 1, 'batch': 66, 'loss': 7.25005578994751}
{'epoch': 1, 'batch': 67, 'loss': 7.094273090362549}
{'epoch': 1, 'batch': 68, 'loss': 7.264418125152588}
{'epoch': 1, 'batch': 69, 'loss': 7.043867588043213}
{'epoch': 1, 'batch': 70, 'loss': 7.334185600280762}
{'epoch': 1, 'batch': 71, 'loss': 7.297176361083984}
{'epoch': 1, 'batch': 72, 'loss': 7.218429088592529}
{'epoch': 1, 'batch': 73, 'loss': 7.187628746032715}
{'epoch': 1, 'batch': 74, 'loss': 7.224567413330078}
{'epoch': 1, 'batch': 75, 'loss': 7.345601558685303}
{'epoch': 1, 'batch': 76, 'loss': 7.208913803100586}
{'epoch': 1, 'batch': 77, 'loss': 7.357769012451172}
{'epoch': 1, 'batch': 78, 'loss': 7.458227634429932}
{'epoch': 1, 'batch': 79, 'loss': 6.8592305183410645}
{'epoch': 1, 'batch': 80, 'loss': 7.087447166442871}
{'epoch': 1, 'batch': 81, 'loss': 7.2556657791

{'epoch': 3, 'batch': 31, 'loss': 6.729982852935791}
{'epoch': 3, 'batch': 32, 'loss': 6.879522323608398}
{'epoch': 3, 'batch': 33, 'loss': 7.062889099121094}
{'epoch': 3, 'batch': 34, 'loss': 6.997822284698486}
{'epoch': 3, 'batch': 35, 'loss': 7.2448625564575195}
{'epoch': 3, 'batch': 36, 'loss': 7.130382061004639}
{'epoch': 3, 'batch': 37, 'loss': 6.961644649505615}
{'epoch': 3, 'batch': 38, 'loss': 7.2773895263671875}
{'epoch': 3, 'batch': 39, 'loss': 7.094320774078369}
{'epoch': 3, 'batch': 40, 'loss': 7.330890655517578}
{'epoch': 3, 'batch': 41, 'loss': 7.022881507873535}
{'epoch': 3, 'batch': 42, 'loss': 7.289077281951904}
{'epoch': 3, 'batch': 43, 'loss': 7.026420593261719}
{'epoch': 3, 'batch': 44, 'loss': 6.946214199066162}
{'epoch': 3, 'batch': 45, 'loss': 7.042489528656006}
{'epoch': 3, 'batch': 46, 'loss': 7.219879627227783}
{'epoch': 3, 'batch': 47, 'loss': 7.516852855682373}
{'epoch': 3, 'batch': 48, 'loss': 6.94767951965332}
{'epoch': 3, 'batch': 49, 'loss': 7.188207149

{'epoch': 4, 'batch': 92, 'loss': 7.110905170440674}
{'epoch': 4, 'batch': 93, 'loss': 6.626267910003662}
{'epoch': 5, 'batch': 0, 'loss': 6.959883689880371}
{'epoch': 5, 'batch': 1, 'loss': 6.950994968414307}
{'epoch': 5, 'batch': 2, 'loss': 6.891692161560059}
{'epoch': 5, 'batch': 3, 'loss': 7.1095662117004395}
{'epoch': 5, 'batch': 4, 'loss': 7.061836242675781}
{'epoch': 5, 'batch': 5, 'loss': 7.034764289855957}
{'epoch': 5, 'batch': 6, 'loss': 7.396126747131348}
{'epoch': 5, 'batch': 7, 'loss': 7.236608982086182}
{'epoch': 5, 'batch': 8, 'loss': 7.1908698081970215}
{'epoch': 5, 'batch': 9, 'loss': 7.084639072418213}
{'epoch': 5, 'batch': 10, 'loss': 7.064618110656738}
{'epoch': 5, 'batch': 11, 'loss': 7.030664920806885}
{'epoch': 5, 'batch': 12, 'loss': 7.10312557220459}
{'epoch': 5, 'batch': 13, 'loss': 7.209578990936279}
{'epoch': 5, 'batch': 14, 'loss': 6.918496131896973}
{'epoch': 5, 'batch': 15, 'loss': 6.989140033721924}
{'epoch': 5, 'batch': 16, 'loss': 6.773491382598877}
{'

{'epoch': 6, 'batch': 60, 'loss': 6.94575834274292}
{'epoch': 6, 'batch': 61, 'loss': 7.082808971405029}
{'epoch': 6, 'batch': 62, 'loss': 7.0870490074157715}
{'epoch': 6, 'batch': 63, 'loss': 7.006938457489014}
{'epoch': 6, 'batch': 64, 'loss': 7.093347549438477}
{'epoch': 6, 'batch': 65, 'loss': 7.011999130249023}
{'epoch': 6, 'batch': 66, 'loss': 7.003612041473389}
{'epoch': 6, 'batch': 67, 'loss': 6.831462383270264}
{'epoch': 6, 'batch': 68, 'loss': 7.017219543457031}
{'epoch': 6, 'batch': 69, 'loss': 6.773460865020752}
{'epoch': 6, 'batch': 70, 'loss': 7.184905529022217}
{'epoch': 6, 'batch': 71, 'loss': 7.107147216796875}
{'epoch': 6, 'batch': 72, 'loss': 7.0181379318237305}
{'epoch': 6, 'batch': 73, 'loss': 7.054732322692871}
{'epoch': 6, 'batch': 74, 'loss': 7.047669887542725}
{'epoch': 6, 'batch': 75, 'loss': 7.200679302215576}
{'epoch': 6, 'batch': 76, 'loss': 7.008828163146973}
{'epoch': 6, 'batch': 77, 'loss': 7.224951267242432}
{'epoch': 6, 'batch': 78, 'loss': 7.318644046

{'epoch': 8, 'batch': 28, 'loss': 7.246639251708984}
{'epoch': 8, 'batch': 29, 'loss': 7.37012243270874}
{'epoch': 8, 'batch': 30, 'loss': 6.725759983062744}
{'epoch': 8, 'batch': 31, 'loss': 6.682334899902344}
{'epoch': 8, 'batch': 32, 'loss': 6.840283393859863}
{'epoch': 8, 'batch': 33, 'loss': 7.006250381469727}
{'epoch': 8, 'batch': 34, 'loss': 6.945496559143066}
{'epoch': 8, 'batch': 35, 'loss': 7.226787567138672}
{'epoch': 8, 'batch': 36, 'loss': 7.109573841094971}
{'epoch': 8, 'batch': 37, 'loss': 6.92486047744751}
{'epoch': 8, 'batch': 38, 'loss': 7.237672805786133}
{'epoch': 8, 'batch': 39, 'loss': 7.051650524139404}
{'epoch': 8, 'batch': 40, 'loss': 7.294072151184082}
{'epoch': 8, 'batch': 41, 'loss': 6.999327182769775}
{'epoch': 8, 'batch': 42, 'loss': 7.25674295425415}
{'epoch': 8, 'batch': 43, 'loss': 6.9730072021484375}
{'epoch': 8, 'batch': 44, 'loss': 6.887473106384277}
{'epoch': 8, 'batch': 45, 'loss': 6.977423191070557}
{'epoch': 8, 'batch': 46, 'loss': 7.183465480804

{'epoch': 9, 'batch': 89, 'loss': 6.987062454223633}
{'epoch': 9, 'batch': 90, 'loss': 7.3673224449157715}
{'epoch': 9, 'batch': 91, 'loss': 6.900260925292969}
{'epoch': 9, 'batch': 92, 'loss': 7.092333793640137}
{'epoch': 9, 'batch': 93, 'loss': 6.56434965133667}
{'epoch': 10, 'batch': 0, 'loss': 6.986889362335205}
{'epoch': 10, 'batch': 1, 'loss': 6.966728687286377}
{'epoch': 10, 'batch': 2, 'loss': 6.910133361816406}
{'epoch': 10, 'batch': 3, 'loss': 7.120316505432129}
{'epoch': 10, 'batch': 4, 'loss': 7.076371192932129}
{'epoch': 10, 'batch': 5, 'loss': 7.067514896392822}
{'epoch': 10, 'batch': 6, 'loss': 7.423311233520508}
{'epoch': 10, 'batch': 7, 'loss': 7.257700443267822}
{'epoch': 10, 'batch': 8, 'loss': 7.2096028327941895}
{'epoch': 10, 'batch': 9, 'loss': 7.09486722946167}
{'epoch': 10, 'batch': 10, 'loss': 7.073594093322754}
{'epoch': 10, 'batch': 11, 'loss': 7.031402587890625}
{'epoch': 10, 'batch': 12, 'loss': 7.106471538543701}
{'epoch': 10, 'batch': 13, 'loss': 7.219233

{'epoch': 11, 'batch': 55, 'loss': 6.971989631652832}
{'epoch': 11, 'batch': 56, 'loss': 7.047425746917725}
{'epoch': 11, 'batch': 57, 'loss': 7.001873016357422}
{'epoch': 11, 'batch': 58, 'loss': 6.966463088989258}
{'epoch': 11, 'batch': 59, 'loss': 7.006446838378906}
{'epoch': 11, 'batch': 60, 'loss': 6.929228782653809}
{'epoch': 11, 'batch': 61, 'loss': 7.072596549987793}
{'epoch': 11, 'batch': 62, 'loss': 7.0669169425964355}
{'epoch': 11, 'batch': 63, 'loss': 6.987858295440674}
{'epoch': 11, 'batch': 64, 'loss': 7.090846538543701}
{'epoch': 11, 'batch': 65, 'loss': 7.000576019287109}
{'epoch': 11, 'batch': 66, 'loss': 6.986774444580078}
{'epoch': 11, 'batch': 67, 'loss': 6.825445175170898}
{'epoch': 11, 'batch': 68, 'loss': 7.003419399261475}
{'epoch': 11, 'batch': 69, 'loss': 6.767592430114746}
{'epoch': 11, 'batch': 70, 'loss': 7.176876544952393}
{'epoch': 11, 'batch': 71, 'loss': 7.105376720428467}
{'epoch': 11, 'batch': 72, 'loss': 7.008915901184082}
{'epoch': 11, 'batch': 73, 

{'epoch': 13, 'batch': 21, 'loss': 7.129179000854492}
{'epoch': 13, 'batch': 22, 'loss': 7.082434177398682}
{'epoch': 13, 'batch': 23, 'loss': 7.162782192230225}
{'epoch': 13, 'batch': 24, 'loss': 7.143493175506592}
{'epoch': 13, 'batch': 25, 'loss': 6.885279178619385}
{'epoch': 13, 'batch': 26, 'loss': 6.870786666870117}
{'epoch': 13, 'batch': 27, 'loss': 6.875373363494873}
{'epoch': 13, 'batch': 28, 'loss': 7.244100093841553}
{'epoch': 13, 'batch': 29, 'loss': 7.367822170257568}
{'epoch': 13, 'batch': 30, 'loss': 6.7208333015441895}
{'epoch': 13, 'batch': 31, 'loss': 6.6758131980896}
{'epoch': 13, 'batch': 32, 'loss': 6.833217144012451}
{'epoch': 13, 'batch': 33, 'loss': 6.9972243309021}
{'epoch': 13, 'batch': 34, 'loss': 6.936469078063965}
{'epoch': 13, 'batch': 35, 'loss': 7.2242865562438965}
{'epoch': 13, 'batch': 36, 'loss': 7.106492042541504}
{'epoch': 13, 'batch': 37, 'loss': 6.9168806076049805}
{'epoch': 13, 'batch': 38, 'loss': 7.2291035652160645}
{'epoch': 13, 'batch': 39, '

{'epoch': 14, 'batch': 80, 'loss': 6.929724216461182}
{'epoch': 14, 'batch': 81, 'loss': 7.111954689025879}
{'epoch': 14, 'batch': 82, 'loss': 7.088320255279541}
{'epoch': 14, 'batch': 83, 'loss': 7.140276908874512}
{'epoch': 14, 'batch': 84, 'loss': 6.9647064208984375}
{'epoch': 14, 'batch': 85, 'loss': 7.141602516174316}
{'epoch': 14, 'batch': 86, 'loss': 6.862828731536865}
{'epoch': 14, 'batch': 87, 'loss': 7.001965045928955}
{'epoch': 14, 'batch': 88, 'loss': 6.906075477600098}
{'epoch': 14, 'batch': 89, 'loss': 6.987839221954346}
{'epoch': 14, 'batch': 90, 'loss': 7.363979816436768}
{'epoch': 14, 'batch': 91, 'loss': 6.89520263671875}
{'epoch': 14, 'batch': 92, 'loss': 7.086781978607178}
{'epoch': 14, 'batch': 93, 'loss': 6.539266586303711}


In [5]:
print(predict(dataset, model, text='Knock knock. Whos there?'))

['Knock', 'knock.', 'Whos', 'there?', 'Max', 'moment.)', 'lack-toes', 'ball?', 'wolves', "treefrog's", 'romance', 'sausage...', 'itself?', 'What', 'was', "couldn't", 'is', 'Block', 'wrong', 'says', 'Why', 'uses', 'Because', 'doctor?', 'told', '"is', 'obvious?', 'out', 'a', 'An', 'From:', 'calluses', 'Rex!', 'pumpkin?', 'Washed', 'is', 'flowers', 'know', 'angle', 'born', 'dish', 'the', 'was', 'deer.', 'did', 'top', 'man,', 'calm', 'a', 'the', 'alright', 'Gorgonzola.', 'FLICK', 'is']
