In [None]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

from collections import Counter

In [2]:
import pickle
import torch
import torch.optim as optim

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [4]:
if device.type == 'cuda':
    cuda_device = -1
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

Tesla K80
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [5]:
import re
from typing import Dict, List, Tuple, Set

import torch
import torch.optim as optim
from allennlp.common.file_utils import cached_path
from allennlp.common.util import START_SYMBOL, END_SYMBOL
from allennlp.data import DataLoader, AllennlpDataset
from allennlp.data.samplers import BucketBatchSampler
from allennlp.data.fields import TextField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token, CharacterTokenizer
from allennlp.data.vocabulary import Vocabulary, DEFAULT_PADDING_TOKEN
from allennlp.models import Model
from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.trainer import GradientDescentTrainer

In [6]:
EMBEDDING_SIZE = 32
HIDDEN_SIZE = 256
BATCH_SIZE = 128

In [7]:
def read_dataset(all_chars: Set[str]=None) -> List[List[Token]]:
    """Read a plan text file and return character-tokenized sentences."""
    tokenizer = CharacterTokenizer()
    sentences = []
    with open(cached_path('https://s3.amazonaws.com/realworldnlpbook/data/tatoeba/sentences.eng.10k.txt')) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            line = re.sub(' +', ' ', line)
            tokens = tokenizer.tokenize(line)
            if all_chars:
                tokens = [token for token in tokens if token.text in all_chars]
            sentences.append(tokens)

    return sentences

In [8]:
def tokens_to_lm_instance(tokens: List[Token],
                          token_indexers: Dict[str, TokenIndexer]):
    tokens = list(tokens)   # shallow copy
    tokens.insert(0, Token(START_SYMBOL))
    tokens.append(Token(END_SYMBOL))

    input_field = TextField(tokens[:-1], token_indexers)
    output_field = TextField(tokens[1:], token_indexers)
    return Instance({'input_tokens': input_field,
                     'output_tokens': output_field})

In [9]:
class RNNLanguageModel(Model):
    def __init__(self,
                 embedder: TextFieldEmbedder,
                 hidden_size: int,
                 max_len: int,
                 vocab: Vocabulary) -> None:
        super().__init__(vocab)

        self.embedder = embedder

        self.rnn = PytorchSeq2SeqWrapper(
            torch.nn.LSTM(EMBEDDING_SIZE, HIDDEN_SIZE, batch_first=True))

        self.hidden2out = torch.nn.Linear(in_features=self.rnn.get_output_dim(),
                                          out_features=vocab.get_vocab_size('tokens'))
        self.hidden_size = hidden_size
        self.max_len = max_len

    def forward(self, input_tokens, output_tokens):
        mask = get_text_field_mask(input_tokens)
        embeddings = self.embedder(input_tokens)
        rnn_hidden = self.rnn(embeddings, mask)
        out_logits = self.hidden2out(rnn_hidden)

        """
        THIS IS LIKELY NOT HOW I SHOULD FIX THIS BUT IT WAS THE BEST
        I COULD DO TO GET THIS WORKING

        At this stage, `output_tokens` looks like this (module the specific token indices):

        {'tokens': {'tokens': tensor([[16, 45,  5,  ...,  0,  0,  0],
                [51, 56, 48,  ...,  0,  0,  0],
                [44, 54,  2,  ...,  0,  0,  0],
                ...,
                [14, 54,  7,  ...,  0,  0,  0],
                [10, 48, 22,  ...,  0,  0,  0],
                [51, 36, 56,  ..., 58,  0,  0]])}}

        which seems like it's being double indexed somehow.

        Thus, calling output_tokens = output_tokens["tokens"] to unnest `tokens`
        resolves this in an unideal way.
        """
        output_tokens = output_tokens["tokens"]
        loss = sequence_cross_entropy_with_logits(out_logits, output_tokens['tokens'], mask)

        return {'loss': loss}

    def generate(self) -> Tuple[List[Token], torch.tensor]:

        start_symbol_idx = self.vocab.get_token_index(START_SYMBOL, 'tokens')
        end_symbol_idx = self.vocab.get_token_index(END_SYMBOL, 'tokens')
        padding_symbol_idx = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN, 'tokens')

        log_likelihood = 0.
        words = []
        state = (torch.zeros(1, 1, self.hidden_size), torch.zeros(1, 1, self.hidden_size))

        word_idx = start_symbol_idx

        for i in range(self.max_len):
            tokens = torch.tensor([[word_idx]])

            embeddings = self.embedder({'tokens': tokens})
            output, state = self.rnn._module(embeddings, state)
            output = self.hidden2out(output)

            log_prob = torch.log_softmax(output[0, 0], dim=0)

            dist = torch.exp(log_prob)

            word_idx = start_symbol_idx

            while word_idx in {start_symbol_idx, padding_symbol_idx}:
                word_idx = torch.multinomial(
                    dist, num_samples=1, replacement=False).item()

            log_likelihood += log_prob[word_idx]

            if word_idx == end_symbol_idx:
                break

            token = Token(text=self.vocab.get_token_from_index(word_idx, 'tokens'))
            words.append(token)

        return words, log_likelihood

In [10]:
all_chars = {END_SYMBOL, START_SYMBOL}
all_chars.update("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,!?'-")

In [11]:
train_set = read_dataset(all_chars)

In [12]:
token_counts = {char: 1 for char in all_chars}
vocab = Vocabulary({'tokens': token_counts})

token_indexers = {'tokens': SingleIdTokenIndexer()}
instances = [tokens_to_lm_instance(tokens, token_indexers)
             for tokens in train_set]

In [13]:
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_SIZE)
embedder = BasicTextFieldEmbedder({"tokens": token_embedding})

model = RNNLanguageModel(embedder=embedder,
                         hidden_size=HIDDEN_SIZE,
                         max_len=80,
                         vocab=vocab)

In [14]:
dataset = AllennlpDataset(instances, vocab)
data_loader = DataLoader(dataset,
                         batch_size=BATCH_SIZE)

optimizer = optim.Adam(model.parameters(), lr=5.e-3)

In [15]:
#model = model.cuda(cuda_device)
trainer = GradientDescentTrainer(
    model=model,
    optimizer=optimizer,
    data_loader=data_loader,
    num_epochs=10)

#trainer.train()

with open("./text_generation.th", 'rb') as f:
    model.load_state_dict(torch.load(f))

In [16]:
def predict(text: str, model: Model) -> float:
    tokenizer = CharacterTokenizer()
    tokens = tokenizer.tokenize(text)
    
    token_indexers = {'tokens': SingleIdTokenIndexer()}
    instance = tokens_to_lm_instance(tokens, token_indexers)
    output = model.forward_on_instance(instance)
    print(output)

In [None]:
with open("./text_generation.th", 'wb') as f:
    torch.save(model.state_dict(), f)

In [17]:
predict('The trip to the beach was ruined by bad weather.', model)

{'loss': 9.363358}


In [18]:
predict('The trip to the beach was ruined by bad dogs.', model)

{'loss': 9.456648}


In [19]:
predict('by weather was trip my bad beach the ruined to.', model)

{'loss': 9.827207}


In [20]:
for _ in range(50):
    tokens, _ = model.generate()
    print(''.join(token.text for token in tokens))

RuntimeError: Could not run 'aten::values' with arguments from the 'CPUTensorId' backend. 'aten::values' is only available for these backends: [SparseCPUTensorId, SparseCUDATensorId, VariableTensorId].