In [None]:
import numpy as np
import matplotlib.pyplot as plt

from core import RNN, Embedding, SoftmaxCrossEntropy

%matplotlib inline

In [None]:
class TinyShakespeareDataset:
    def __init__(self, path, batch_size, seq_len):
        with open(path, 'r') as file:
            self.text = file.read()
        
        self.chars = sorted(list(set(self.text)))
        self.vocab_size = len(self.chars)
        self.stoi = {ch:i for i, ch in enumerate(self.chars)}
        self.itos = {i:ch for i, ch in enumerate(self.chars)}
        
        chars_elem = len(self.text) // batch_size # how many chars in one element in batch
        self.data = self.text[:chars_elem * batch_size] # cutoff
        
        self.data = np.array([self.stoi[ch] for ch in self.text]) # ndims = 1
        self.data = self.data.reshape(batch_size, -1)
        
        self.num_batches = (len(self.data) - 1) // seq_len
        self.batch_size = batch_size
        self.seq_len = seq_len
        
    def get_batch(self, i):
        start = i * self.seq_len
        end = start + self.seq_len
        
        X = self.data[:, start:end]
        Y = self.data[:, start + 1: end + 1]
        
        return X, Y

In [None]:
def softmax(logits):
    # We assume logits are (B, seq_len, out_dim)
    max_logits = np.max(logits, axis=2, keepdims=True) # (B, seq_len, 1)
    shifted_logits = logits - max_logits # (B, seq_len, out_dim)
    
    exp_logits = np.exp(shifted_logits) # (B, seq_len, out_dim)
    exp_sum = np.sum(exp_logits, axis=2, keepdims=True) # (B, seq_len, 1)
    
    probs = exp_logits / exp_sum # (B, seq_len, out_dim)
    
    return probs

def sample(model: RNN, embedding: Embedding, dataset: TinyShakespeareDataset, start_char: str, length: int, temperature=0.8) -> str:
    current_input = embedding.forward(np.array(dataset.stoi[start_char]).reshape(1, -1)) # Becomes (1, 1, embed_dim)
    h = None
    
    output = ""
    
    for _ in range(length):
        logits, h = model.forward(current_input, h)
        # logits is (B, seq_len, out_dim) = (1, 1, vocab_size)
        logits = logits / temperature
        probs = softmax(logits).reshape(dataset.vocab_size) # (vocab_size,)
        
        next_ix = np.random.choice(dataset.vocab_size, p=probs)
        output += dataset.itos[next_ix]
        
        current_input = embedding.forward(np.array(next_ix).reshape(1, -1)) # Becomes (1, 1, embed_dim)
        
    return output        