In [1]:
import torch
import torch.nn as nn
import random

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
class RNN_base(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, activation):
        super().__init__()
        sigma = 0.01
        self.W_xh = nn.Parameter(torch.randn(input_size, hidden_size)*sigma)
        self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size)*sigma) 
        self.b_h = nn.Parameter(torch.randn(hidden_size)*sigma)
        self.activation = activation
        self.W_hq = nn.Parameter(torch.randn(hidden_size, output_size)*sigma)
        self.b_q = nn.Parameter(torch.randn(output_size))
    
    def forward(self, X, state=None):
        if state is None:
            state = torch.zeros(X.shape[1], self.W_hh.shape[0], device=X.device)
        outputs = []
        states = []
        for Xt in X:  #X is the input with shape (seq_len, batch_size, input_size)
            state = self.activation(Xt @ self.W_xh + state @ self.W_hh + self.b_h)
            states.append(state)
            output = state @ self.W_hq + self.b_q
            outputs.append(output)
        return outputs, states    

    @property
    def device(self):
        return next(self.parameters()).device

In [4]:
ds_text = ''
with open('tinyshakespeare.txt', 'r') as f:
    ds_text = f.read()
    ds_text = ds_text.lower() # convert to lowercase to make training easier
print("dataset size:", len(ds_text))


dataset size: 1115394


In [5]:
vocab = sorted(set(ds_text))
print(f"Vocabulary size: {len(vocab)}")
char_to_idx = {char:idx for idx, char in enumerate(vocab)}
idx_to_char = {idx:char for char, idx in char_to_idx.items()}

Vocabulary size: 39


In [6]:
corpus_indices = torch.tensor([char_to_idx[char] for char in ds_text])
corpus_indices

tensor([18, 21, 30,  ..., 19,  8,  0])

In [7]:
encoded_text = torch.zeros(len(ds_text), len(vocab))
for i, char in enumerate(ds_text):
    encoded_text[i, char_to_idx[char]] = 1
encoded_text

#Or in an optimized way
#encoded_text = nn.functional.one_hot(corpus_indices).float()    

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]])

In [8]:
class DataLoader():
    def __init__(self, corpus_indices, encoded_text, seq_length, batch_size):
        self.corpus_indices = corpus_indices
        self.encoded_text = encoded_text
        self.seq_length = seq_length
        self.batch_size = batch_size
    
    def __iter__(self):
        num_examples = (self.encoded_text.shape[0] - 1) // self.seq_length #-1 to avoid OOB for Y (j+1, j+seq_length+1) 
        #starting points
        example_indices = list(range(0, num_examples * self.seq_length, self.seq_length))
        random.shuffle(example_indices)
        
        for i in range(0, len(example_indices), self.batch_size):
            batch_indices = example_indices[i : i + self.batch_size]
            # if the batch is smaller than the batch size, we drop it
            if len(batch_indices) < self.batch_size:
                continue
            
            X = torch.stack([self.encoded_text[j : j + self.seq_length] for j in batch_indices])
            Y = torch.stack([self.corpus_indices[j + 1 : j + self.seq_length + 1] for j in batch_indices])
            
            # Transpose to (seq_length, batch_size, vocab_size)
            yield X.transpose(0, 1), Y.transpose(0, 1)

    def __len__(self):
        num_examples = (self.encoded_text.shape[0] - 1) // self.seq_length
        return num_examples // self.batch_size


In [9]:
class Training():
    def __init__(self, model, data_loader, optimizer, epochs, grad_clip=1.0):
        self.model = model
        self.data_loader = data_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optimizer
        self.epochs = epochs
        self.grad_clip = grad_clip

    def predict(self, prefix, num_preds, vocab, char_to_idx, idx_to_char):
        state = None
        outputs = [char_to_idx[prefix[0]]]
        def get_input():
            return torch.tensor([outputs[-1]], device=self.model.device).reshape(1, 1)
        
        # Warm up with prefix
        for char in prefix[1:]:
            X = nn.functional.one_hot(get_input(), num_classes=len(vocab)).float()
            _, states = self.model(X, state)
            state = states[-1]
            outputs.append(char_to_idx[char])
            
        # Predict num_preds characters
        for _ in range(num_preds):
            X = nn.functional.one_hot(get_input(), num_classes=len(vocab)).float()
            y, states = self.model(X, state)
            state = states[-1]
            outputs.append(int(y[0].argmax(dim=1).item()))
            
        return ''.join([idx_to_char[i] for i in outputs])

    def train(self):
        self.model.to(self.model.device)
        for epoch in range(1, self.epochs + 1):
            total_loss = 0
            for X, Y in self.data_loader:
                X, Y = X.to(self.model.device), Y.to(self.model.device)
                
                y_hat, _ = self.model(X)
                y_hat = torch.stack(y_hat).reshape(-1, y_hat[0].shape[-1])
                Y = Y.reshape(-1)
                
                loss = self.criterion(y_hat, Y.long())
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
                self.optimizer.step()
                
                total_loss += loss.item()
            
            print(f"Epoch {epoch} loss: {total_loss / len(self.data_loader)}")
            print(f"Sample: {self.predict('the ', 50, vocab, char_to_idx, idx_to_char)}")

In [10]:
num_hiddens = 512
seq_length = 64
batch_size = 32
lr = 0.5
epochs = 20

data_loader = DataLoader(corpus_indices, encoded_text, seq_length, batch_size)
model = RNN_base(len(vocab), num_hiddens, len(vocab), nn.Tanh())
model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=lr)
trainer = Training(model, data_loader, optimizer, epochs)
trainer.train()

Epoch 1 loss: 2.732228200225269
Sample: the the the the the the the the the the the the the th
Epoch 2 loss: 2.340498979039052
Sample: the sore the the the the the the the the the the the t
Epoch 3 loss: 2.2516424173817917
Sample: the sore the hare the hare the hare the hare the hare 
Epoch 4 loss: 2.1821843223536717
Sample: the the have the soust of and the soust of and the sou
Epoch 5 loss: 2.1141527842949417
Sample: the sould the sould the sould the sould the sould the 
Epoch 6 loss: 2.0527414335485767
Sample: the the hard the serenter hard the warde the warde the
Epoch 7 loss: 1.999350335010711
Sample: the the the the have and the have and the have and the
Epoch 8 loss: 1.952020079116611
Sample: the the sting the stall the stall the stall the stall 
Epoch 9 loss: 1.907769898281378
Sample: the that i was he do be the with the bean the with the
Epoch 10 loss: 1.8670533341081703
Sample: the good marriess of the surse of the seaven and sees 
Epoch 11 loss: 1.8294957137721426
Sample: th