In [None]:
import os
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
import string
import time

from core import RNN, Embedding, SoftmaxCrossEntropy

%matplotlib inline

In [None]:
def one_hot_enc(y, num_classes):
    # Y is (B, seq_len)
    # We want (B, seq_len, vocab_size)
    B, T = y.shape
    
    y_flat = y.ravel() # Flattens to 1D without copying memory
    one_hot = cp.zeros((B * T, num_classes))
    
    one_hot[cp.arange(B * T), y_flat] = 1 # Both must be 1D arrays, go over pairs (a, b)
    
    return one_hot.reshape(B, T, num_classes)
    

class SherlockDataset:
    def __init__(self, path, batch_size, seq_len):
        with open(path, 'r', encoding='utf-8') as file:
            self.text = file.read()
            
        allowed_chars = string.ascii_letters + string.digits + string.punctuation + " \n"
        
        self.chars = sorted(list(set(self.text)))
        self.vocab_size = len(self.chars)
        self.stoi = {ch:i for i, ch in enumerate(self.chars)}
        self.itos = {i:ch for i, ch in enumerate(self.chars)}

        
        self.data = cp.array([self.stoi[ch] for ch in self.text if ch in allowed_chars]) # ndims = 1
        
        chars_elem = self.data.size // batch_size # how many chars in one element in batch
        self.data = self.data[:chars_elem * batch_size] # cutoff
        self.data = self.data.reshape(batch_size, -1)
        
        self.num_batches = (self.data.shape[1] - 1) // seq_len
        self.batch_size = batch_size
        self.seq_len = seq_len
        
    def get_batch(self, i):
        start = i * self.seq_len
        end = start + self.seq_len
        
        X = self.data[:, start:end]
        Y = self.data[:, start + 1: end + 1]
        
        return X, Y

In [None]:
def softmax(logits):
    # We assume logits are (B, seq_len, out_dim)
    max_logits = cp.max(logits, axis=2, keepdims=True) # (B, seq_len, 1)
    shifted_logits = logits - max_logits # (B, seq_len, out_dim)
    
    exp_logits = cp.exp(shifted_logits) # (B, seq_len, out_dim)
    exp_sum = cp.sum(exp_logits, axis=2, keepdims=True) # (B, seq_len, 1)
    
    probs = exp_logits / exp_sum # (B, seq_len, out_dim)
    
    return probs

def sample(models: list[RNN], embedding: Embedding, dataset: SherlockDataset, start_char: str, length: int, temperature:float = 0.8, time_breaks: float = 0.0) -> str:
    current_ix = dataset.stoi[start_char]
    h = [None for _ in models]
    output_str = start_char
    print(start_char, end="")
    
    for _ in range(length):
        x = embedding.forward(cp.array([[current_ix]])) 
        
        for i in range(len(models)):
            x, h[i] = models[i].forward(x, h[i])
            
        logits = x / temperature
        probs_gpu = softmax(logits).reshape(-1)
        
        probs_cpu = cp.asnumpy(probs_gpu).astype('float64') # We convert to np to have working sampling (cupy is weird)
        next_ix = np.random.choice(dataset.vocab_size, p=probs_cpu)
        
        next_character = dataset.itos[int(next_ix)]
        output_str += next_character
        current_ix = next_ix
        
        print(next_character, end="", flush=True)
        
        if time_breaks != 0.0:
            time.sleep(time_breaks)
        
    return output_str

In [None]:
data_path = 'sherlock.txt'
batch_size = 128
seq_len = 128

embed_dim = 64
hidden_dim = 512

epochs = 80
learning_rate = 5e-4

dataset = SherlockDataset(data_path, batch_size, seq_len)

embedding = Embedding(dataset.vocab_size, embed_dim)
models = [RNN(embed_dim, hidden_dim, hidden_dim), RNN(hidden_dim, hidden_dim, dataset.vocab_size)]
loss_fn = SoftmaxCrossEntropy()

h_states = [None for _ in models]

loss_history = []

for epoch in range(epochs):
    for batch_index in range(dataset.num_batches):
        X, Y = dataset.get_batch(batch_index)
        Y = one_hot_enc(Y, dataset.vocab_size)
        x_emb = embedding.forward(X) # (B, seq_len, embed_dim)
        activations = [x_emb] # Prepare for forwards pass 
        
        for i in range(len(models)): # Essentially each RNN outputs (B, T, hidden_dim) so its some projection of input_dim
            output, h_states[i] = models[i].forward(activations[-1], h_states[i])
            activations.append(output)
        
        logits = activations[-1]
        loss = loss_fn.forward(logits, Y)
        loss_history.append(loss)
        
        dlogits = loss_fn.backward()
        grads = [dlogits] # prepare for backwards pass
        
        for i in range(len(models)):
            grad = models[len(models) - i - 1].backward(grads[-1])
            grads.append(grad)
            
        embedding.backward(grads[-1])
        
        for m in models:
            m.step(learning_rate)
        embedding.step(learning_rate)
            
    print("=" * 40)
    print(f"Epoch: {epoch+1} | Loss: {loss}")
    sample(models, embedding, dataset, "A", 200, 0.8)
    print("\n" + "=" * 40)


In [None]:
plt.plot(loss_history)
plt.show()

In [None]:
sample(models, embedding, dataset, "H", 200000, temperature=0.8, time_breaks=0.02)