In [8]:
from ai.datamodules.text import TextDataModule
from ai.constants import REPO_DIR
from ai.misc_utils import get_vocab_size
import matplotlib.pyplot as plt
import numpy as np
import torch
import math
from torch import nn
import torch.nn.functional as F
from dataclasses import dataclass
from ai.models.nlp.transformer import Transformer as Transformer2
fname = f'{REPO_DIR}/data/shakespeare.txt'
# fname = f'{REPO_DIR}/data/names.txt'

In [31]:
config = Transformer2.Config(
    vocab_size = get_vocab_size(fname),
    block_size = 32,
    # mlp_size:int = 64
    n_embd = 256,
    n_heads = 4,
    n_blocks = 4,
    dropout=0
)
batch_size = 256

In [32]:
dm = TextDataModule.Config(fname=fname, block_size=config.block_size, batch_size=batch_size, 
                           num_workers=0).i()
dm.prepare_data()

[32m2024-02-08 14:12:17.179[0m | [1mINFO    [0m | [36mai.datamodules.text[0m:[36mprepare_data[0m:[36m76[0m - [1mCreating datasets...[0m


In [35]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        c = config
        self.config = config
        self.q = nn.Linear(c.n_embd, c.n_embd)
        self.k = nn.Linear(c.n_embd, c.n_embd)
        self.v = nn.Linear(c.n_embd, c.n_embd)

        attn_mask = torch.ones(c.block_size, c.block_size)
        attn_mask.masked_fill(~torch.tril(attn_mask).to(bool), -torch.inf)
        self.register_buffer('attn_mask', attn_mask)
    
    def forward(self, x):
        c = self.config
        B=x.shape[0]
        T=c.block_size
        C=c.n_embd
        nh=c.n_heads

        
        q=self.q(x) # B,T,C
        k=self.k(x) # B,T,C
        v=self.v(x) # B,T,C
        q=q.view(B,T,nh,C//nh).permute(0,2,1,3) # B,nh,T,C//nh
        k=k.view(B,T,nh,C//nh).permute(0,2,3,1) # B,nh,C//nh,T
        v=v.view(B,T,nh,C//nh).permute(0,2,1,3) # B,nh,T,C//nh
        
        attn = q@k * (math.sqrt(1/(C//nh))) # B,nh,T,T
        attn *= self.attn_mask
        attn = torch.softmax(attn, dim=-1) # B,nh,T,T
        o = attn@v # B,nh,T,C//nh
        
        return o.permute(0,2,1,3).contiguous().view(B,T,C)

class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        c = config
        self.l1 = nn.Linear(c.n_embd, c.n_embd*4)
        self.l2 = nn.Linear(c.n_embd*4, c.n_embd)
        
    def forward(self, x):
        o = self.l1(x)
        o = torch.relu(o)
        o = self.l2(o)
        return o

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        c = config
        self.config = config

        self.attn = Attention(config)
        self.ln1 = nn.LayerNorm(c.n_embd)

        self.ln2 = nn.LayerNorm(c.n_embd)
        
        self.ff = FeedForward(c)

    def forward(self, x):
        o = x + self.attn(self.ln1(x))
        o = x + self.ff(self.ln2(o))
        return o

class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        c = config
        self.config=config
        self.token_emb = nn.Embedding(c.vocab_size, c.n_embd)
        self.pos_emb = nn.Embedding(c.block_size, c.n_embd)
        self.blocks = nn.ModuleList([Block(config) for _ in range(c.n_blocks)])
        self.ln_final = nn.LayerNorm(c.n_embd)
        self.proj = nn.Linear(c.n_embd, c.vocab_size)
        
        self.register_buffer('pos_id', torch.arange(0, c.block_size))
    
    def forward(self, x):
        tok_emb = self.token_emb(x)
        pos_emb = self.pos_emb(self.pos_id)
        o = pos_emb + tok_emb
        
        for block in self.blocks:
            o = block(o)
        self.ln_final(o)
        o = self.proj(o)
        return o


# B = batch_size
T = config.block_size
model = Transformer(config)


In [36]:
dl = dm.train_dataloader()

In [None]:
lr=.01

step=0
for epoch in range(4):
    for X,y in iter(dl):
        step+=1
        B = X.shape[0]
    
        # forward
        logits = model(X)
        loss = F.cross_entropy(logits.view(B*T,config.vocab_size), y.view(B*T))
        if step % 100 == 0:
            print(f'e{epoch}|s{step}: loss={loss:.2f}')
    
        for param in model.parameters():
            param.grad = None
        loss.backward()
        for param in model.parameters():
            if param.grad is not None:
                param.data -= param.grad * lr

e0|s100: loss=3.03
e0|s200: loss=2.84


In [30]:
x = X.clone()
result = []
for i in range(30):
    logits = model(x)
    probas = torch.softmax(logits, -1)
    # greedy
    # token_ids = probas[:,-1].argmax(-1)
    token_ids = torch.multinomial(probas[:,-1], num_samples=1)
    x = x.roll(-1)
    x[:,-1] = token_ids.flatten()
    result.append(token_ids.flatten())


prompt = X[0]
answer = result[0]
print(''.join(dm.ds_train.decode(prompt)))
print('*'*72)
print(''.join(dm.ds_train.decode(answer)))

r'd he hath not,
But basely yiel
************************************************************************
 WbawsriositdhnrehywkoBn,.Aawaa
Xhydd Sk e aa:,wf antus;tn  e nmstrhlnygh  da emnattedrv,hhDmrur aghVsrfn? hdrnofpG syrht  hou tPyo fe Mwprl
 tc olIdhHgw.,
 ecrMeosVteBaauVsUo:'npw o  :n t    g. wnsef f ddi oua , qNlMasSh fo oensT d
o
 ofratFrfols  ltdhU e


In [22]:
# dm.ds_train.decode(answer)