In [17]:
import torch
from torch import nn
import pandas as pd

class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [18]:
import torch
import pandas as pd
from collections import Counter

class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        args,
    ):
        self.args = args
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        train_df = pd.read_csv('data/reddit-cleanjokes.csv')
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.args.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.args.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.args.sequence_length+1]),
        )

In [19]:
import argparse
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
from model import Model
from dataset import Dataset

def train(dataset, model, args):
    model.train()

    dataloader = DataLoader(dataset, batch_size=args.batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(args.max_epochs):
        state_h, state_c = model.init_state(args.sequence_length)

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

In [20]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [21]:
parser = argparse.ArgumentParser()
parser.add_argument('--max-epochs', type=int, default=100)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--sequence-length', type=int, default=4)
#args = parser.parse_args()
args, unknown = parser.parse_known_args()
dataset = Dataset(args)
model = Model(dataset)

train(dataset, model, args)
print(predict(dataset, model, text='Knock knock. Whos there?'))

{'epoch': 0, 'batch': 0, 'loss': 8.849908828735352}
{'epoch': 0, 'batch': 1, 'loss': 8.842696189880371}
{'epoch': 0, 'batch': 2, 'loss': 8.843104362487793}
{'epoch': 0, 'batch': 3, 'loss': 8.825246810913086}
{'epoch': 0, 'batch': 4, 'loss': 8.824419021606445}
{'epoch': 0, 'batch': 5, 'loss': 8.812392234802246}
{'epoch': 0, 'batch': 6, 'loss': 8.809162139892578}


KeyboardInterrupt: 

In [35]:
#print(predict(dataset, model, text='Knock knock. Whos there?'))

['Knock', 'knock.', 'Whos', 'there?', 'Barely', 'instead', 'about', '"fire"', 'beef...', 'Even', 'can', 'seen', 'people', 'in', 'the', 'tree', 'one', 'Well,', 'the', 'ball', 'Yeah,', 'he', 'would', 'interest.', 'as', 'one', 'something', 'where', 'they', 'are', 'those', 'but', "he's", 'playground?', 'of', 'Bagpiper,', 'need', 'any', 'serious.', 'and', 'my', 'cow.', "I'm", 'sharks', 'grammar...', 'but', 'all', 'like', '2', 'you', 'cereal...', 'not', 'he', 'go.', 'you!', "I'm", 'idea)', 'like', 'heard', 'to', 'drum', 'at', 'car', 'at', 'his', 'big', 'friend?', 'bear.*', 'dd/mm/yy', 'What', 'do', 'you', 'call', 'when', 'the', 'heck', 'for', 'carrot.', 'What', 'kind', 'of', 'write?', 'lamp', "doesn't", 'PARSLEYMONIOUS', 'Are', 'you', 'buy', 'back', 'to', 'pop', 'Eggs', 'Now', 'done', 'on', 'be', 'shoes', 'over', 'ate', 'emergency!', 'Where', 'do', 'Japan,', 'criticize']


In [13]:
def load_words():
        train_df = pd.read_csv('data/reddit-cleanjokes.csv')
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')

In [33]:
#load_words()

In [34]:
#dir(dataset)
#dataset.words