In [2]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import pdb
from torch.utils.data import Dataset, DataLoader
%load_ext autoreload
%autoreload 2
torch.set_printoptions(linewidth=200)

In [3]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear_hh = nn.Linear(hidden_size, hidden_size)
        self.linear_hx = nn.Linear(input_size, hidden_size, bias=False)
        self.linear_output = nn.Linear(hidden_size, output_size)
    
    def forward(self, h_prev, x):
        h = torch.tanh(self.linear_hh(h_prev) + self.linear_hx(x))
        y = self.linear_output(h)
        return h, y

In [14]:
class Name_Generator(Dataset):
    def __init__(self):
        super().__init__()
        self.init_hparams()
        self.getData()
#         self.getModel()

    def init_hparams(self):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.hidden_size = 100
        self.epochs = 10
        self.lr = 1e-2
    
    def getData(self):
        with open('data/dinos.txt') as f:
            content = f.read().lower()
            self.vocab = sorted(set(content))
            self.vocab_size = len(self.vocab)
            self.lines = content.splitlines()
        self.ch_to_idx = {c:i for i, c in enumerate(self.vocab)}
        self.idx_to_ch = {i:c for i, c in enumerate(self.vocab)}
        print(self.ch_to_idx)
        print("lines={}, vocab={}".format(len(self.lines), self.vocab_size))
        
    def getBatchData(self, index):
        line = self.lines[index]
        x_str = ' ' + line
        y_str = line + '\n'
        x = torch.zeros([len(x_str), self.vocab_size], dtype=torch.float)
        y = torch.empty(len(x_str), dtype=torch.long)
        
        y[0] = self.ch_to_idx[y_str[0]]
        for i, (x_ch, y_ch) in enumerate(zip(x_str[1:], y_str[1:]), 1):
            x[i][self.ch_to_idx[x_ch]] = 1
            y[i] = self.ch_to_idx[y_ch]
                    
        return x, y
    
    def getModel(self):
        self.model = RNN(self.vocab_size, self.hidden_size, self.vocab_size).to(self.device)
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr)
    
    def print_sample(self, sample_idxs):
        print(self.idx_to_ch[sample_idxs[0]].upper(), end='')
        [print(self.idx_to_ch[x], end='') for x in sample_idxs[1:]]
    
    def sample(self):
        self.model.eval()
        word_size=0
        newline_idx = self.ch_to_idx['\n']
        indices = []
        pred_char_idx = -1
        h_prev = torch.zeros([1, self.hidden_size], dtype=torch.float, device=self.device)
        x = h_prev.new_zeros([1, self.vocab_size])
        with torch.no_grad():
            while pred_char_idx != newline_idx and word_size != 50:
                h_prev, y_pred = self.model(h_prev, x)
                softmax_scores = torch.softmax(y_pred, dim=1).cpu().numpy().ravel()
                np.random.seed(np.random.randint(1, 5000))
                idx = np.random.choice(np.arange(self.vocab_size), p=softmax_scores)
                indices.append(idx)

                x = (y_pred == y_pred.max(1)[0]).float()
                pred_char_idx = idx

                word_size += 1

            if word_size == 50:
                indices.append(newline_idx)
        return indices
    
    def train(self):
        for line_num in range(len(self.lines)):
            x, y = self.getBatchData(line_num)
            x = torch.unsqueeze(x, 0)
            y = torch.unsqueeze(y, 0)
            
            self.model.train()
            loss = 0
            self.optimizer.zero_grad()
            h_prev = torch.zeros([1, self.hidden_size], dtype=torch.float, device=self.device)
            x, y = x.to(self.device), y.to(self.device)
            for i in range(x.shape[1]):
                h_prev, y_pred = self.model(h_prev, x[:, i])
                loss += self.loss_fn(y_pred, y[:, i])

#             print(line_num)
            if (line_num+1) % 100 == 0:
                self.print_sample(self.sample())

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5) #gradient clipping
            self.optimizer.step()
    
    def fit(self):
        for e in range(1, self.epochs+1):
            print(f'{"-"*20} Epoch {e} {"-"*20}')
            self.train()
    
    def print_ds(self, num_examples=10):
        for i in range(len(self.lines)):
            x, y = self.getBatchData(i)
            print('*'*50)
            x_str, y_str = '', ''
            for idx in y:
                y_str += self.idx_to_ch[idx.item()]
            print(repr(y_str))

            for t in x[1:]:
                x_str += self.idx_to_ch[t.argmax().item()]
            print(repr(x_str))

            if i == num_examples:
                break

In [15]:
obj = Name_Generator()

{'\n': 0, 'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26}
lines=1536, vocab=27
['\n', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
19909
['aachenosaurus', 'aardonyx', 'abdallahsaurus', 'abelisaurus', 'abrictosaurus', 'abrosaurus', 'abydosaurus', 'acanthopholis', 'achelousaurus', 'acheroraptor', 'achillesaurus', 'achillobator', 'acristavus', 'acrocanthosaurus', 'acrotholus', 'actiosaurus', 'adamantisaurus', 'adasaurus', 'adelolophus', 'adeopapposaurus', 'aegyptosaurus', 'aeolosaurus', 'aepisaurus', 'aepyornithomimus', 'aerosteon', 'aetonyxafromimus', 'afrovenator', 'agathaumas', 'aggiosaurus', 'agilisaurus', 'agnosphitys', 'agrosaurus', 'agujaceratops', 'agustinia', 'ahshislepelta', 'airakoraptor', 'ajancingenia', 'ajkaceratops',

In [13]:
obj.fit()

-------------------- Epoch 1 --------------------


AttributeError: 'Name_Generator' object has no attribute 'model'

In [103]:
obj.print_ds(5)

**************************************************
'aachenosaurus\n'
'aachenosaurus'
**************************************************
'aardonyx\n'
'aardonyx'
**************************************************
'abdallahsaurus\n'
'abdallahsaurus'
**************************************************
'abelisaurus\n'
'abelisaurus'
**************************************************
'abrictosaurus\n'
'abrictosaurus'
**************************************************
'abrosaurus\n'
'abrosaurus'
