# Text generation

This tutorial covers using LSTMs on PyTorch for generating text; in this case - pretty lame jokes.

In [5]:
import torch
import pandas as pd
from collections import Counter
import argparse
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader

## 1. Dataset

We will train a joke text generator using LSTM networks in PyTorch. For this tutorial, we use Reddit clean jokes dataset to train the network. The dataset has 1623 jokes. To load the data into PyTorch, use PyTorch Dataset class.

In [13]:
data_filepath = 'NEU/reddit-cleanjokes.csv'

In [14]:
class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        args,
    ):
        self.args = args
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        train_df = pd.read_csv(data_filepath)
        text = train_df['Joke'].str.cat(sep=' ')
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.args['sequence_length']

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.args['sequence_length']]),
            torch.tensor(self.words_indexes[index+1:index+self.args['sequence_length']+1]),
        )

This Dataset inherits from the PyTorch's torch.utils.data.Dataset class and defines two important methods __len__ and __getitem__. 

load_words function loads the dataset. Unique words are calculated in the dataset to define the size of the network's vocabulary and embedding size. index_to_word and word_to_index converts words to number indexes and visa versa.

## 2. Model

In [4]:
class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2,
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

This is a standard looking PyTorch model. Embedding layer converts word indexes to word vectors. LSTM is the main learnable part of the network - PyTorch implementation has the gating mechanism implemented inside the LSTM cell that can learn long sequences of data.

## 3. Training

In [17]:
def train(dataset, model, args):
    model.train()

    dataloader = DataLoader(dataset, batch_size=args['batch_size'])
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(args['max_epochs']):
        state_h, state_c = model.init_state(args['sequence_length'])

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

Use PyTorch DataLoader and Dataset abstractions to load the jokes data.

Use CrossEntropyLoss as a loss function and Adam as an optimizer with default params. You can tweak it later.

## 4. Text generation

In [8]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

## 5. Execute predictions

In [11]:
args = {'max_epochs': 10, 'batch_size': 256, 'sequence_length': 4}

In [15]:
dataset = Dataset(args)
model = Model(dataset)

In [18]:
train(dataset, model, args)

{'epoch': 0, 'batch': 0, 'loss': 8.828627586364746}
{'epoch': 0, 'batch': 1, 'loss': 8.834393501281738}
{'epoch': 0, 'batch': 2, 'loss': 8.814643859863281}
{'epoch': 0, 'batch': 3, 'loss': 8.814126968383789}
{'epoch': 0, 'batch': 4, 'loss': 8.803616523742676}
{'epoch': 0, 'batch': 5, 'loss': 8.784758567810059}
{'epoch': 0, 'batch': 6, 'loss': 8.785212516784668}
{'epoch': 0, 'batch': 7, 'loss': 8.758509635925293}
{'epoch': 0, 'batch': 8, 'loss': 8.729255676269531}
{'epoch': 0, 'batch': 9, 'loss': 8.665733337402344}
{'epoch': 0, 'batch': 10, 'loss': 8.574596405029297}
{'epoch': 0, 'batch': 11, 'loss': 8.44304370880127}
{'epoch': 0, 'batch': 12, 'loss': 8.270025253295898}
{'epoch': 0, 'batch': 13, 'loss': 8.215472221374512}
{'epoch': 0, 'batch': 14, 'loss': 7.938713550567627}
{'epoch': 0, 'batch': 15, 'loss': 7.895543098449707}
{'epoch': 0, 'batch': 16, 'loss': 7.701065540313721}
{'epoch': 0, 'batch': 17, 'loss': 7.703535079956055}
{'epoch': 0, 'batch': 18, 'loss': 7.59310245513916}
{'epo

The model predicts the next 100 words after Knock knock. Whos there?

In [19]:
print(predict(dataset, model, text='Knock knock. Whos there?'))

['Knock', 'knock.', 'Whos', 'there?', 'offspring', 'sour', 'tomato', 'thyme.', 'Go', 'been', 'said,', 'five', "he's", 'in?', 'hundred', 'No?', 'mine?', 'when', 'my', 'beans', 'look', 'an', 'if', 'You', 'a', 'Affleckted', 'What', 'who', 'you', 'fluffy?', 'To', 'My', 'Ton', 'over?', 'in', 'the', 'chicken', 'lawnmower', 'I', 'not', 'all', 'swallow', 'take', "what's", 'if', 'through', 'just', 'a', 'fish', 'out', 'so', 'open', 'oh', 'word', 'get', 'a', 'drum', 'and', 'travel', 'My', 'heavy', 'absolutely', 'memory', 'of', 'pillowcase?', 'veloci-raptured.', 'India', '..', 'I', 'saw', 'your', 'dog', 'A', 'stuck', 'heard', 'How', 'is', 'the', 'round', 'zero', 'have', 'get', 'my', 'Italian.', 'Elephino', 'on', 'up', 'it', 'into', 'and', 'Saus', 'on', "'React'", 'Moroccan', 'Porcel', 'line.', 'him', 'harm', 'another', 'an', 'today...', 'bi-polar', 'today...', 'but']
