In [1]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import pdb
from torch.utils.data import Dataset, DataLoader

# %load_ext autoreload
# %autoreload 2

torch.set_printoptions(linewidth=200)

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
hidden_size = 100

In [3]:
class DinosDataset(Dataset):
    def __init__(self):
        super().__init__()
        with open('dinos.txt') as f:
            content = f.read().lower()
            self.vocab = sorted(set(content))
            self.vocab_size = len(self.vocab)
            self.lines = content.splitlines()
        self.ch_to_idx = {c:i for i, c in enumerate(self.vocab)}
        self.idx_to_ch = {i:c for i, c in enumerate(self.vocab)}
    
    def __getitem__(self, index):
        line = self.lines[index]
        x_str = ' ' + line #add a space at the beginning, which indicates a vector of zeros.
        y_str = line + '\n'
        x = torch.zeros([len(x_str), self.vocab_size], dtype=torch.float)
        y = torch.empty(len(x_str), dtype=torch.long)
        
        y[0] = self.ch_to_idx[y_str[0]]
        #we start from the second character because the first character of x was nothing(vector of zeros).
        for i, (x_ch, y_ch) in enumerate(zip(x_str[1:], y_str[1:]), 1):
            x[i][self.ch_to_idx[x_ch]] = 1
            y[i] = self.ch_to_idx[y_ch]
        
        return x, y
    
    def __len__(self):
        return len(self.lines)

In [4]:
trn_ds = DinosDataset()
trn_dl = DataLoader(trn_ds, batch_size=1, shuffle=True)

In [5]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear_hh = nn.Linear(hidden_size, hidden_size)
        self.linear_hx = nn.Linear(input_size, hidden_size, bias=False)
        self.linear_output = nn.Linear(hidden_size, output_size)
    
    def forward(self, h_prev, x):
        h = torch.tanh(self.linear_hh(h_prev) + self.linear_hx(x))
        y = self.linear_output(h)
        return h, y

In [6]:
model = RNN(trn_ds.vocab_size, hidden_size, trn_ds.vocab_size).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)

In [7]:
def print_sample(sample_idxs):
    print(trn_ds.idx_to_ch[sample_idxs[0]].upper(), end='')
    [print(trn_ds.idx_to_ch[x], end='') for x in sample_idxs[1:]]

In [8]:
def sample(model):
    model.eval()
    word_size=0
    newline_idx = trn_ds.ch_to_idx['\n']
    indices = []
    pred_char_idx = -1
    h_prev = torch.zeros([1, hidden_size], dtype=torch.float, device=device)
    x = h_prev.new_zeros([1, trn_ds.vocab_size])
    with torch.no_grad():
        while pred_char_idx != newline_idx and word_size != 50:
            h_prev, y_pred = model(h_prev, x)
            softmax_scores = torch.softmax(y_pred, dim=1).cpu().numpy().ravel()
            np.random.seed(np.random.randint(1, 5000))
            idx = np.random.choice(np.arange(trn_ds.vocab_size), p=softmax_scores)
            indices.append(idx)
            
            x = (y_pred == y_pred.max(1)[0]).float()
            pred_char_idx = idx
            
            word_size += 1
        
        if word_size == 50:
            indices.append(newline_idx)
    return indices

In [9]:
def train_one_epoch(model, loss_fn, optimizer):
    for line_num, (x, y) in enumerate(trn_dl):
        model.train()
        loss = 0
        optimizer.zero_grad()
        h_prev = torch.zeros([1, hidden_size], dtype=torch.float, device=device)
        x, y = x.to(device), y.to(device)
        for i in range(x.shape[1]):
            h_prev, y_pred = model(h_prev, x[:, i])
            loss += loss_fn(y_pred, y[:, i])
        if (line_num+1) % 100 == 0:
            print_sample(sample(model))
            
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5) #gradient clipping
        optimizer.step()

In [10]:
def train(model, loss_fn, optimizer, dataset='dinos', epochs=1):
    for e in range(1, epochs+1):
        print(f'{"-"*20} Epoch {e} {"-"*20}')
        train_one_epoch(model, loss_fn, optimizer)

In [11]:
train(model, loss_fn, optimizer, epochs=2000)

-------------------- Epoch 1 --------------------
Hdx
Saz
Eqang
Tyapaanaurus
Fuaghaurus
Zptnairus
Irashtsus
Okolatresiurus
Eitwaurus
Tjhaoaitaurus
Lhpsaurus
Uainaorbs
Vaeus
Lfrks
Tbrussnrus
-------------------- Epoch 2 --------------------
Etpnturus
Amrarlerus
Burasaurulaarus
Bururaurutaurus
Inbms
Slsalhuaus
Hlcucaurus
Tureshusus
Suaoaerur
Gucdsaarus
Ibrarturus
Snrusaurus
Tarocaulus
Aeroaarrus
Tcrustnrus
-------------------- Epoch 3 --------------------
Esrmtourus
Jncis
Slsacjuaus
Hmcucaurus
Turiseusus
Smarrc
Sbrhs
Burasaurur
Antusaurus
Surusaurun
Lbouciurcs
Margsaumus
Xrsasourus
Jyroshuhus
Guculterus
-------------------- Epoch 4 --------------------
Lirhlauius
Verusaurushurus
Aarcos
Slsaeoudus
Hmauaaurus
Tureskusus
Sobrodaurus
Aaurasaurus
Antosaurus
Strusaurue
Lbpuaousas
Larasaunus
Xrsassurus
Lyrosoulus
Gucunterus
-------------------- Epoch 5 --------------------
Mkrimaunus
Ugrusaurus
Tnrus
Jnbns
Slsaciudus
Gmbucaurus
Turisiusus
Suaipaurus
Tbatnaonaurus
Kvmulfurus
Snxupiuras
Alrasaunu

## Print training data (used for debugging, you can ignore this)

In [12]:
def print_ds(ds, num_examples=10):
    for i, (x, y) in enumerate(trn_ds, 1):
        print('*'*50)
        x_str, y_str = '', ''
        for idx in y:
            y_str += trn_ds.idx_to_ch[idx.item()]
        print(repr(y_str))

        for t in x[1:]:
            x_str += trn_ds.idx_to_ch[t.argmax().item()]
        print(repr(x_str))

        if i == num_examples:
            break

In [13]:
print_ds(trn_ds, 5)

**************************************************
'aachenosaurus\n'
'aachenosaurus'
**************************************************
'aardonyx\n'
'aardonyx'
**************************************************
'abdallahsaurus\n'
'abdallahsaurus'
**************************************************
'abelisaurus\n'
'abelisaurus'
**************************************************
'abrictosaurus\n'
'abrictosaurus'
