<a href="https://colab.research.google.com/github/el-eshaano/ml/blob/main/Transformers/Decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
# Imports

import torch
import torch.nn as nn
import torch.nn.functional as F

In [15]:
# Hyperparams
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 2500
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
embd_dims = 64
n_heads = 4
n_layers = 4
dropout = 0.0

torch.manual_seed(1337)
# ------------

<torch._C.Generator at 0x785bab7a58d0>

In [16]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-02-18 11:40:15--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.13’


2024-02-18 11:40:15 (22.2 MB/s) - ‘input.txt.13’ saved [1115394/1115394]



In [17]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [18]:
chars = sorted(list(set(text)))
print(''.join(chars))
vocab_size = len(chars)
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [19]:
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s : [stoi[c] for c in s]
decode = lambda n : ''.join([itos[d] for d in n])

In [20]:
data = torch.tensor(encode(text), dtype=torch.long)

n = int(0.9 * len(data))
train_data = data[:n]
test_data = data[n:]

# Important Concept - `block_size`

The `block_size` specifies the maximum size of the data we will pull for training. It is the maximum number of examples that will be in any subsection of the array that we pull. It represents the maximum context level of our model

So, for example, suppose our array is:
`[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`

Here, if `block_size = 2`, we will pull 3 elements from the array, lets say the first three: `[1, 2, 3]`

This set of 3 elements contains **two** training examples for our transformer,
1. `1 => 2`
2. `1, 2 => 3`

Hence we will always take a subarray of size `block_size + 1`

So we can write
```python
x = test_array[:block_size]
y = test_array[1:block_size+1]
```

Hence, for any given value `t` in the `range(block_size)`
```python
context = x[:t+1]
target = y[t]
```


In [21]:
def get_batch(split="train"):
    data = train_data if split == 'train' else test_data

    ix = torch.randint(len(data) - block_size, (batch_size, )) # function call is torch.randint(low=0, high, size, ...)
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    x, y = x.to(device), y.to(device)

    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [22]:
class SelfAttentionHead(nn.Module):

    def __init__(self, head_size):
        super(SelfAttentionHead, self).__init__()

        self.head_size = head_size

        self.queries = nn.Linear(embd_dims, head_size)
        self.keys = nn.Linear(embd_dims, head_size)
        self.values = nn.Linear(embd_dims, head_size)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)


    def forward(self, x):
        B, T, C = x.shape

        k = self.keys(x)
        q = self.queries(x)

        # print(k.shape, q.shape)

        affs = q @ k.transpose(-2, -1) * C**-0.5 # Divide by C^(-0.5) so that affs has unit var
        affs = affs.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # Present should not be able to communicate with future
        affs = F.softmax(affs, dim=-1) # Normalize
        affs = self.dropout(affs)

        v = self.values(x)
        return affs @ v

In [23]:
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_size):
        super(MultiHeadAttention, self).__init__()
        self.sa_heads = nn.ModuleList([SelfAttentionHead(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(embd_dims, embd_dims)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([head(x) for head in self.sa_heads], dim=-1) # Concat along channel dim
        out = self.dropout(self.proj(out))
        return out

In [24]:
class SingleFeedForward(nn.Module):

    def __init__(self, in_size, out_size):

        super(SingleFeedForward, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(in_size, out_size),
            nn.ReLU(),
            nn.Linear(out_size, in_size), # Project back so that it can be added with the residual pathway
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


In [25]:
class TransformerBlock(nn.Module):

    def __init__(self, embd_dims, n_heads):
        super(TransformerBlock, self).__init__()

        head_size = embd_dims // n_heads
        self.sa = MultiHeadAttention(n_heads, head_size)
        self.ffwd = SingleFeedForward(embd_dims, 4 * embd_dims)

        self.ln1 = nn.LayerNorm(embd_dims)
        self.ln2 = nn.LayerNorm(embd_dims)

    def forward(self, x):
        x = x + self.sa(self.ln1(x)) # Skip connection with x added
        x = x + self.ffwd(self.ln2(x)) # Another skip connection
        return x

In [26]:
class DecoderTransformer(nn.Module):

    def __init__(self):

        super(DecoderTransformer, self).__init__()

        self.token_embedding = nn.Embedding(vocab_size, embd_dims)
        self.position_embedding = nn.Embedding(block_size, embd_dims)
        self.blocks = nn.Sequential(*[TransformerBlock(embd_dims, n_heads) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(embd_dims) # layer norm after all the transformer stuff
        self.lm_head = nn.Linear(embd_dims, vocab_size) # output head of language model

    def forward(self, x, targets=None):
        B, T = x.shape

        tok_emb = self.token_embedding(x)
        pos_emb = self.position_embedding(torch.arange(T, device=device))

        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)

            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, x, max_new_tokens):

        self.eval()
        for _ in range(max_new_tokens):
            cropped_x = x[:, -block_size:] # Only take block size amount of context
            logits, loss = self(cropped_x)
            logits = logits[:, -1, :] # Only get last token for each batch
            probs = F.softmax(logits, dim=-1)
            next_choice = torch.multinomial(probs, num_samples=1) # Predict the next charatcer given the probabilites
            x = torch.concat((x, next_choice), dim=1)
        self.train()
        return x


In [27]:
model = DecoderTransformer().to(device)
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

0.210497 M parameters
step 0: train loss 4.3090, val loss 4.3112
step 100: train loss 2.6446, val loss 2.6524
step 200: train loss 2.2733, val loss 2.2724
step 300: train loss 1.6392, val loss 1.6524
step 400: train loss 0.9465, val loss 0.9770
step 500: train loss 0.5368, val loss 0.5856
step 600: train loss 0.3680, val loss 0.4092
step 700: train loss 0.2774, val loss 0.3049
step 800: train loss 0.2189, val loss 0.2487
step 900: train loss 0.1868, val loss 0.2246
step 1000: train loss 0.1605, val loss 0.1813
step 1100: train loss 0.1338, val loss 0.1570
step 1200: train loss 0.1272, val loss 0.1423
step 1300: train loss 0.1166, val loss 0.1315
step 1400: train loss 0.1130, val loss 0.1250
step 1500: train loss 0.1013, val loss 0.1150
step 1600: train loss 0.1102, val loss 0.1201
step 1700: train loss 0.1347, val loss 0.1461
step 1800: train loss 0.0989, val loss 0.1085
step 1900: train loss 0.1024, val loss 0.1128
step 2000: train loss 0.1022, val loss 0.1137
step 2100: train loss 0.

In [28]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=2000)[0].tolist()))


Mooooo horrinchhhcch raanath Ku aisen bobe toe.
Sagr-'t theatitans bar bthie uhqurth bthar dinthoate arice mout fostatu zokou 
Youns-m'of in coing mit nditincueg ireens, hoin lat Hot duov te ande to poman'g trabe!
 lhin dome u.
Hhuce courity:'tug haiss hiw yo nurin's normopete gods'sk:
tink, titthakeno Winso whut eiings touti fouris so nuhireds poou gour; thu the hinteruf ff sor; igre! muf thin maleount ffaf Prisd mo om.
WHKINLuk!
Kuind isa
ardsad this me sto fin couk ay andy iry tome fo mo vouck no tounke mary.
Tou 'not buthm so aten hin.
Wpour bethe thimreand shoun be hes th thutisod,.
Butuch fosomy sstuth sou.
 moun sof thupeings.
Whuespues she oveve imd assce oros ifk ovet soie so urd histe feRil ass:
Whit CINGghatk nike, n idu he neesoxt atou thitheakes agh'scour
'ss m k o with selon mo dous flla nomur ipatrcessingsughild.
MLurd po ton'en'tun thity feo theoou tiund tho nof the sut pe iportheou suth than tthe silse, Bloel ke soun thy a thef bede wat hotiu, cout ir foru;
youkgeg ou