In [1]:
import numpy as np
import random
import torch as pt
import torch.nn as nn
import torch.optim as optim
import re
from torch.utils.data import Dataset, DataLoader

%load_ext autoreload
%autoreload 2

pt.set_printoptions(linewidth=200)

In [2]:
device = pt.device("cuda:0" if pt.cuda.is_available() else "cpu")
hidden_size = 100

In [3]:
class ShakespeareDataset(Dataset):
    def __init__(self):
        super().__init__()
        with open('shakespeare.txt') as f:
            content = f.read().lower()
            self.vocab = sorted(set(content))
            self.vocab_size = len(self.vocab)
            self.poems = [poem for poem in re.split('\s{2,}', content) if len(poem) > 100]
        self.ch_to_idx = {c:i for i, c in enumerate(self.vocab)}
        self.idx_to_ch = {i:c for i, c in enumerate(self.vocab)}
    
    def __getitem__(self, index):
        poem = self.poems[index]
        x_str = poem
        y_str = poem[1:] + '\n'
        x = pt.zeros([len(x_str), self.vocab_size], dtype=pt.float)
        y = pt.empty(len(x_str), dtype=pt.long)
        for i, (x_ch, y_ch) in enumerate(zip(x_str, y_str)):
            x[i][self.ch_to_idx[x_ch]] = 1
            y[i] = self.ch_to_idx[y_ch]
        
        return x, y
    
    def __len__(self):
        return len(self.poems)

In [4]:
trn_ds = ShakespeareDataset()
trn_dl = DataLoader(trn_ds, shuffle=True)

In [5]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTM, self).__init__()
        self.l1 = nn.ModuleList(self.make_lstm_layer(input_size, hidden_size, hidden_size))
        self.l2 = nn.ModuleList(self.make_lstm_layer(hidden_size, hidden_size, output_size))
        
    def forward(self, c_prev_1, h_prev_1, c_prev_2, h_prev_2, x):
        c_1, h_1, output_1 = self.forward_layer(self.l1, c_prev_1, h_prev_1, x)
        c_2, h_2, y = self.forward_layer(self.l2, c_prev_2, h_prev_2, output_1)
        return c_1, h_1, c_2, h_2, y
        
    def forward_layer(self, l, c_prev, h_prev, x):
        combined = pt.cat([x, h_prev], 1)
        f = pt.sigmoid(l[0](combined))
        u = pt.sigmoid(l[1](combined))
        c_tilde = pt.tanh(l[2](combined))
        c = f*c_prev + u*c_tilde
        o = pt.sigmoid(l[3](combined))
        h = o*pt.tanh(c)
        output = l[4](h)
        
        return c, h, output
        
    def make_lstm_layer(self, input_size, hidden_size, output_size):
        linear_f = nn.Linear(input_size + hidden_size, hidden_size)
        linear_u = nn.Linear(input_size + hidden_size, hidden_size)
        linear_c = nn.Linear(input_size + hidden_size, hidden_size)
        linear_o = nn.Linear(input_size + hidden_size, hidden_size)
        i2o = nn.Linear(hidden_size, output_size)
        return [linear_f, linear_u, linear_c, linear_o, i2o]

In [6]:
model = LSTM(trn_ds.vocab_size, hidden_size, trn_ds.vocab_size).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)

In [7]:
def print_sample(sample_idxs):
    [print(trn_ds.idx_to_ch[x], end='') for x in sample_idxs]

In [8]:
def sample(model):
    model.eval()
    with pt.no_grad():
        c_prev_1 = pt.zeros([1, hidden_size], dtype=pt.float, device=device)
        h_prev_1 = pt.zeros_like(c_prev_1)
        c_prev_2 = pt.zeros_like(c_prev_1)
        h_prev_2 = pt.zeros_like(c_prev_1)
        
        idx = random.randint(1, trn_ds.vocab_size-1)
        x = c_prev_1.new_zeros([1, trn_ds.vocab_size])
        x[0, idx] = 1
        sampled_indexes = [idx]
        n_chars = 1
        newline_char_idx = trn_ds.ch_to_idx['\n']
        num_lines = 0
        while n_chars != 1000 and num_lines != 5:
            c_prev_1, h_prev_1, c_prev_2, h_prev_2, y_pred = model(c_prev_1, h_prev_1, c_prev_2, h_prev_2, x)
            
            np.random.seed(np.random.randint(1, 5000))
            idx = np.random.choice(np.arange(trn_ds.vocab_size), p=pt.softmax(y_pred, 1).cpu().numpy().ravel())
            sampled_indexes.append(idx)
            x = pt.zeros_like(x)
            x[0, idx] = 1
            
            n_chars += 1
            
            if idx == newline_char_idx:
                num_lines += 1
            
        if n_chars == 50:
            sampled_indexes.append(newline_char_idx)
                
    model.train()
    return sampled_indexes

In [9]:
def train_one_epoch(model, loss_fn, optimizer):
    for poem_num, (x, y) in enumerate(trn_dl):
        model.train()
        loss = 0
        optimizer.zero_grad()
        c_prev_1 = pt.zeros([1, hidden_size], dtype=pt.float, device=device)
        h_prev_1 = pt.zeros_like(c_prev_1)
        c_prev_2 = pt.zeros_like(c_prev_1)
        h_prev_2 = pt.zeros_like(c_prev_1)
        x, y = x.to(device), y.to(device)
        for i in range(x.shape[1]):
            c_prev_1, h_prev_1, c_prev_2, h_prev_2, y_pred = model(c_prev_1, h_prev_1, c_prev_2, h_prev_2, x[:, i])
            loss += loss_fn(y_pred, y[:, i])
        loss.backward()
        optimizer.step()
        if (poem_num + 1) % 50 == 0:
            print_sample(sample(model))

In [10]:
def train(model, loss_fn, optimizer, epochs=1):
    for e in range(1, epochs+1):
        print(f'Epoch:{e}')
        train_one_epoch(model, loss_fn, optimizer)

In [12]:
train(model, loss_fn, optimizer, epochs=1000)