In [29]:
import textwrap
from os.path import join, expanduser, exists
from urllib.error import URLError
from urllib.request import urlopen

In [15]:
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torchtext.data import Field

In [3]:
from core.text import TextDataset
from core.loop import Loop, Stepper
from core.iterators import SequenceIterator
from core.schedule import CosineAnnealingLR
from core.callbacks import EarlyStopping, Checkpoint, Logger

In [4]:
def set_random_seed(state=1):
    gens = (np.random.seed, torch.manual_seed, torch.cuda.manual_seed)
    for set_state in gens:
        set_state(state)

In [5]:
RANDOM_STATE = 1
set_random_seed(RANDOM_STATE)

## Dataset Downloading

In [None]:
def download(url, download_path, expected_size):
    if exists(download_path):
        print('The file was already downloaded')
        return
    
    try:
        r = urlopen(url)
    except URLError as e:
        print(f'Cannot download the data. Error: {e}')
        return
    
    if r.status != 200:
        print(f'HTTP Error: {r.status}')
        return
    
    data = r.read()
    if len(data) != expected_size:
        print(f'Invalid downloaded array size: {len(data)}')
        return
    
    text = data.decode(encoding='utf-8')
    with open(download_path, 'w') as file:
        file.write(text)
        
    print(f'Downloaded: {download_path}')

In [None]:
URL = 'https://s3.amazonaws.com/text-datasets/nietzsche.txt'

In [None]:
download(URL, PATH, 600901)

## Model Training

In [6]:
ROOT = expanduser(join('~', 'data', 'fastai', 'nietzsche'))
TRAIN_DIR = join(ROOT, 'trn')
VALID_DIR = join(ROOT, 'val')
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [7]:
class RNN(nn.Module):

    def __init__(self, vocab_size, n_factors, batch_size, n_hidden,
                 n_recurrent=1, architecture=nn.RNN, device=DEVICE):

        self.vocab_size = vocab_size
        self.n_hidden = n_hidden
        self.n_recurrent = n_recurrent
        self.device = device

        super().__init__()
        self.embed = nn.Embedding(vocab_size, n_factors)
        self.rnn = architecture(n_factors, n_hidden, num_layers=n_recurrent)
        self.out = nn.Linear(n_hidden, vocab_size)
        self.hidden_state = self.init_hidden(batch_size).to(device)
        self.batch_size = batch_size
        self.to(device)

    def forward(self, batch):
        bs = batch.size(1)
        if bs != self.batch_size:
            self.hidden_state = self.init_hidden(bs)
            self.batch_size = bs
        embeddings = self.embed(batch)
        rnn_outputs, h = self.rnn(embeddings, self.hidden_state)
        self.hidden_state = truncate_history(h)
        linear = self.out(rnn_outputs)
        return F.log_softmax(linear, dim=-1).view(-1, self.vocab_size)

    def init_hidden(self, batch_size):
        if type(self.rnn) == nn.LSTM:
            # an LSTM cell requires two hidden states
            h = torch.zeros(2, self.n_recurrent, batch_size, self.n_hidden)
        else:
            h = torch.zeros(self.n_recurrent, batch_size, self.n_hidden)
        return h.to(self.device)

In [8]:
def truncate_history(v):
    if type(v) == torch.Tensor:
        return v.detach()
    else:
        return tuple(truncate_history(x) for x in v)

In [9]:
def generate_text(model, field, seed, n=500):
    string = seed
    for i in range(n):
        indexes = field.numericalize(string)
        predictions = model(indexes.transpose(0, 1))
        last_output = predictions[-1]
        [most_probable] = torch.multinomial(last_output.exp(), 1)
        char = field.vocab.itos[most_probable]
        seed = seed[1:] + char
        string += char
    return string

In [10]:
def pretty_print(text, width=80):
    print('\n'.join(textwrap.wrap(text, width=width)))

In [40]:
def show_text(model, field, seed):
    pretty_print(generate_text(model, field, seed))

In [11]:
def create_dataset(bptt, batch_size, min_freq):
    field = Field(lower=True, tokenize=list)
    dataset = TextDataset(field, min_freq)
    factory = lambda seq: SequenceIterator(seq, bptt, batch_size)
    dataset.build(train=TRAIN_DIR, valid=VALID_DIR, iterator_factory=factory)
    return dataset, field

In [31]:
batch_size = 64
bptt = 8
min_freq = 5

In [32]:
dataset, field = create_dataset(bptt, batch_size, min_freq)

In [42]:
def train(dataset, field, arch=nn.LSTM, n_epochs=100, 
          n_factors=50, n_hidden=256, n_recurrent=1,
          callbacks=None):
    
    model = RNN(
        dataset.vocab_size,
        n_factors,
        batch_size,
        n_hidden,
        n_recurrent,
        architecture=arch)
    
    optimizer = optim.RMSprop(model.parameters(), lr=1e-3)
    cycle_length = dataset['train'].total_iters
    scheduler = CosineAnnealingLR(optimizer, t_max=cycle_length)
    stepper = Stepper(model, optimizer, scheduler, F.nll_loss)

    loop = Loop(stepper)
    loop.run(train_data=dataset['train'],
             valid_data=dataset['valid'],
             callbacks=callbacks,
             epochs=n_epochs)
    
    model.load_state_dict(torch.load(checkpoint.best_model))
    return model

In [43]:
model = train(dataset, field)

Epoch    1: train - 1.6699 valid - 1.6919
Epoch    2: train - 1.5128 valid - 1.5669
Epoch    3: train - 1.4296 valid - 1.4882
Epoch    4: train - 1.3588 valid - 1.4431
Epoch    5: train - 1.3095 valid - 1.4151
Epoch    6: train - 1.2804 valid - 1.3920
Epoch    7: train - 1.2800 valid - 1.4024
Epoch    8: train - 1.2463 valid - 1.3784
Epoch    9: train - 1.2186 valid - 1.3833
Epoch   10: train - 1.1936 valid - 1.3592
Epoch   11: train - 1.1684 valid - 1.3507
Epoch   12: train - 1.1480 valid - 1.3528
Epoch   13: train - 1.1339 valid - 1.3554


In [44]:
show_text(model, field, 'For thos')

For those new had to at the whoth, selhomity), be, of sumpareriations and
according to be suspition from a proved to immediately and found, of inness in
the free spirit, folly, a magner; it is defers in any christians, also, has
solit; he finds nowadays in its possible for it who is thus necessarily
subjects, indiffained, which to senwerful not the "falsisest she herd of
indispensed.--loationing and by no god construction of a dimestion of expedients
and father, indestructive, knowledge up to beethovers


In [None]:
for arch in (nn.RNN, nn.GRU, nn.LSTM):
    model = train(dataset, field, arch)
    
    

In [None]:
callbacks = [EarlyStopping(patience=3), Logger(), checkpoint]