In [8]:
import argparse

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

from torch.nn import functional as F

import requests

import numpy as np
import pandas as pd

import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader

from collections import Counter


In [9]:
# Check if CUDA is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [11]:
class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        sequence_length
    ):
        self.sequence_length = sequence_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        train_df = pd.read_csv('data/reddit-cleanjokes.csv')
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )

In [12]:
def train(dataset, model, sequence_length, batch_size, max_epochs):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(sequence_length)

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

In [13]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [14]:
max_epochs = 10
batch_size = 256
sequence_length = 4

dataset = Dataset(sequence_length)
model = Model(dataset)

train(dataset, model, sequence_length, batch_size, max_epochs)
print(predict(dataset, model, text='Knock knock. Whos there?'))

{'epoch': 0, 'batch': 0, 'loss': 8.841975212097168}
{'epoch': 0, 'batch': 1, 'loss': 8.83491039276123}
{'epoch': 0, 'batch': 2, 'loss': 8.825891494750977}
{'epoch': 0, 'batch': 3, 'loss': 8.818588256835938}
{'epoch': 0, 'batch': 4, 'loss': 8.80907154083252}
{'epoch': 0, 'batch': 5, 'loss': 8.79863452911377}
{'epoch': 0, 'batch': 6, 'loss': 8.798529624938965}
{'epoch': 0, 'batch': 7, 'loss': 8.776673316955566}
{'epoch': 0, 'batch': 8, 'loss': 8.75554370880127}
{'epoch': 0, 'batch': 9, 'loss': 8.684842109680176}
{'epoch': 0, 'batch': 10, 'loss': 8.622735977172852}
{'epoch': 0, 'batch': 11, 'loss': 8.473854064941406}
{'epoch': 0, 'batch': 12, 'loss': 8.343722343444824}
{'epoch': 0, 'batch': 13, 'loss': 8.260093688964844}
{'epoch': 0, 'batch': 14, 'loss': 7.991740703582764}
{'epoch': 0, 'batch': 15, 'loss': 7.9348063468933105}
{'epoch': 0, 'batch': 16, 'loss': 7.761424541473389}
{'epoch': 0, 'batch': 17, 'loss': 7.699403285980225}
{'epoch': 0, 'batch': 18, 'loss': 7.582327365875244}
{'epoc

In [15]:
print(predict(dataset, model, text='Knock knock. Whos there?'))

['Knock', 'knock.', 'Whos', 'there?', 'Bulls', 'snatchers?', '/r/tumblr', 'a', "Smith's", 'Just', "ain't", 'run', 'getting', 'lure', 'and', 'men', 'then', 'out.', 'meet', 'sharp,', 'age', 'so', 'boots?', 'photos?', 'then', 'My', '6:30', 'A', 'Loki.', 'What', 'do', 'call', 'the', '"Don\'t', "I'm", 'Likes', 'years...', 'Two', 'every', 'planed', 'chickens', 'their', "I've", 'He', 'getting', 'made', 'it', 'his', 'Jamaican', 'then', 'got', 'Necronomnomnomicon.', 'to', 'the', 'psychologist?', 'They', "I'm", 'slice', 'The', 'here!"', 'stringed', 'A', 'Engagement,', 'in', 'the', 'rocks', '#ThugLife', 'of', 'be', 'sure', 'of', 'you', 'documentary', 'What', 'the', 'casket', 'and', 'down', 'a', "'th'", 'This', 'surrounded', 'What', 'do', 'you', 'call', 'a', 'though,', 'that', 'table.', 'not', '"Stay', 'second', "here's", 'no', 'woman.', 'Velcro', 'say', 'to', 'the', 'Last', 'Platypus', 'walking', 'Well,']


In [16]:
print(predict(dataset, model, text='Is there anyone here?'))

['Is', 'there', 'anyone', 'here?', 'paw."', 'earing', 'could', 'doing', 'Will', 'We', 'What', 'to', 'hear', 'a', 'you', 'and', '"breakfast', 'to', '7?', 'they', 'be', 'ocean', 'with', 'know!', 'there?', 'terminator', 'Autumn', 'pasta', 'nose', 'but', 'against', 'when', '(not', 'was', 'Knock...', 'favorite', 'blood.', 'Coming', 'decide', 'spud?', 'in', 'he', 'his', 'mis-steak.', 'through', 'me', 'always', "don't", 'a', 'tomato,', 'from', 'An', 'Scholar', 'drivers', 'Because', 'I', "Lee's", 'hit', 'ended', 'of', 'her', 'turkey', '"It\'s', 'nose...', 'What', 'did', 'will', 'a', 'lives', 'BAAAACH', 'She', 'down', 'master', 'operates', 'he', 'a', 'dome?', 'nose', 'to', 'a', 'Satin!', 'idea)', 'it', 'have', 'are', 'brown', 'organ', 'worth', 'But', 'I', 'he', 'I', 'the', '2"X4"\'s', 'bulb?', 'and', 'a', '(I', 'of', 'presents...', 'who', 'Which', 'somewhere.', 'When']
