In [1]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pdb
from torch.utils.data import Dataset, DataLoader

%load_ext autoreload
%autoreload 2

torch.set_printoptions(linewidth=200)

In [2]:
hidden_size = 100

class DinosDataset(Dataset):
    def __init__(self):
        super().__init__()
        with open('dinos.txt') as f:
            content = f.read().lower()
            self.vocab = sorted(set(content))
            self.vocab_size = len(self.vocab)
            self.lines = content.splitlines()
        self.ch_to_idx = {ch:i for i,ch in enumerate(self.vocab)}
        self.idx_to_ch = {i:ch for i,ch in enumerate(self.vocab)}
    
    def __getitem__(self, index):
        line = self.lines[index]
        x_str = ' ' + line 
        y_str = line + '\n'
        x = torch.zeros([len(x_str), self.vocab_size], dtype=torch.float)
        y = torch.empty(len(x_str), dtype=torch.long)
        
        y[0] = self.ch_to_idx[y_str[0]]

        for i, (x_ch, y_ch) in enumerate(zip(x_str[1:], y_str[1:]), 1):
            x[i][self.ch_to_idx[x_ch]] = 1
            y[i] = self.ch_to_idx[y_ch]
        
        return x, y
    
    def __len__(self):
        return len(self.lines)

In [3]:
class RNN(nn.Module):
    def __init__(self, hidden_size, input_size, output_size):
        super().__init__()
        self.linearax = nn.Linear(input_size, hidden_size) #a2x
        self.linearaa = nn.Linear(hidden_size, hidden_size) #a2a
        self.linearay = nn.Linear(hidden_size, output_size) #h2o

    def forward(self, x, a):
        z1 = self.linearax(x)
        z2 = self.linearaa(a)
        z = z1 + z2
        a_prime = torch.tanh(z)
        y = self.linearay(a_prime)
        
        return a_prime, y

In [4]:
def sample(model):
    model.eval()
    word_size=0
    newline_idx = trn_ds.ch_to_idx['\n']
    indices = []
    pred_char_idx = -1
    
    list_for_idx = []
    for i in range (27):
        list_for_idx.append(i)
        
    a_prev = torch.zeros(1, 100)
    x = torch.zeros(27)

    
    with torch.no_grad():
        while pred_char_idx != newline_idx and word_size != 50:

            a, y = model.forward(x, a_prev)
            y_softmax = F.softmax(y)

            idx = np.random.choice(list_for_idx, p = y_softmax.reshape(27).numpy())
            indices.append(idx)
            
            x = torch.zeros(27)
            x[idx] = 1
            
            pred_char_idx = idx 
            word_size += 1
            
            a_prev = a
            
        if word_size == 50:
            indices.append(newline_idx)
            
    return indices

def print_sample(sample_idxs):
    print(trn_ds.idx_to_ch[sample_idxs[0]].upper(), end='')
    [print(trn_ds.idx_to_ch[x], end='') for x in sample_idxs[1:]]

In [5]:
def train_one_epoch(model, loss_fn, optimizer):

    for line_num, (x, y) in enumerate(trn_dl):
        model.train()
        loss = 0
        optimizer.zero_grad()
        
        a = torch.zeros(100)
        
        for i in range(x.shape[1]):
            a, y_pred = model.forward(x[0][i], a)
            loss += loss_fn(y_pred.view(1, -1), y[:,i])            
            
        if (line_num+1) % 100 == 0:

            print_sample(sample(model))
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)

        optimizer.step()
        

In [6]:
trn_ds = DinosDataset()
trn_dl = DataLoader(trn_ds, batch_size=1, shuffle=True)

def train(trn_ds, trn_dl, epochs=1):
    
    model = RNN(100, 27, 27)

    loss_fn = nn.CrossEntropyLoss()

    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    for e in range(1, epochs+1):
        print(f'{"-"*20} Epoch {e} {"-"*20}')
        train_one_epoch(model, loss_fn, optimizer)

In [7]:
#Start training
train(trn_ds, trn_dl, epochs=5)

-------------------- Epoch 1 --------------------




Harw
Ftksdo
Jxylnrshur
Sgurlot
Miiclphaaus
Eurusaurus

Yubros
Derpsaurus
Ackaapandukq
Nggltoneotnatrus
Gichoposaurus
Dicendovsaurus
Ncedtanaongdosturus
Brancosaurus
-------------------- Epoch 2 --------------------
Lopdonoerastus
Hanoestuon
Tanachiaurosaurus
Hervina
Telangctosy
Yixaitis
Copando
Rrtovaaurus
Plassacompopma
Sakaueur
Urenchidulum
Topmoria
Nntelonatan
Chyusaurosterdishusaurus
Khhiondamana
-------------------- Epoch 3 --------------------
Pliairosaurus
Lerabosaurus
Dongaurosaurus
Cempovodan
Yenaalrin
Jumymamars
Pintiosaurus
Llamisaurus
Aurestochonosan
Huragosaurus
Arakotelaus


KeyboardInterrupt: 