In [4]:
import torch

input_content = ""
with open("../code/input.txt") as f:
    input_content = f.read()
    
chars = sorted(list(set(input_content)))
VOCAB_SIZE = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
data = torch.tensor(encode(input_content), dtype=torch.long)
n = int(0.9*len(data))
train = data[:n]
val = data[n:]

In [5]:
BLOCK_SIZE = 8
BATCH_SIZE = 4

def get_batch(split):
    data = train if split == "train" else val
    idx = torch.randint((len(data) - BLOCK_SIZE), (BATCH_SIZE,))
    x = torch.stack([data[i:i+BLOCK_SIZE] for i in idx])
    y = torch.stack([data[i+1:i+BLOCK_SIZE+1] for i in idx])
    return x, y
    
batchx, batchy = get_batch("train")

In [25]:
import torch.nn.functional as F

class BigramLM(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Embedding(VOCAB_SIZE, VOCAB_SIZE)
        
    def forward(self, idx, targets=None):
        logits = self.embedding(idx)   # B,T,C
        
        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)   # But torch needs B,C,T 
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
        return logits, loss
    
    def generate(self, idx, max_tokens):
        for _ in range(max_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :] # Take last T from B,T,C
            probs = F.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_idx), dim=1)
        return idx
        
bgm = BigramLM()
# check forward pass
out, loss = bgm(batchx, batchy)
# check generation
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(bgm.generate(idx, max_tokens=100)[0].tolist()))


 phz!OPyqc&yqMOQJ?QRXyqVoZAqfwnLp
Tn
lJRRl .g,KYW&kA:n3m'lu
 GtNx .qc!gmH: KYx,'VyyQRsALZA3
Cp?n,3aV


In [26]:
# Simple Bigram, only taking T-1 for predicting T

optimizer = torch.optim.AdamW(bgm.parameters(), lr=1e-3)
# experiment optimization
batch_size = 32
for _ in range(10000):
    x, y = get_batch("train")
    logits, loss = bgm(x,y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss)
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(bgm.generate(idx, max_tokens=500)[0].tolist()))


3QJuzje AUNQv?GFSgJlzr3-fYaSg
CSNHWbJlfGK:'eGSWrCgD-JritjwU&os;.s
rpoatIRBgZqqetmu,pKq3pn.lSvdvwXd:uEsXovSgnf&YW-q
CSNU:LFs vWbaL3NR $A&y;wdMZHfSZlDqc
dqzgorGsJWvpKG:&HWsA:FcmfAUgfaWVyxkMOTb'l?x3$FtSyRD-j;iRv3&GAj
rbpc.qLE,Uqu'fXCiY:ki TGayqe;.oBWfSZKDsTMl
TX.HC3r
uAb oMYRTQvdwJkJLB.z&cb ,Ty&kCMOPyqN.Du,ua!BpjcgCC3a-tH:a.OOA.qsK IZ,Zbl

TueQehNgnTiasCwJkbi XTyqONmTAkUAIdhl&ynfgft !IjhBzQI,z3AGF.;wpqKzFlbSgPz3vHfNSLvWzN,CSg:JVYWaNC:&YORL&TAOQJWzq agcjl.$vJCSyBEpQjS3uOtB'oxIFQeMY,ovnwurH!qct?ETFYQ


In [27]:
# Self-Attention experiment with BoW  
# Lets try to average the last N chars, not ideal, but lets try

B, T, C = 4, 8, 2
x = torch.randn(B,T,C)

# Approach 1
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True) # (T,T) @ (B,T,C) ==> (B,T,T) @ (B,T,C) = (B, T, C)
result_1 = wei @ x

# Approach 2
tril = torch.tril(torch.ones(T, T))
# Affinities wont be constant -> each present will be influenced diff by diff past words
wei = torch.zeros(T, T)
# future cant communicate with the present
wei = wei.masked_fill(tril == 0, float('-inf')) 
wei = F.softmax(wei, dim=-1)
result_2 = wei @ x
torch.allclose(result_1, result_2)

True

In [75]:
EMBED_SIZE = 32

import torch.nn.functional as F

class Head(torch.nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = torch.nn.Linear(EMBED_SIZE, head_size)
        self.query = torch.nn.Linear(EMBED_SIZE, head_size)
        self.value = torch.nn.Linear(EMBED_SIZE, head_size)
        self.tril = torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE))
    
    def forward(self, x):
        B, T, C = x.shape
        key = self.key(x)
        query = self.query(x)
        wei = query @ key.transpose(-2,-1) * C**-0.5 # B,T,16 @ B, 16, T -> B, T,T -> T,T like square dist matrix
        wei = wei.masked_fill(self.tril[:T, :T] == 0, -float("inf"))
        wei = F.softmax(wei, dim=-1)
        value = self.value(x)
        out = wei @ value
        return out
        
        
class BigramLM(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = torch.nn.Embedding(VOCAB_SIZE, EMBED_SIZE)
        self.pos_embedding = torch.nn.Embedding(BLOCK_SIZE, EMBED_SIZE)
        # insert self attention head
        self.self_att_head = Head(EMBED_SIZE)
        self.lm_head = torch.nn.Linear(EMBED_SIZE, VOCAB_SIZE)
        
    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding(idx)   # B,T,C
        pos_emb = self.pos_embedding(torch.arange(T))
        x = tok_emb + pos_emb
        logits = self.lm_head(x)
        
        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)   # But torch needs B,C,T 
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
        return logits, loss
    
    def generate(self, idx, max_tokens):
        for _ in range(max_tokens):
            idx_cond = idx[:, -BLOCK_SIZE:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1) 
            idx_next = torch.multinomial(probs, num_samples=1) 
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
        
bgm = BigramLM()
# check forward pass
out, loss = bgm(batchx, batchy)
# check generation
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(bgm.generate(idx, max_tokens=500)[0].tolist()))


jVHIlYYR3S3FUGuIL
tPofYpw3DILxkWbGJpnfyuLTou&.fDb.mT?WCmpaHdDnbjMSSEbJ'pbGXiPbaUQLAEwDIGHoa3yTDGG3nULANf'akIbabksFuXs3cL NGeBEE
zX,3fIaew.UaxTxnoffxoqLnb-dvT Lwo:JgTmLMg-EEsWPtWCG33;DnIncA dOtlajGTpa'bf:YF
xqL
YYaMjq!SrRfdxGb3fDZRZMy-F abnmxYGuHx3,
m!xmLPbGzAU J3.cEIfhRE-GLBoyYrV?3 kuLPbBpXt:bbGACwTaYJzu'af;suDoHHIafaHRd3;klWE wRbyHfbMb3oyyb;IabHQrnnj3IbbUef3EcOakHJ33MmXiqT'G.,.-naXKbG'jIR?rLsvR:MLNnQenHMYB,GJCG  NP3;,Uyy3mL-uaf3IVCLcAAURLVPVPffBXjS3yUXejb-qVQL?,;uaUHLfm ,mruFmiuPp?xU JPBW &QxAa


In [47]:
# Lets train new BM with Self attention head

# Simple Bigram, only taking T-1 for predicting T
bgm = BigramLM()
optimizer = torch.optim.AdamW(bgm.parameters(), lr=1e-3)
# experiment optimization
batch_size = 32
for _ in range(10000):
    x, y = get_batch("train")
    logits, loss = bgm(x,y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss)
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(bgm.generate(idx, max_tokens=500)[0].tolist()))

tensor(2.2859, grad_fn=<NllLossBackward0>)


LO:
BRI he.
I'ccuthes s;
The man jun.

fiud y:
Omveravin.
SAs adl hme wit w areaney uy ormowind thornin'lod
Wheit mer, wifistr pe s CHankilthers I:
Ankon peemyous gevererethea wise'nkfAny, thre R' in;,,
D kicorard,

Acesor the MAnals, h bu, l Iem aralat rd mokndevomfuthouch qmade my

UERDpede t s
Bor t t m se

Whe,

Teavided e, crt tBEYothardifoo, quromend

Ange if teed;

An.
TangheMI be re, f mands theans wig gin,
Bavet lld phened del-dtr n. mbechenthanthld
LI: mer,
Thay tsooor bbli be I ou, w


In [133]:
# Multi-head attention
class MultiAttentionHead(torch.nn.Module):
    def __init__(self, no_of_heads, head_size):
        super().__init__()
        self.heads = torch.nn.ModuleList([Head(head_size) for _ in range(no_of_heads)])
        self.proj = torch.nn.Linear(head_size * no_of_heads, EMBED_SIZE)
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.proj(out)

In [73]:
# Update BigramLM
NO_OF_HEADS = 4

class BigramLM(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = torch.nn.Embedding(VOCAB_SIZE, EMBED_SIZE)
        self.pos_embedding = torch.nn.Embedding(BLOCK_SIZE, EMBED_SIZE)
        # insert self attention head
        self.heads = MultiAttentionHead(NO_OF_HEADS, EMBED_SIZE//NO_OF_HEADS)
        self.lm_head = torch.nn.Linear(EMBED_SIZE, VOCAB_SIZE)
        
    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding(idx)   # B,T,C
        pos_emb = self.pos_embedding(torch.arange(T))
        x = tok_emb + pos_emb
        x = self.heads(x)
        logits = self.lm_head(x)
        
        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)   # But torch needs B,C,T 
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
        return logits, loss
    
    def generate(self, idx, max_tokens):
        for _ in range(max_tokens):
            idx_cond = idx[:, -BLOCK_SIZE:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1) 
            idx_next = torch.multinomial(probs, num_samples=1) 
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [None]:
# Lets train new BM with Self attention head
bgm = BigramLM()
optimizer = torch.optim.AdamW(bgm.parameters(), lr=1e-3)
batch_size = 32
losses = []
for _ in range(10000):
    x, y = get_batch("train")
    logits, loss = bgm(x,y)
    losses.append(loss.item())
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(np.mean(losses))
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(bgm.generate(idx, max_tokens=500)[0].tolist()))
# Loss : 2.6064284039855004

In [142]:
# Adding few more components

class FeedForward(torch.nn.Module):
    def __init__(self, EMBED_SIZE):
        super().__init__()
        # * by 4 acc to paper.
        self.net = torch.nn.Sequential(torch.nn.Linear(EMBED_SIZE, 4 * EMBED_SIZE),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(4 * EMBED_SIZE, EMBED_SIZE)) # Projection
    def forward(self, x):
        return self.net(x)

    
class Block(torch.nn.Module):
    def __init__(self, EMBED_SIZE, no_of_heads):
        super().__init__()
        head_size = EMBED_SIZE // no_of_heads
        self.sa = MultiAttentionHead(no_of_heads, head_size)
        self.ffd = FeedForward(EMBED_SIZE)
        self.layer_norm_1 = torch.nn.LayerNorm(EMBED_SIZE)
        self.layer_norm_2 = torch.nn.LayerNorm(EMBED_SIZE)

    def forward(self, x):
        # Residual connections
        # Pre-norm formulation
        x = x + self.sa(self.layer_norm_1(x))
        x = x + self.ffd(self.layer_norm_2(x))
        return x

In [143]:
# Update BigramLM
NO_OF_HEADS = 4
NO_OF_BLOCKS = 4

class BigramLM(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = torch.nn.Embedding(VOCAB_SIZE, EMBED_SIZE)
        self.pos_embedding = torch.nn.Embedding(BLOCK_SIZE, EMBED_SIZE)
        # insert blocks
        self.blocks = torch.nn.Sequential(*[Block(EMBED_SIZE, 4) for _ in range(NO_OF_BLOCKS)])
        self.lm_head = torch.nn.Linear(EMBED_SIZE, VOCAB_SIZE)
        
    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding(idx)   # B,T,C
        pos_emb = self.pos_embedding(torch.arange(T))
        x = tok_emb + pos_emb 
        x = self.blocks(x)
        logits = self.lm_head(x)
        
        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)   # But torch needs B,C,T 
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
        return logits, loss
    
    def generate(self, idx, max_tokens):
        for _ in range(max_tokens):
            idx_cond = idx[:, -BLOCK_SIZE:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1) 
            idx_next = torch.multinomial(probs, num_samples=1) 
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [144]:
# Lets train new BM with Self attention head
bgm = BigramLM()
optimizer = torch.optim.AdamW(bgm.parameters(), lr=1e-3)
batch_size = 32
losses = []
for _ in range(10000):
    x, y = get_batch("train")
    logits, loss = bgm(x,y)
    losses.append(loss.item())
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(np.mean(losses))
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(bgm.generate(idx, max_tokens=500)[0].tolist()))


2.2917728568553923

He dow.

AREN ROVICOLH:
Bursing--lucher undun aved lauke murkemss.

MICIOVWAxh; our me whast mure a shall fore wardy;
Will thitth, O, how, our that poon not
And withest heave.
Wurtame fingume,
Hard is sarlond
Lut infid withou how to i men vill hard you welconder thange is
Bucone thy thhe bu lovertumed:
To but not bacther bon will hand nopsice
souquy iperothe tho.

AgELICIIFLA:
I my vost that he verein.

CAPurmund wito hat ard dupresty Pelloi macork's nour thou a lat hath andand have howsif love,
