# Word level text generation
### Imports, utils, and classes

In [1]:
import os
import re
import time
import requests
import shutil
import zipfile
from collections import Counter
from glob import glob
from itertools import chain

import joblib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from textblob import TextBlob
from torch.utils.data import Dataset, DataLoader

def get_embeddings(token_lists, embedding_file_path,
                   min_vocab_size=0, max_vocab_size=None):
    """
    Build emeddding mappings based on `texts`.

    Args:
        token_lists: A sequence of tokens to fit the embeddings
            on.
        embedding_file_path: The full path to a pretrained embeddings
            file.
        min_vocab_size: Minimum number of common words to include,
            even if they don't appear in `token_lists`. Defaults to
            `0`.
        max_vocab_size: Maximum number of words from the data set to
            include. Defaults to `None`.

    Returns:
        token_to_vec: A `dict` mapping between word tokens and numpy vectors.
        token_to_id: A `dict` mapping between word tokens and embedding ids.
        embedding_matrix: A numpy `ndarray` that maps between embedding ids
            and embedding vectors.

    """
    vocab_counter = Counter(chain(*token_lists))
    most_common_counts = vocab_counter.most_common(max_vocab_size)
    data_vocab = set(token for token, _ in most_common_counts)

    token_to_vec = {}
    with open(embedding_file_path) as f:
        for line_num, line in enumerate(f):
            values = line.split()  # Splits on spaces.
            word = values[0]
            if min_vocab_size < line_num + 1 and word not in data_vocab:
                continue
            vector = np.asarray(values[1:], dtype='float32')
            token_to_vec[word] = vector

    vocab = data_vocab | set(token_to_vec.keys())
    rand_state = np.random.RandomState(42)
    token_to_id = {'<PAD>': 0, '<EOS>': 1, '<START>': 2, '<UNK>': 3}
    num_meta_tokens = len(token_to_id)
    embedding_matrix = rand_state.rand(len(vocab) + num_meta_tokens, 200)
    for i in range(num_meta_tokens):
        embedding_matrix[i] = i
    vocab = sorted(vocab)  # Sort for consistent ids.
    for i, token in enumerate(vocab):
        word_id = i + num_meta_tokens
        if token in token_to_vec:
            embedding_matrix[word_id] = token_to_vec[token]
        token_to_id[token] = word_id

    return token_to_vec, token_to_id, embedding_matrix


def tokens_to_ids(token_lists, token_to_id, max_sequence_length=None,
                  add_eos=False, skip_unknown=False):
    """
    Args:
        texts: A sequence of texts to fit the embeddings
            on.
        token_to_id: A `dict` map between word tokens and embedding ids.
        max_sequence_length: The maximum number of words to include in
            each line of dialogue. Shorter sequences will be padded with
            the <PAD> vector. Defaults to `None`.
        add_eos: Whether or not to add an <EOS> token.
        skip_unknown: Whether or not to skip unknown tokens. If `False`,
            unknown tokens are assigned the `<UNK>` meta-token. Defaults
            to `False`.

    Returns:
        id_matrix: An array with shape `(len(texts), max_sequence_length)`
            containing the correct embedding ids.

    """
    id_lists = []
    for token_list in token_lists:
        token_ids = []
        for token in token_list:
            if token in token_to_id:
                token_ids.append(token_to_id[token])
            elif skip_unknown:
                token_ids.append(token_to_id['<UNK>'])
        if add_eos:
            token_ids.append(token_to_id['<EOS>'])
        id_lists.append(token_ids)

    if max_sequence_length is None:
        max_sequence_length = max(len(ids) for ids in id_lists)

    id_matrix = np.zeros((len(token_lists), max_sequence_length), dtype=int)
    for i, id_list in enumerate(id_lists):
        id_list = id_list[:max_sequence_length]
        id_matrix[i, :len(id_list)] = id_list

    return id_matrix


def get_train_val_test_indices(num_rows, val_ratio=0.25, test_ratio=0.25, seed=42):
    """Return indices of the train, test, and validation sets."""
    # Calculate the number of items in test and val sets.
    num_val = int(num_rows * val_ratio)
    num_test = int(num_rows * test_ratio)

    # Slice an array of permuted indices to those sizes.
    rand = np.random.RandomState(seed)
    indices = rand.permutation(range(num_rows))
    val_indices = indices[:num_val]
    test_indices = indices[num_val:num_val + num_test]
    train_indices = indices[num_val + num_test:]

    return train_indices, val_indices, test_indices


def text_to_tokens(text, seq_length=100):
    """
    Split a string into non-overlapping sequences of tokens.

    Args:
        text: A `string` to chop up into sequences.
        seq_length: The length of each sequence.

    Returns:
        X_tokens: A `list` of `list`s where each inner list contains
            `seq_length` tokens.
        y_tokens: A `list` with the same shape as `X_tokens` but where
            each word is the word after the word at the same index in
            `X_tokens`.

    """
    # Ensure text is lower case.
    text = text.lower()
    
    # Put NEWLINE between the lines. I want them to be
    # to be a token the model can predict but the simple
    # tokenizer I'm using will strip them out.
    text = text.replace('\n', ' NEWLINE ')
    
    # Glove has an embedding for the dash, but again the
    # simple tokenizer I'm using won't break up words 
    # connected with dashes.
    text = text.replace('—', ' — ')
    text = text.replace('-', ' - ')
    
    tokens = tuple(TextBlob(text).tokens)
    num_sequences = (len(tokens) - 1) // seq_length

    X_tokens = []
    y_tokens = []
    for seq_index in range(num_sequences):
        sequence_start = seq_index * seq_length
        x = tokens[sequence_start:sequence_start + seq_length]
        y = tokens[sequence_start + 1: sequence_start + seq_length + 1]
        X_tokens.append(x)
        y_tokens.append(y)

    return X_tokens, y_tokens


def ids_to_tokens(ids, id_to_word):
    """Turn a list of token ids into a list of tokens."""
    return [id_to_word[id] for id in ids]


def id_lists_to_token_lists(id_lists, id_to_word):
    """Turn a matrix of token ids into a matrix of tokens."""
    return [ids_to_tokens(ids, id_to_word) for ids in id_lists]


def save_checkpoint(model, metric_value, file_path, higher_is_better=True):
    """Save a checkpoint if performance improved."""
    try:
        old_state_dict = torch.load(file_path)
        prev_best = old_state_dict['checkpoint_metric']
    except (FileNotFoundError, KeyError):
        prev_best = None
    if (prev_best is None
            or (higher_is_better and prev_best < metric_value)
            or (not higher_is_better and prev_best > metric_value)):
        state = model.state_dict()
        state['checkpoint_metric'] = metric_value
        torch.save(state, file_path)
        return True
    return False


def load_checkpoint(model, file_path):
    """Load a checkpiont."""
    # I'm using the state dict to store performance of the model with
    # that state, but PyTorch doesn't like unknown keys in its
    # state dicts. Here's I'm just removing the that key and then
    # loading the model's state.
    state_dict = torch.load(file_path)
    state_dict.pop('checkpoint_metric')
    model.load_state_dict(state_dict)


def compute_loss_and_accuracy(y_pred, y_true):
    """
    Helper function to predict loss and calculate accuracy.

    Args:
        y_pred: A PyTorch tensor with shape `(batch_size, seq_length,
            vocab_size)`.
        y_true: A PyTorch tensor with shape `(batch_size, seq_length)`.

    Returns:
        loss: A tensor containing the cross entropy loss.
        accuracy: A float containing the accuracy of the predictions,
            assuming we predict the most likely word.

    """
    # Flatten `y_pred` from `(batch_size, seq_length, vocab_size)` to
    # `(batch_size * seq_length, vocab_size)`, a shape compatible with
    # the loss function.
    y_pred_flat = y_pred.view(-1, y_pred.size(2))
    # Similarly, reshape y_true from `(batch_size, seq_length)` to
    # `(batch_size * seq_length)`.
    y_true_flat = y_true.view(-1)
    loss = criterion(y_pred_flat, y_true_flat)
    accuracy = accuracy_score(y_true_flat, y_pred_flat.argmax(dim=-1))
    return loss, accuracy


class BooksDataset(Dataset):

    def __init__(self, books_path, embeddings_path, sequence_length=50, val_ratio=0, test_ratio=0):
        """
        A `Dataset` subclass that loads and prepares a dataset from 
        all the text files in a directory.
        
        Because this class implements `__len__` and `__getitem__`, it's
        instances are sequences which `DataLoader` can index into to 
        generate mini-batches.
        
        Args:
            books_path: The path to directory containing the text files. 
            embeddings_path: The path to pretrained embeddings file.
            sequence_length: The length of the sequences to train on.
            val_ratio: The proportion of the data to set a side for 
                validation. 
            test_ratio: The proportion of the data to set a side for 
                testing.
                
        """
        super().__init__()

        # Read in all text documents in the folder at `books_path` and
        # and join them into a single string.
        glob_path = os.path.join(books_path, '*.txt')
        texts = []
        for file_path in glob(glob_path):
            with open(file_path) as f:
                text = f.read()
                text = text.lower()
                # Select lines that are longer than 50 chars or that
                # start with quotation marks. Other lines are likely
                # things like chapter titles, et cetera.
                lines = re.findall(r'^(?:.{50}.*$)|(?:["“”].*$)',
                                   text, flags=re.MULTILINE)
                lines = [line.strip() for line in lines]
                text = '\n'.join(lines)
                texts.append(text)
        text = '\n'.join(texts)

        # Cut the text into `sequence_length` sized chunks.
        X_tokens, y_tokens = text_to_tokens(text, sequence_length)

        # Generate the embedding matrix and associated maps.
        _, self.token_to_id, self.embedding_matrix = get_embeddings(
            X_tokens, embeddings_path)
        self.id_to_token = {token_id: token for token, token_id in self.token_to_id.items()}

        # Pytorch wants these as LongTensors.
        X_ids = tokens_to_ids(X_tokens, self.token_to_id)
        X_ids = torch.LongTensor(X_ids)
        y_ids = tokens_to_ids(y_tokens, self.token_to_id)
        y_ids = torch.LongTensor(y_ids)

        # Do a train/val/test split.
        train_indicies, val_indicies, test_indicies = get_train_val_test_indices(
            len(X_tokens), val_ratio, test_ratio)
        self.X_train_ids = X_ids[train_indicies]
        self.y_train_ids = y_ids[train_indicies]
        self.X_val_ids = X_ids[val_indicies]
        self.y_val_ids = y_ids[val_indicies]
        self.X_test_ids = X_ids[test_indicies]
        self.y_test_ids = y_ids[test_indicies]

    def __len__(self):
        return len(self.X_train_ids)

    def __getitem__(self, index):
        return self.X_train_ids[index], self.y_train_ids[index]


class NextWordPredictor(nn.Module):

    def __init__(self, embedding_matrix, recurrent_layers=1,
                 train_embeddings=False, recur_size=256):
        """
        A neural net that predicts the next word at each time step.

        Args:
            embedding_matrix: A torch tensor or numpy array where each
                row is the embedding vector of a word.
            recurrent_layers: The number of LSTM layers to use. Defaults
                to `1`.
            train_embeddings: Whether or not to further train the
                embedding weights or fix them in place. Defaults to
                `False`.
            recur_size: The number of dimensions in recurrent state.
                Defaults to `256`.

        """
        super().__init__()

        # Create embedding layer and initialize its weights to passed
        # in embedding_matrix.
        if not isinstance(embedding_matrix, torch.Tensor):
            embedding_matrix = torch.FloatTensor(embedding_matrix)
        vocab_size = embedding_matrix.size(0)
        embedding_size = embedding_matrix.size(1)
        self.embeddings = nn.Embedding(vocab_size, embedding_size)
        self.embeddings.weight = nn.Parameter(embedding_matrix)
        self.embeddings.requires_grad = train_embeddings

        # Setup LSTM layer.
        self.recur_state = None
        self.recur_size = recur_size
        self.recurrent_layers = recurrent_layers
        self.lstm = nn.LSTM(embedding_size, recur_size,
                            num_layers=recurrent_layers, batch_first=True)

        # Setup fully connected word predictor.
        self.word_predictor = nn.Linear(recur_size, vocab_size)

    def init_recur_state(self, batch_size):
        """
        Return an empty hidden state for the recurrent layer.

        Args:
             batch_size (int): The number of training examples in
                each mini-batch.

        Returns:
            (tuple): A tuple of torch tensors each with the shape
                `(num_recur_layers, batch_size, recur_size)`.

        """
        return (torch.zeros(self.recurrent_layers, batch_size, self.recur_size),
                torch.zeros(self.recurrent_layers, batch_size, self.recur_size))

    def forward(self, x):
        """
        For each word, predict the probabilities for the next word.

        Args:
            x: A torch tensor with shape `(batch_size, seq_length)`.
                Its elements are word indices that map to word vectors
                in the `embedding_matrix`.

        Returns:
            A torch tensor with shape `(batch_size, seq_length, vocab_size)`.

        """
        batch_size = x.size(0)
        seq_length = x.size(1)
        # Passing `x` through the embedding layer looks up the word
        # indices and replaces them their corresponding embedding
        # vectors. This changes its shape from `(batch_length, seq_length)`
        # to `(batch_length, seq_length, embedding_size)`.
        x = self.embeddings(x)
        # Passing `x` through the LSTM layer changes its shape from
        # `(batch_length, seq_length, embedding_size)` to
        # `(batch_length, seq_length, recur_size)`
        x, self.recur_state = self.lstm(x, self.recur_state)
        # I'm calling `x.contiguous()` because otherwise the `x.view`
        # call below fails and prints an error message telling me
        # to call it.
        x = x.contiguous()
        # Flatten out the result so it has shape
        # `(batch_length * seq_length, recur_size)`.
        x = x.view(-1, self.recur_size)
        # Apply a fully connected layer to each element of the flattened
        # `x`, which are `recur_size` vectors representing words. This
        # Produces a tensor of shape `(batch_length * seq_length,
        # vocab_size)`.
        x = self.word_predictor(x)
        # Apply softmax activation over the last dimension, turning them
        # into probability distributions over words.
        x = F.log_softmax(x, dim=1)
        # Before returning, reshape to `(batch_size, seq_length,
        # vocab_size)`.
        return x.view(batch_size, seq_length, -1)

### Data downloading

In [2]:
DATA_URL = 'https://www.dropbox.com/s/5xeq6ijqfw9xlh9/data.zip?dl=1'
SAVE_PATH = 'data.zip'

if not os.path.exists('data'):
    # Download the data zip file.
    response = requests.get(DATA_URL, stream=True)
    with open(SAVE_PATH, 'wb') as f:
        shutil.copyfileobj(response.raw, f)
    # Unzip the file.
    with zipfile.ZipFile(SAVE_PATH, 'r') as z:
        z.extractall('')
    # Delete the zip file.
    shutil.rmtree('__MACOSX')
    os.remove(SAVE_PATH)

### Configuration options

In [28]:
BOOKS_PATH = 'data/texts'
EMEDDINGS_PATH = 'data/glove.6B.200d.txt'
CHECKPOINT_PATH = 'data/checkpoint.pickle'
DATASET_SAVE_PATH = 'data/dataset.pickle'
MODEL_SAVE_PATH = 'data/model.pickle'

HIDDEN_LAYERS = 1
TRAIN_EMBEDDINGS = True
USE_CHECKPOINTING = False
NUM_EPOCHS = 30
BATCH_SIZE = 64
LOG_FREQ = 100  # In mini-batches.
LOAD_CACHED_DATASET = True  # Turn off if you use your own data.

### Model and datset setup

In [29]:
if LOAD_CACHED_DATASET and os.path.exists(DATASET_SAVE_PATH):
    books_dataset = joblib.load(DATASET_SAVE_PATH)
else:
    books_dataset = BooksDataset(BOOKS_PATH, EMEDDINGS_PATH, val_ratio=0, test_ratio=0)
    joblib.dump(books_dataset, DATASET_SAVE_PATH)
dataloader = DataLoader(books_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [36]:
model = NextWordPredictor(books_dataset.embedding_matrix, 
                          recurrent_layers=HIDDEN_LAYERS, 
                          train_embeddings=TRAIN_EMBEDDINGS)
criterion = nn.NLLLoss()
optimizer = torch.optim.RMSprop(model.parameters())

### Training loop

In [37]:
for epoch in range(NUM_EPOCHS):
    print(f'Starting epoch {epoch}.')
    epoch_start_time = time.time()
    running_loss = 0.0
    running_accuracy = 0.0

    # Iterating over a DataLoader instance iterates over the whole
    # dataset in `batch_size` sized chunks.
    for i, (batch_X, batch_y) in enumerate(dataloader):
        # Switch to training mode (shouldn't matter here, but good
        # practice as some layers like dropout behave differently in
        # training and evaluation modes).
        model.train()
        # Clear the gradients from previous mini-batch.
        optimizer.zero_grad()
        # Reset the recurrent state. We don't want to retain this
        # between batches as the text segments aren't going to be
        # contiguous.
        model.recur_state = model.init_recur_state(batch_X.size(0))
        # Predict the next word at each time step. `y_pred` has shape
        # `(batch_size, seq_length, vocab_size)`.
        y_pred = model(batch_X)
        loss, accuracy = compute_loss_and_accuracy(y_pred, batch_y)
        # Calculate gradients.
        loss.backward()
        # Update weights.
        optimizer.step()

        # Periodically print the loss and prediction accuracy. Usually
        # with a language model you'd also show perplexity, but as it's
        # a function of our cross entropy loss and not that intuitive,
        # I've elected not to.
        running_loss += loss.item()
        running_accuracy += accuracy
        if i % LOG_FREQ == LOG_FREQ - 1:
            average_loss = running_loss / LOG_FREQ
            average_accuracy = running_accuracy / LOG_FREQ
            print(f'Mini-batch: {i + 1}/{len(dataloader)} '
                  f'Loss: {average_loss:.5f} Accuracy: {average_accuracy:.5f}')
            running_loss = 0.0
            running_accuracy = 0.0

    # Log elapsed_time for the epoch.
    elapsed_time = time.time() - epoch_start_time
    print(f'Epoch {epoch} completed in {elapsed_time // 60:.0f} minutes '
          f'{elapsed_time % 60:.0f} seconds.\n')

    # If checkpointing, check to see if loss on the validation set is
    # lower than at the previous checkpoint. If so, save the state of
    # the model. This is less useful on this task as we also want the
    # model to memorize the dataset here, but on other tasks it's
    # to save a state of the model before it starts to over fit.
    if USE_CHECKPOINTING:
        if not books_dataset.X_val_ids.size(0):
            raise Exception("Can't checkpoint without validation data!")
        with torch.no_grad:
            model.recur_state = model.init_recur_state(
                books_dataset.X_val_ids.size(0))
            y_pred = model(books_dataset.X_val_ids)
            val_loss, val_accuracy = compute_loss_and_accuracy(
                y_pred, books_dataset.y_val_ids)
            print(f'Val Loss: {val_loss:.5f} Val Accuracy: {val_accuracy:.5f}')
            if save_checkpoint(model, val_loss, CHECKPOINT_PATH, higher_is_better=False):
                print('Validation set performance improved, saving checkpoint.\n')
            else:
                print('Validation set performance did NOT improve.\n')

# If checkpointing, load the model state with the lowest validation loss.
if USE_CHECKPOINTING:
    load_checkpoint(model, CHECKPOINT_PATH)

Starting epoch 0.
Mini-batch: 100/409 Loss: 6.17821 Accuracy: 0.14781
Mini-batch: 200/409 Loss: 5.26532 Accuracy: 0.19967
Mini-batch: 300/409 Loss: 5.08472 Accuracy: 0.21158
Mini-batch: 400/409 Loss: 4.98249 Accuracy: 0.21980
Epoch 0 completed in 4 minutes 16 seconds.

Starting epoch 1.
Mini-batch: 100/409 Loss: 4.57262 Accuracy: 0.23451
Mini-batch: 200/409 Loss: 4.54959 Accuracy: 0.23888
Mini-batch: 300/409 Loss: 4.55398 Accuracy: 0.23779
Mini-batch: 400/409 Loss: 4.54612 Accuracy: 0.24056
Epoch 1 completed in 3 minutes 57 seconds.

Starting epoch 2.
Mini-batch: 100/409 Loss: 4.07988 Accuracy: 0.26506
Mini-batch: 200/409 Loss: 4.15897 Accuracy: 0.25848
Mini-batch: 300/409 Loss: 4.17456 Accuracy: 0.26027
Mini-batch: 400/409 Loss: 4.21322 Accuracy: 0.25613
Epoch 2 completed in 4 minutes 0 seconds.

Starting epoch 3.
Mini-batch: 100/409 Loss: 3.73409 Accuracy: 0.29273
Mini-batch: 200/409 Loss: 3.82359 Accuracy: 0.28371
Mini-batch: 300/409 Loss: 3.89030 Accuracy: 0.27799
Mini-batch: 400/4

### Optional save/restore

In [None]:
# OPTIONAL: Save the complete model object for easy reloading.
joblib.dump(model, MODEL_SAVE_PATH)

In [None]:
# OPTIONAL: Load the model and the dataset.
model = joblib.load(MODEL_SAVE_PATH)
books_dataset = joblib.load(DATASET_SAVE_PATH)

### Text generation

In [45]:
SEED = "Harry stepped onto Y.T.'s plank."
NUM_TO_GENERATE = 400

seed_tokens = [tuple(TextBlob(SEED.lower()).tokens)]
seed_ids = tokens_to_ids(seed_tokens, books_dataset.token_to_id)
seed_ids = torch.LongTensor(seed_ids)

with torch.no_grad():
    model.eval()
    # Clear the hidden state of the model.
    model.recur_state = model.init_recur_state(1)
    # Read in seed.
    y_pred = model(seed_ids)
    # Slice of the predicted probs for the last element of the output
    # (which wasn't in the seed) and predict the token with highest
    # probability. The result is tensor with shape `(1, 1)`.
    last_token_pred = y_pred[:, -1:, :].argmax(-1)

    # Generate more text by feeding in the previously predicted word
    # and then predicting the next.
    generated = [last_token_pred]
    for _ in range(NUM_TO_GENERATE - 1):
        next_word_pred = model(generated[-1])
        generated.append(next_word_pred.argmax(-1))

# Combine the ids from the seed and the generated text.
combined_ids = list(tensor.item() for tensor in seed_ids[0])
combined_ids.extend(tensor.item() for tensor in generated)

# Map the token ids back to tokens.
combined_tokens = ids_to_tokens(combined_ids, books_dataset.id_to_token)

# Turn it back into a string.
generated_text = ' '.join(combined_tokens)
generated_text = generated_text.replace(' NEWLINE ', '\n\n')
print(generated_text)

harry stepped onto y.t . 's plank . she is not too scared of the people who feel anymore .

“ okay , sasha . you ’ re infected with your lawyer , we ’ re all dead rumor ? ”

“ i ’ m not asking you to do , ” randy says . “ i ’ m not sure if you like. ”

“ if you ’ re going to sponsor a stable currency , ” goto dengo says . “ we are going to be a major drum heap drum into the seafloor , which is now calls to the states .

“ what is the site ? ”

“ yes , sir. ”

“ i ’ m not allowed to do this. ” the lieutenant ’ s smile turns around , and opens up his briefcase . “ i ’ m not going to be back. ”

“ i ’ m not asking you to do , ” avi says . “ i ’ m going to be troublesome . i ’ m going to get the fuck out of jail. ”

“ sir ! yes , sir ! ”

“ i ’ m not asking you to justify the , ” says the skipper . “ i ’ m not sure if you ’ re not limited to congratulate your passport control over the convoy , i ’ ll be wanting to make a better sense , ” he says .

“ i ’ ll be fucked ! ” nina says . “ i ’ 