In [None]:
from src.tokenizer import preprocess_tinyshakespare

preprocess_tinyshakespare()

In [1]:
from src.dataset import TinyShakespeareDataset
    

train_ds = TinyShakespeareDataset('train')
val_ds = TinyShakespeareDataset('val')
test_ds = TinyShakespeareDataset('test')

  self.tokens = torch.load(f"{path}_tokens.pt")
  self.labels = torch.load(f"{path}_labels.pt")
  self.offsets = torch.load(f"{path}_offsets.pt")


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tokenizers import Tokenizer


class FFLayer(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        dff = emb_dim * 8 // 3
        self.W1 = nn.Linear(emb_dim, dff)
        self.W2 = nn.Linear(dff, emb_dim)
        self.W3 = nn.Linear(emb_dim, dff)

    def forward(self, x):
        y = self.silu(self.W1(x))
        y = y * self.W3(x)
        y = self.W2(y)
        return y
    
    @staticmethod
    def silu(x):
        return x * F.sigmoid(x)

class RMSNorm(nn.Module):
    def __init__(self, emb_size, eps=1e-5):
        super().__init__()
        self.g = nn.Parameter(torch.ones(emb_size))
        self.eps = eps

    def forward(self, x):
        in_dtype = x.dtype
        rms = torch.sqrt(torch.mean(x.pow(2)) + self.eps)
        x = (x / rms) * self.g
        return x.to(in_dtype)

class CausalSelfAttentionBlock(nn.Module):
    def __init__(self, input_dim, emb_dim=512, max_seq_len=1050, n_heads=8, p_dropout=0.1):
        super().__init__()
        # self.Wq = nn.Linear(input_dim, emb_dim, bias=True)
        # self.Wk = nn.Linear(input_dim, emb_dim, bias=True)
        # self.Wv = nn.Linear(input_dim, emb_dim, bias=True)
        self.W = nn.Linear(input_dim, emb_dim * 3, bias=True)
        self.Wo = nn.Linear(emb_dim, emb_dim)
        self.ff = FFLayer(emb_dim)
        self.ln_mha = RMSNorm(emb_dim)
        self.ln_ff = RMSNorm(emb_dim)        
        self.scale = 1 / np.sqrt(emb_dim)
        self.emb_dim = emb_dim
        self.n_heads = n_heads
        self.p_dropout = p_dropout
        attn_mask = torch.tril(torch.ones(max_seq_len, max_seq_len)) == 0
        self.register_buffer("_causal_mask", attn_mask)


    def forward(self, x):
        y = self.mha(x)
        y_prenorm = y + x

        y = self.mlp(y_prenorm)
        y = y_prenorm + y
        
        return y
    
    def mlp(self, x):
        y = self.ln_ff(x)
        y = self.ff(y)
        y = F.dropout(y, p=self.p_dropout)
    
        return y
    
    def mha(self, x):
        b, t, e = x.shape
        s = e // self.n_heads
        
        # x = self.ln_mha(x)
        # q, k, v = self.Wq(x), self.Wk(x), self.Wv(x)
        # q = q.view(b, t, self.n_heads, s).transpose(1, 2)
        # k = k.view(b, t, self.n_heads, s).transpose(1, 2)
        # v = v.view(b, t, self.n_heads, s).transpose(1, 2)

        x = self.ln_mha(x)
        qkv = self.W(x)
        qkv = qkv.view(b, t, self.n_heads, 3 * s)
        qkv = qkv.transpose(1, 2)
        q, k, v = qkv.split(s, dim=-1)

        attn = q @ k.transpose(-1, -2) * self.scale
        attn = attn.masked_fill(self._causal_mask[:t, :t], -torch.inf)
        attn = F.softmax(attn, dim=-1)
        attn = F.dropout(attn, p=self.p_dropout)
        y = attn @ v
        y = y.transpose(-2, -3).reshape(b, t, e)
        y = self.Wo(y)
        y = F.dropout(y, p=self.p_dropout)

        return y

        
class NanoGPT(nn.Module):
    def __init__(
            self, 
            vocab_size=8124, 
            emb_dim=512, 
            attn_blocks=4, 
            max_seq_len=1024, 
            n_heads=16, 
            p_dropout=0.1, 
            tokenizer_path="tokenizer.json"
        ):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim)
        self.attn = nn.Sequential(
            *[CausalSelfAttentionBlock(emb_dim, emb_dim, max_seq_len, n_heads, p_dropout) for _ in range(attn_blocks)]
        )
        self.mlp = nn.Linear(emb_dim, vocab_size)
        self.layer_norm = RMSNorm(emb_dim)
        self.tokenizer = Tokenizer.from_file(tokenizer_path)
        self.p_dropout = p_dropout
        
        self.register_buffer("_device_tracker", torch.empty(0))
        pe = self._compute_pe(max_seq_len, emb_dim)
        self.register_buffer('pe', pe)

    @property
    def device(self):
        return self._device_tracker.device

    def forward(self, x):
        _, t = x.shape
        x = self.emb(x)
        x = x + self.pe[:, :t, :]
        x = F.dropout(x, p=self.p_dropout)
        x = self.attn(x)
        x = self.layer_norm(x)
        x = self.mlp(x)
        return x
    
    @torch.no_grad
    def run_inference(self, text_input, tau=1.0, k=10):
        x = self.tokenizer.encode(text_input).ids[:-1]
        eos_id = self.tokenizer.token_to_id("<eos>")
        next_token = -1
        cur_iter = 0
        max_iter = 128

        while (next_token != eos_id and cur_iter < max_iter):
            x_tensor = torch.tensor(x).to(self.device)
            logits = self.forward(x_tensor.unsqueeze(0))
            q = F.softmax(logits / tau, dim=-1)[:, -1]
            topk = q.topk(k=k)
            next_token_index = topk.values.multinomial(1).item()
            next_token = topk.indices[0, next_token_index]
            x += [next_token]
            cur_iter += 1
        
        return self.tokenizer.decode(x)
        
    def _compute_pe(self, max_seq_len, emb_dim):
        pe = torch.zeros(max_seq_len, emb_dim)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * (-np.log(10000.0) / emb_dim))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) 

        return pe

In [3]:
import numpy as np


class GPT2LRScheduler:
    def __init__(self, optim, max_lr=1e-3, min_lr=3e-5, warmup_steps=2000, total_steps=200000):
        self.optim = optim
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.warmup_steps = max(1, int(warmup_steps))
        self.total_steps = max(self.warmup_steps + 1, int(total_steps))
        self.step = 0

    def adjust_lr(self):
        self.step += 1
        s = self.step

        if s <= self.warmup_steps:
            # linear warmup: 0 -> max_lr
            lr = self.max_lr * (s / self.warmup_steps)
        else:
            # cosine decay: max_lr -> min_lr
            progress = (s - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            progress = min(max(progress, 0.0), 1.0)
            cosine = 0.5 * (1.0 + np.cos(np.pi * progress))
            lr = self.min_lr + (self.max_lr - self.min_lr) * cosine

        for pg in self.optim.param_groups:
            pg["lr"] = lr

        return lr

In [4]:
import numpy as np
import torch.nn.functional as F
import torch
from functools import partial
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from torchinfo import summary
from tqdm import tqdm


# def clip_gradient(grad):
#     max_grad_norm = 

def collate_batch_fn(args, pad_token):
    x = [t[0] for t in args]
    y = [t[1] for t in args]
    xlen = len(x)
    padded_xy = pad_sequence(x + y, padding_value=pad_token, batch_first=True)
    padded_x = padded_xy[:xlen]
    padded_y = padded_xy[xlen:]

    return padded_x, padded_y

lrs = []

def train_epoch(model, dloader, optim, lr_scheduler, device):
    cum_loss =  0.0

    for x, y in tqdm(dloader):
        lr_scheduler.adjust_lr()
        x, y = x.to(device), y.to(device)
        optim.zero_grad()
        y_pred = model(x)
        loss = F.cross_entropy(y_pred.transpose(1, 2), y)
        loss.backward()

        optim.step()

        cum_loss += loss.item() * dloader.batch_size

        for param_group in optim.param_groups:
            lrs.append(param_group['lr'])
    
    cum_loss /= (len(dloader) * batch_size)

    return cum_loss

def test_epoch(model, dloader, device):
    cum_loss =  0.0

    for x, y in dloader:
        x, y = x.to(device), y.to(device)

        with torch.no_grad():
            y_pred = model(x)
            loss = F.cross_entropy(y_pred.transpose(1, 2), y)

        cum_loss += loss.item() * dloader.batch_size
    
    cum_loss /= (len(dloader) * batch_size)

    return cum_loss

torch.manual_seed(2131)

batch_size = 8
num_epochs = 10
tokenizer_path = "tokenizer.json"

tokenizer = Tokenizer.from_file(tokenizer_path)
pad_token = tokenizer.token_to_id("<pad>")
vocab_size = tokenizer.get_vocab_size()
device = torch.device('cuda')

collate_fn = partial(collate_batch_fn, pad_token=pad_token)
train_dloader = DataLoader(train_ds, batch_size, shuffle=True, collate_fn=collate_fn)
val_dloader = DataLoader(val_ds, batch_size, shuffle=False, collate_fn=collate_fn)
test_dloader = DataLoader(test_ds, batch_size, shuffle=False, collate_fn=collate_fn)

model = NanoGPT()

# Gradient clipping
# for p in model.parameters():
#     p.register_hook(lambda grad: )

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

optim = AdamW(model.parameters())
total_steps = len(train_dloader) * num_epochs
warmup_steps = total_steps * 0.05
# lr_scheduler = LRScheduler(optim, warmup_steps=warmup_steps)
lr_scheduler = GPT2LRScheduler(optim, warmup_steps=warmup_steps, total_steps=total_steps)
model = model.to(device)
print(summary(model))

for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_dloader, optim, lr_scheduler, device)
    val_loss = test_epoch(model, val_dloader, device)
    perplexity = np.exp(val_loss)

    print(f"Epoch: {epoch}, Train Loss: {train_loss}, Val Loss: {val_loss}, Val Perplexity: {perplexity}")

test_loss = test_epoch(model, test_dloader, device)
perplexity = np.exp(test_loss)
print(f"Test Loss: {test_loss}, Test Perplexity: {perplexity}")
torch.save(model.module.state_dict(), "weights.pt")
# torch.save(model.state_dict(), "weights.pt")


Layer (type:depth-idx)                        Param #
DataParallel                                  --
├─NanoGPT: 1-1                                --
│    └─Embedding: 2-1                         4,159,488
│    └─Sequential: 2-2                        --
│    │    └─CausalSelfAttentionBlock: 3-1     3,151,530
│    │    └─CausalSelfAttentionBlock: 3-2     3,151,530
│    │    └─CausalSelfAttentionBlock: 3-3     3,151,530
│    │    └─CausalSelfAttentionBlock: 3-4     3,151,530
│    └─Linear: 2-3                            4,167,612
│    └─RMSNorm: 2-4                           512
Total params: 20,933,732
Trainable params: 20,933,732
Non-trainable params: 0


100%|██████████| 59/59 [00:12<00:00,  4.62it/s]


Epoch: 0, Train Loss: 5.497071646027646, Val Loss: 4.470759391784668, Val Perplexity: 87.42308618586773


100%|██████████| 59/59 [00:11<00:00,  5.23it/s]


Epoch: 1, Train Loss: 4.380050853147345, Val Loss: 4.115286350250244, Val Perplexity: 61.269756361106374


100%|██████████| 59/59 [00:11<00:00,  5.31it/s]


Epoch: 2, Train Loss: 4.031834772077658, Val Loss: 3.9667584896087646, Val Perplexity: 52.81305899640893


100%|██████████| 59/59 [00:11<00:00,  5.28it/s]


Epoch: 3, Train Loss: 3.714634656906128, Val Loss: 3.9460620880126953, Val Perplexity: 51.731252087931566


100%|██████████| 59/59 [00:11<00:00,  5.26it/s]


Epoch: 4, Train Loss: 3.4584456985279663, Val Loss: 3.938840389251709, Val Perplexity: 51.35901029565926


100%|██████████| 59/59 [00:11<00:00,  5.27it/s]


Epoch: 5, Train Loss: 3.300766977213197, Val Loss: 3.9603688716888428, Val Perplexity: 52.47667954079065


100%|██████████| 59/59 [00:11<00:00,  5.27it/s]


Epoch: 6, Train Loss: 3.0874060089305297, Val Loss: 4.019615650177002, Val Perplexity: 55.67970123032022


100%|██████████| 59/59 [00:11<00:00,  5.25it/s]


Epoch: 7, Train Loss: 3.0891660189224504, Val Loss: 4.021398544311523, Val Perplexity: 55.7790607905246


100%|██████████| 59/59 [00:11<00:00,  5.22it/s]


Epoch: 8, Train Loss: 2.9379237910448492, Val Loss: 4.0247979164123535, Val Perplexity: 55.96899722298665


100%|██████████| 59/59 [00:11<00:00,  5.24it/s]


Epoch: 9, Train Loss: 2.888797848911609, Val Loss: 4.046222686767578, Val Perplexity: 57.18105783254657
Test Loss: 4.70983007975987, Test Perplexity: 111.03329149912854


In [8]:
model = NanoGPT()
model.load_state_dict(torch.load("weights.pt", weights_only=True))
out = model.run_inference("Let")
print(out)

Let him:
Ay, I am going with me,
That I do,
My lord.
My lord, my lord.
You must:
So I'll not.
DUKE VINCENTIO:
DUKE VINCENTIO:
DUKE VINCENTIO:
I do, I have done, I will.
I know you?
And I know's the people, my lord, sir, and I had I had been
You will.
And you are:
ISABELLA:
DUKE VINCENTIO:
LUCIO:
LUCIO:
To your husband
LUCIO:
DUKE VINCENTIO:
I'll do not a very heart.
