In [1]:
import torch
from torch import nn

class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [10]:
import pandas as pd
from collections import Counter

class Dataset(torch.utils.data.Dataset):
    def __init__(self):
        self.sequence_length = 4
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        train_df = pd.read_csv('data/reddit-cleanjokes.csv')
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )

In [16]:
import argparse
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader


def train(dataset, model):
    model.train()

    dataloader = DataLoader(dataset, batch_size=256)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(10):
        state_h, state_c = model.init_state(4)

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

In [8]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [18]:
dataset = Dataset()
model = Model(dataset)



train(dataset, model)
res = predict(dataset, model, text='Knock knock. Whos there?')

{'epoch': 0, 'batch': 0, 'loss': 8.833588600158691}
{'epoch': 0, 'batch': 1, 'loss': 8.828499794006348}
{'epoch': 0, 'batch': 2, 'loss': 8.820976257324219}
{'epoch': 0, 'batch': 3, 'loss': 8.823351860046387}
{'epoch': 0, 'batch': 4, 'loss': 8.805675506591797}
{'epoch': 0, 'batch': 5, 'loss': 8.797994613647461}
{'epoch': 0, 'batch': 6, 'loss': 8.803772926330566}
{'epoch': 0, 'batch': 7, 'loss': 8.777909278869629}
{'epoch': 0, 'batch': 8, 'loss': 8.755915641784668}
{'epoch': 0, 'batch': 9, 'loss': 8.698698997497559}
{'epoch': 0, 'batch': 10, 'loss': 8.66344165802002}
{'epoch': 0, 'batch': 11, 'loss': 8.538965225219727}
{'epoch': 0, 'batch': 12, 'loss': 8.469428062438965}
{'epoch': 0, 'batch': 13, 'loss': 8.385149002075195}
{'epoch': 0, 'batch': 14, 'loss': 8.11269474029541}
{'epoch': 0, 'batch': 15, 'loss': 8.022530555725098}
{'epoch': 0, 'batch': 16, 'loss': 7.809468746185303}
{'epoch': 0, 'batch': 17, 'loss': 7.817915439605713}
{'epoch': 0, 'batch': 18, 'loss': 7.666975975036621}
{'epo

{'epoch': 1, 'batch': 61, 'loss': 7.20242977142334}
{'epoch': 1, 'batch': 62, 'loss': 7.1563801765441895}
{'epoch': 1, 'batch': 63, 'loss': 7.08748197555542}
{'epoch': 1, 'batch': 64, 'loss': 7.194957733154297}
{'epoch': 1, 'batch': 65, 'loss': 7.119697093963623}
{'epoch': 1, 'batch': 66, 'loss': 7.115661144256592}
{'epoch': 1, 'batch': 67, 'loss': 6.939568519592285}
{'epoch': 1, 'batch': 68, 'loss': 7.137152194976807}
{'epoch': 1, 'batch': 69, 'loss': 6.87841272354126}
{'epoch': 1, 'batch': 70, 'loss': 7.302093982696533}
{'epoch': 1, 'batch': 71, 'loss': 7.2535834312438965}
{'epoch': 1, 'batch': 72, 'loss': 7.158138275146484}
{'epoch': 1, 'batch': 73, 'loss': 7.221046447753906}
{'epoch': 1, 'batch': 74, 'loss': 7.234049320220947}
{'epoch': 1, 'batch': 75, 'loss': 7.367117881774902}
{'epoch': 1, 'batch': 76, 'loss': 7.136973857879639}
{'epoch': 1, 'batch': 77, 'loss': 7.392655372619629}
{'epoch': 1, 'batch': 78, 'loss': 7.490032196044922}
{'epoch': 1, 'batch': 79, 'loss': 6.82153701782

{'epoch': 3, 'batch': 28, 'loss': 7.102365493774414}
{'epoch': 3, 'batch': 29, 'loss': 7.195680141448975}
{'epoch': 3, 'batch': 30, 'loss': 6.551056385040283}
{'epoch': 3, 'batch': 31, 'loss': 6.4758124351501465}
{'epoch': 3, 'batch': 32, 'loss': 6.565923690795898}
{'epoch': 3, 'batch': 33, 'loss': 6.820531368255615}
{'epoch': 3, 'batch': 34, 'loss': 6.779815196990967}
{'epoch': 3, 'batch': 35, 'loss': 6.977634906768799}
{'epoch': 3, 'batch': 36, 'loss': 6.93181848526001}
{'epoch': 3, 'batch': 37, 'loss': 6.741453170776367}
{'epoch': 3, 'batch': 38, 'loss': 7.057920455932617}
{'epoch': 3, 'batch': 39, 'loss': 6.881134986877441}
{'epoch': 3, 'batch': 40, 'loss': 7.05852746963501}
{'epoch': 3, 'batch': 41, 'loss': 6.757557392120361}
{'epoch': 3, 'batch': 42, 'loss': 7.058585166931152}
{'epoch': 3, 'batch': 43, 'loss': 6.764370441436768}
{'epoch': 3, 'batch': 44, 'loss': 6.701319217681885}
{'epoch': 3, 'batch': 45, 'loss': 6.795321464538574}
{'epoch': 3, 'batch': 46, 'loss': 6.94701766967

{'epoch': 4, 'batch': 89, 'loss': 6.647129535675049}
{'epoch': 4, 'batch': 90, 'loss': 7.103519916534424}
{'epoch': 4, 'batch': 91, 'loss': 6.474244117736816}
{'epoch': 4, 'batch': 92, 'loss': 6.796535015106201}
{'epoch': 4, 'batch': 93, 'loss': 6.105434417724609}
{'epoch': 5, 'batch': 0, 'loss': 6.654656410217285}
{'epoch': 5, 'batch': 1, 'loss': 6.541390419006348}
{'epoch': 5, 'batch': 2, 'loss': 6.56959342956543}
{'epoch': 5, 'batch': 3, 'loss': 6.781586647033691}
{'epoch': 5, 'batch': 4, 'loss': 6.649386882781982}
{'epoch': 5, 'batch': 5, 'loss': 6.6583027839660645}
{'epoch': 5, 'batch': 6, 'loss': 7.186623573303223}
{'epoch': 5, 'batch': 7, 'loss': 6.942822456359863}
{'epoch': 5, 'batch': 8, 'loss': 6.830474376678467}
{'epoch': 5, 'batch': 9, 'loss': 6.813939094543457}
{'epoch': 5, 'batch': 10, 'loss': 6.799866199493408}
{'epoch': 5, 'batch': 11, 'loss': 6.636995315551758}
{'epoch': 5, 'batch': 12, 'loss': 6.788219928741455}
{'epoch': 5, 'batch': 13, 'loss': 6.896061897277832}
{'e

{'epoch': 6, 'batch': 56, 'loss': 6.358236789703369}
{'epoch': 6, 'batch': 57, 'loss': 6.299840450286865}
{'epoch': 6, 'batch': 58, 'loss': 6.204270839691162}
{'epoch': 6, 'batch': 59, 'loss': 6.361382007598877}
{'epoch': 6, 'batch': 60, 'loss': 6.249892711639404}
{'epoch': 6, 'batch': 61, 'loss': 6.387689590454102}
{'epoch': 6, 'batch': 62, 'loss': 6.434951305389404}
{'epoch': 6, 'batch': 63, 'loss': 6.3051042556762695}
{'epoch': 6, 'batch': 64, 'loss': 6.322754383087158}
{'epoch': 6, 'batch': 65, 'loss': 6.351102828979492}
{'epoch': 6, 'batch': 66, 'loss': 6.3917155265808105}
{'epoch': 6, 'batch': 67, 'loss': 6.019638538360596}
{'epoch': 6, 'batch': 68, 'loss': 6.322899341583252}
{'epoch': 6, 'batch': 69, 'loss': 5.997907638549805}
{'epoch': 6, 'batch': 70, 'loss': 6.598304748535156}
{'epoch': 6, 'batch': 71, 'loss': 6.372706413269043}
{'epoch': 6, 'batch': 72, 'loss': 6.315666675567627}
{'epoch': 6, 'batch': 73, 'loss': 6.33037805557251}
{'epoch': 6, 'batch': 74, 'loss': 6.403052806

{'epoch': 8, 'batch': 24, 'loss': 6.301273345947266}
{'epoch': 8, 'batch': 25, 'loss': 6.108038902282715}
{'epoch': 8, 'batch': 26, 'loss': 5.858868598937988}
{'epoch': 8, 'batch': 27, 'loss': 5.933180809020996}
{'epoch': 8, 'batch': 28, 'loss': 6.423654079437256}
{'epoch': 8, 'batch': 29, 'loss': 6.509706497192383}
{'epoch': 8, 'batch': 30, 'loss': 5.773222923278809}
{'epoch': 8, 'batch': 31, 'loss': 5.742982387542725}
{'epoch': 8, 'batch': 32, 'loss': 5.883028507232666}
{'epoch': 8, 'batch': 33, 'loss': 6.184732437133789}
{'epoch': 8, 'batch': 34, 'loss': 6.052111625671387}
{'epoch': 8, 'batch': 35, 'loss': 6.191718101501465}
{'epoch': 8, 'batch': 36, 'loss': 6.113375186920166}
{'epoch': 8, 'batch': 37, 'loss': 6.06368350982666}
{'epoch': 8, 'batch': 38, 'loss': 6.361328125}
{'epoch': 8, 'batch': 39, 'loss': 6.133939743041992}
{'epoch': 8, 'batch': 40, 'loss': 6.284837245941162}
{'epoch': 8, 'batch': 41, 'loss': 5.997401714324951}
{'epoch': 8, 'batch': 42, 'loss': 6.413608074188232}


{'epoch': 9, 'batch': 85, 'loss': 6.061419486999512}
{'epoch': 9, 'batch': 86, 'loss': 5.717243194580078}
{'epoch': 9, 'batch': 87, 'loss': 5.938302516937256}
{'epoch': 9, 'batch': 88, 'loss': 5.784979820251465}
{'epoch': 9, 'batch': 89, 'loss': 5.863468647003174}
{'epoch': 9, 'batch': 90, 'loss': 6.383458614349365}
{'epoch': 9, 'batch': 91, 'loss': 5.692686557769775}
{'epoch': 9, 'batch': 92, 'loss': 5.970125675201416}
{'epoch': 9, 'batch': 93, 'loss': 5.2801194190979}


In [19]:
print(res)

['Knock', 'knock.', 'Whos', 'there?', 'iSurfer', 'add', 'star', 'Twisted', 'needs', 'to', 'too', 'chrome', 'say', 'a', 'bottom', 'will', 'his', 'behind', 'Moleskine', 'Asian?', 'fact,', 'can', 'the', 'interest', 'I', 'a', 'taser...', 'horn?"', 'was', 'itself?', 'say', 'I', 'hear', 'the', 'was', 'ate', 'I', 'said', "'Jallikatu", 'yesterday.', 'promoted?', 'embarrassed?', 'ballpark', 'knock', 'to', 'God', 'A', 'loss', 'I', 'door', 'their', 'flows', 'Like', 'told', 'in', 'which', 'going', 'of', 'Diesel?', 'a', 'Mom', 'cottage"', "I've", 'called?', 'knock!', 'teacher', 'once', 'Because', 'Did', 'for', '9', 'Planet', "how's", 'All', 'headed', '"Stay', 'Why', 'does', 'the', 'Wawa', 'The', 'rules', 'off.', 'one', 'Another', 'but', 'What', 'do', 'you', 'call', 'did', 'a', 'Birthday', 'I', 'an', 'students', 'at', 'a', 'illegal?', 'eye?', 'work', 'routine?', 'Why', 'to']
