In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys

# Data prep

In [81]:
class TextLoader:
    def __init__(self, text, test_split=0.9):
        self.vocab = sorted(list(set(text)))
        self.vocab_size = len(self.vocab)
        self.stoi = {ch:i for i,ch in enumerate(self.vocab)}
        self.itos = {i:ch for ch,i in self.stoi.items()}

        self.data = [self.stoi[ch] for ch in text]
        n = int(test_split * len(self.data))
        self.train_data = self.data[:n]
        self.val_data = self.data[n:]
        
    def encode(self, text):
        # return list of char indices
        return [self.stoi[ch] for ch in text]

    def decode(self, tokens):
        # return string for the tokens list
        return ''.join([self.itos[token] for token in tokens])

    def get_batch(self, split, batch_size, block_size):
        data = self.train_data if split=='train' else self.val_data
        ix = torch.randint(0, len(data) - block_size, (batch_size,)) # [2,99, 56, 9000,...]
        x = [data[i : i + block_size] for i in ix]
        y = [data[i+1 : i+1 + block_size] for i in ix]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

In [82]:
text = open('../rnn_lstm/data/shakespeare.txt', 'r').read()
loader = TextLoader(text)

In [83]:
loader

<__main__.TextLoader at 0x136f1ddf0>

In [84]:
print(loader.vocab)
print(loader.vocab_size)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
65


In [85]:
xb, yb = loader.get_batch('train', 2, 4)
print(xb)
print(yb)

tensor([[52, 58, 47, 52],
        [47, 52, 43,  1]])
tensor([[58, 47, 52, 59],
        [52, 43,  1, 57]])


# Model construct

In [None]:
class LSTMCell(nn.Module):
    # prev_layer_hidden_state
    def __init__(self, input_embd, hidden_embd):
        super().__init__()
        """
        wx_gates: [wxf | wxi | wxg | wxo]
        wh_gates: [whf | whi | whg | wxo]
        b_gates : [bf | bi | bg | bo]
        """
        self.wx_gates = nn.Parameter(torch.randn(input_embd, hidden_embd * 4) * 0.01)
        self.wh_gates = nn.Parameter(torch.randn(hidden_embd, hidden_embd * 4) * 0.01)
        self.b_gates = nn.Parameter(torch.zeros(hidden_embd * 4))

    def forward(self, x, h_prev, c_prev):
        # x: (B,n)
        # layer_hidden: (B,hidden_embd)
        # layer_cell: (B,hidden_embd)
        x_gates = x @ self.wx_gates # (b,input_dim) @ (inpyt_embd, hidden_embd*4) -> (b,hidden_embd*4)
        h_gates = h_prev @ self.wh_gates # (b,hidden_embd) @ (hidden_embd, hidden_embd*4) -> (b,hidden_embd*4)
        gates_output = x_gates + h_gates + self.b_gates

        ft, it, gt, ot = gates_output.chunk(4, dim=1)

        ft = torch.sigmoid(ft)
        it = torch.sigmoid(it)
        gt = torch.tanh(gt)
        ot = torch.sigmoid(ot)

        c_t = (ft * c_prev) + (it * gt)
        h_t = ot * torch.tanh(c_t)
        return h_t, c_t

class MultiLayerLSTM(nn.Module):
    def __init__(self, vocab_size, input_embd, hidden_embd, layers, dropout=0.5):
        super().__init__()
        self.layers = layers
        self.hidden_embd = hidden_embd
        self.embedding = nn.Parameter(torch.randn(vocab_size, input_embd) * 0.01)
        self.lstm_layer = nn.ModuleList()
        
        self.lstm_layer.append(LSTMCell(input_embd, hidden_embd))
        for layer in range(1, layers):
            self.lstm_layer.append(LSTMCell(hidden_embd, hidden_embd))
        self.dropout = nn.Dropout(dropout) if dropout is not None and layers > 1 else None

        self.why = nn.Parameter(torch.randn(hidden_embd, vocab_size) * 0.01)
        self.by = nn.Parameter(torch.zeros(vocab_size))

    def forward(self, x):
        B,T = x.shape
        hs = torch.zeros(self.layers, B, self.hidden_embd, device=x.device)
        cs = torch.zeros(self.layers, B, self.hidden_embd, device=x.device)
        logits = []

        emb = self.embedding[x] # (B,T,input_embd)
        
        for t in range(T):
            xt = emb[:, t, :] # (B, input_embd)
            hs_new = torch.zeros_like(hs, device=xt.device)
            cs_new = torch.zeros_like(cs, device=xt.device)
            for layer in range(self.layers):
                h_layer, c_layer = hs[layer], cs[layer]
                cell_layer = self.lstm_layer[layer]
                h_new, c_new = cell_layer(xt, h_layer, c_layer)
                hs_new[layer] = h_new
                cs_new[layer] = c_new
                if layer < self.layers - 1 and self.dropout is not None:
                    xt = self.dropout(h_new)
                else:
                    xt = h_new
            hs = hs_new
            cs = cs_new
            yt = hs[-1] @ self.why + self.by # (B, vocab_size)
            logits.append(yt)
        # now logits: T elements of shape (B, vocab_size)
        logits = torch.stack(logits, dim=1) # (B, T, vocab_size)
        return logits

In [95]:
n_embd = 8
n_hidden = 32
device = 'mps'

model = MultiLayerLSTM(loader.vocab_size, input_embd=n_embd, hidden_embd=n_hidden, layers=2)
model.to(device)
print(sum(p.numel() for p in model.parameters()))

for p in model.parameters():
    p.requires_grad = True

16233


In [96]:
type(model)

__main__.MultiLayerLSTM

In [97]:
for name, param in model.named_parameters():
    print(name, param.shape)

embedding torch.Size([65, 8])
why torch.Size([32, 65])
by torch.Size([65])
lstm_layer.0.wx_gates torch.Size([8, 128])
lstm_layer.0.wh_gates torch.Size([32, 128])
lstm_layer.0.b_gates torch.Size([128])
lstm_layer.1.wx_gates torch.Size([32, 128])
lstm_layer.1.wh_gates torch.Size([32, 128])
lstm_layer.1.b_gates torch.Size([128])


In [115]:
n_embd = 64
n_hidden = 64
device = 'mps'
batch_size = 32
block_size = 128

model = MultiLayerLSTM(loader.vocab_size, input_embd=n_embd, hidden_embd=n_hidden, layers=2)
model.to(device)
print(f'{sum(p.numel() for p in model.parameters()) / 1e6} MILION params')

for p in model.parameters():
    p.requires_grad = True

optim = torch.optim.AdamW(model.parameters(), lr=1e-3)

max_iters = 20000
patience = 20
best_val_loss = float('inf')
epochs_no_improve = 0

for i in range(max_iters):
    x, y = loader.get_batch('train', batch_size=batch_size, block_size=block_size)
    x, y = x.to(device), y.to(device)
    B,T = x.shape
    # forward pass
    logits = model(x)
    loss = F.cross_entropy(logits.view(B*T, -1), y.view(B*T))

    # backward pass
    for p in model.parameters():
        p.grad = None
    loss.backward()

    # update weights
    optim.step()

    # validation
    if i % 200 == 0:
        with torch.no_grad():
            x_val, y_val = loader.get_batch('val', batch_size=512, block_size=block_size)
            x_val, y_val = x_val.to(device), y_val.to(device)
            logits = model(x_val)
            val_loss = F.cross_entropy(logits.view(-1, loader.vocab_size), y_val.view(-1))
        # early-stopping
        if val_loss < best_val_loss - 1e-4: # small delta to be considered
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
        if epochs_no_improve > patience:
            print(f'Early stop @ epoch {i}. \n__________________Best validation loss = {best_val_loss:.4f}')
            break
    if i % 100 == 0:
        print(f'Iteration {i} | train loss = {loss.item():.4f} | val loss = {val_loss.item():.4f}')

0.074433 MILION params
Iteration 0 | train loss = 4.1744 | val loss = 4.1733
Iteration 100 | train loss = 3.3609 | val loss = 4.1733
Iteration 200 | train loss = 3.3098 | val loss = 3.3687
Iteration 300 | train loss = 3.1805 | val loss = 3.3687
Iteration 400 | train loss = 3.0694 | val loss = 3.0908
Iteration 500 | train loss = 2.9918 | val loss = 3.0908
Iteration 600 | train loss = 2.8405 | val loss = 2.8516
Iteration 700 | train loss = 2.7169 | val loss = 2.8516
Iteration 800 | train loss = 2.6768 | val loss = 2.6684
Iteration 900 | train loss = 2.5953 | val loss = 2.6684
Iteration 1000 | train loss = 2.5885 | val loss = 2.5853
Iteration 1100 | train loss = 2.5803 | val loss = 2.5853
Iteration 1200 | train loss = 2.5681 | val loss = 2.5234
Iteration 1300 | train loss = 2.4880 | val loss = 2.5234
Iteration 1400 | train loss = 2.4933 | val loss = 2.4724
Iteration 1500 | train loss = 2.4619 | val loss = 2.4724
Iteration 1600 | train loss = 2.4093 | val loss = 2.4229
Iteration 1700 | tra

KeyboardInterrupt: 

In [116]:
@torch.no_grad()
def generate(model, stoi, itos, block_size, prompt=None, device='cpu', max_new_tokens=500, out_path='generated.txt'):
    model.eval()
    if not prompt:
        idx = torch.tensor([[0]], dtype=torch.long, device=device)
    else:
        idx = torch.tensor([loader.stoi[ch] for ch in prompt], dtype=torch.long, device=device)
    generated_chars = []
    
    for _ in range(max_new_tokens):
        idx_cropped = idx[:, -block_size:] # (b,T)
        logits = model(idx_cropped) # (b,T,vocab_size)
        logits = logits[0][-1] # (vocab_size,) vector from last time step
        probs = F.softmax(logits, dim=-1)
        next_idx = torch.multinomial(probs, 1).item()
        generated_chars.append(loader.itos[next_idx])
    full_text = ''.join(generated_chars)
    print(full_text)
    # with open(out_path, 'w', encoding='utf-8') as fp:
    #     fp.write(full_text)
    model.train()

In [117]:
generate(model, loader.stoi, loader.itos, block_size=block_size, device='mps', max_new_tokens=500)

WBIBHSHwLHI:TgDy'y:AwI
ucKRLLTlurHWAMGtE
W'HS';OSWSlSI;p3FsPI;CoiIEyHrMb
HPR
aTcZKVM3ACI sLTE3VWgaOO-E3DRIIIAANTAb-HQoI UNVAlRVHIDEdCMHT
BFHLcHOlGPSW
-HObsGaxbJHIQTOOJRMCHBH,CHPMPIILKCCM'WNGiHWntU3IfmFBO
bEFKlIEILOq,Hp:HHENHoPv WI
pOnOFHngN
POOIRNCBHDiVaUIAjyEIWIcFMBnOGOLEBIP
muK
wbhCGHtJYeCEeVREiL-KHGE IaEWEMKTgiNIaENLTEvCMLF I
W
VrBlmtWCWuN FEAVVAb
WCRhnVIRIaCiSLiI V3fBTAw
jnv
bMTIBaMtL
LktFFhYSwFIIDDFEN'SBSNYSZHUAulSwJBIrMaOUF
GBMmOH ,M&rWCI
AKQDgauD
DS,MCeWT
MAOEoSKtKLRACNHEGkFETwHwGTINTDAHT
