In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt
import regex as re
from icecream import ic

In [None]:
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
device

In [None]:
do_download:bool=False
if do_download:
    !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

## Different methods to compute rolling mean on T axis.
like weighted sum with same weight on previous items of sequence

In [None]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape
x

In [None]:
# method 1
xbow = torch.zeros((B, T, C))
xbow.shape
for b in range(B):
    for t in range(T):
        xprev = x[b, : t + 1]
        # print(xprev)
        # print(b, t, xprev.shape)
        xbow[b, t] = torch.mean(xprev, 0)
        # print(xbow[b, t])
        # print("-----")
        # print(xbow)

xbow.shape

In [None]:
# method 2
wei = torch.tril(torch.ones((T, T)))
print(wei)
wei = wei / wei.sum(axis=1, keepdim=True)
print("weight", wei)
xbow2 = wei @ x  # B,T,C @ B,T,T ---> B,T,C
torch.allclose(xbow, xbow2)

In [None]:
# method 3 using softmax
tril = torch.tril(torch.ones((T, T)))

wei = torch.zeros((T, T))
# print(wei)
wei = wei.masked_fill(tril == 0, value=float("-inf"))
# print(wei)
wei = F.softmax(wei, dim=1)
print("weight", wei)

xbow3 = wei @ x
torch.allclose(xbow, xbow3)

## Attention mechanism

In [None]:
torch.manual_seed(42)
B, T, C = 16, 8, 32
x = torch.rand(B, T, C)

head_size = 32
key = nn.Linear(C, head_size)
k = key(x)
query = nn.Linear(C, head_size)
q = query(x)
value = nn.Linear(C, head_size)
v = value(x)

wei = q @ k.transpose(-2, -1) * head_size**-0.5
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)

out = wei @ v
# out.shape

In [None]:
# transpose vs reshape

m = torch.tensor(range(1, 65))
print(m)
m.storage()
b, t, c = 2, 8, 4
mr = m.reshape(b, t, c)
# mr.storage()
print(mr)

In [None]:
mv = m.view(b, t, c)
mv.view(b, c, t)

In [None]:
mr

In [None]:
# you transpose axis not shape dimensions
mr.transpose(-2, -1)

In [None]:
s = torch.tensor(range(1, 11)).reshape(2, 5)
print(s)
print(s.transpose(0, 1))
print(s.transpose(1, 0))
st = s.transpose(1, 0)
print(st.storage())
# you change the physical storage of the tensor to be contiguous
st.contiguous().storage()

# Data prep

- BPE tokenizer
- character based tokenizer

In [None]:
with open("./input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [None]:
n = int(0.9 * len(text))
train_txt = text[:n]
test_txt = text[n:]
train_txt

## BPE tokenizer

In [None]:
from tokenizers import (
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer,
)

tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
# tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
# GPT doen't use a normalizer but we add it here for demo purposes
normalizer = normalizers.Sequence([normalizers.NFD(), normalizers.Lowercase(), normalizers.StripAccents()])
normalizer.normalize_str(train_txt[:30])

In [None]:
for t, (s, e) in tokenizer.pre_tokenizer.pre_tokenize_str(train_txt[:50]):
    print(t, s, e)

In [None]:
trainer = trainers.BpeTrainer(vocab_size=1000, special_tokens=["<|endoftext|>"])
tokenizer.train_from_iterator([train_txt], trainer=trainer)
tokenizer.get_vocab_size()

In [None]:
tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
tokenizer.decoder = decoders.ByteLevel()
encodings = tokenizer.encode(train_txt[:50])
print("tokens: ", encodings.tokens)
print("ids: ", encodings.ids)
print("decoded:", tokenizer.decode(encodings.ids))
# train_bpe = tokenizer.encode(train_txt).ids
# test_bpe = tokenizer.encode(test_txt).ids
print(f"offsets:{[train_txt[s:e] for s,e in encodings.offsets]}")

In [None]:
# create train and validation data for bpe based tokens
data = torch.tensor(tokenizer.encode(text).ids, dtype=torch.long)
n = int(0.9 * len(data))
train_bpe = data[:n]
val_bpe = data[n:]

In [None]:
def get_bpe_batch(split: str = "train", batch_size: int = 16, block_size: int = 8):
    data = train_bpe if split == "train" else val_bpe
    ix = torch.randint(len(data) - block_size, (batch_size,))
    ic(ix)
    # ix are random starting points of each batch

    x = torch.stack([data[i : i + block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    x = x.to(device)
    y = y.to(device)
    return x, y


get_bpe_batch(batch_size=4, block_size=8)

In [None]:
for i in range(2):
    x, y = get_bpe_batch(batch_size=16, block_size=8)
    print(x.shape, y.shape)
    print(x.to(int).tolist())
    print(tokenizer.decode_batch(x.to(int).tolist()))
    print(tokenizer.decode_batch(y.to(int).tolist()))
    print("-----")

## Char based tokenizer

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {s: i for i, s in enumerate(chars)}
itos = {i: s for i, s in enumerate(chars)}


def encode(s: str) -> list[int]:
    return [stoi[c] for c in s]


def decode(l: list[int]) -> str:
    return "".join([itos[i] for i in l])

In [None]:
decode(encode("j'aime les chats"))
# stoi["a"], stoi["A"]

In [None]:
torch.manual_seed(42)
# Compute train and validation data
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(text))
train = data[:n]
val = data[n:]
block_size = 8


def get_char_batch(split: str = "train", batch_size: int = 16, block_size: int = 8):
    data = train if split == "train" else val
    ix = torch.randint(len(data) - block_size, (batch_size,))
    # print(ix)
    # ix are random starting points of each batch
    x = torch.stack([data[i : i + block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    x = x.to(device)
    y = y.to(device)
    return x, y

In [None]:
x, y = get_char_batch(batch_size=16, block_size=8)
x, y
x.shape, y.shape

In [None]:
use_bpe: bool = False
if use_bpe:
    vocab_size = tokenizer.get_vocab_size()
    get_batch = get_bpe_batch
else:
    vocab_size = 65
    get_batch = get_char_batch

## GPT

In [None]:
class Head(nn.Module):
    def __init__(self, head_size: int, n_embd: int, block_size: int, dropout: float = 0.1) -> None:
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.smax = nn.Softmax(dim=-1)
        self.register_buffer("tril", tensor=torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        wei = self.smax(wei)  # F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, head_size: int, block_size: int, n_embd: int, dropout: float = 0.1) -> None:
        super().__init__()
        self.heads = nn.ModuleList(
            [Head(head_size, block_size=block_size, n_embd=n_embd, dropout=dropout) for _ in range(num_heads)]
        )
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out


class FeedForward(nn.Module):
    def __init__(self, n_embd: int, dropout: float = 0.1) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd), nn.ReLU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    def __init__(self, n_embd: int, num_heads: int, block_size: int, dropout: float = 0.1):
        super().__init__()
        head_size = n_embd // num_heads
        self.sa = MultiHeadAttention(
            num_heads=num_heads, head_size=head_size, block_size=block_size, n_embd=n_embd, dropout=dropout
        )
        self.ffwd = FeedForward(n_embd, dropout=dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


class GPTLanguageModel(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        num_heads: int = 4,
        n_embd: int = 32,
        block_size: int = 8,
        num_layers: int = 2,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.block_size = block_size
        self.n_embd = n_embd
        self.num_layers = num_layers
        self.token_emb = torch.nn.Embedding(vocab_size, n_embd)
        self.pos_emb = torch.nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, num_heads, block_size, dropout=dropout) for _ in range(num_layers)])
        self.layer_norm = nn.LayerNorm(n_embd)
        self.lm_head = torch.nn.Linear(n_embd, vocab_size)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_emb(idx)
        pos_emb = self.pos_emb(torch.arange(T, device=device))
        # print("pos_emb shape", pos_emb.shape)
        x = tok_emb + pos_emb  # B, T, C
        # print(f"x shape: {x.shape}")
        x = self.blocks(x)
        x = self.layer_norm(x)
        logits = self.lm_head(x)

        # print(f"logits shape: {logits.shape}")
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size :]
            logits, _ = self(idx_cond)
            # print("generate logit shape", logits.shape)

            logits = logits[:, -1, :]

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)

        return idx


# hyperparams
batch_size = 64  # how many independent sequences will we process in parallel?
block_size = 256  # what is the maximum context length for predictions?


n_embd = 384
num_heads = 6
num_layers = 6
dropout = 0.2


model = GPTLanguageModel(
    vocab_size, num_heads=num_heads, num_layers=num_layers, n_embd=n_embd, block_size=block_size, dropout=dropout
)
model = model.to(device)
xb, yb = get_batch("train")
train_logit, train_loss = model(xb)
train_logit.shape

print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")
print("it should be about 10M")

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size=batch_size, block_size=block_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


max_iters = 7001
eval_iters = 200
eval_interval = 500
learning_rate = 3e-4


# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

model_id = "cop_gpt"
model_version = "0.1"
log_dir = Path("../runs") / f"{model_id}-{model_version}" / f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"

writer = SummaryWriter(log_dir=log_dir)

do_train: bool = False
last_iter = 0  # to start from scratch
last_iter = 5000
if last_iter > 0:
    model = torch.load(f"gpt-{model_version}-{last_iter}.pt")
if do_train:
    for i in range(last_iter + 1, max_iters):
        # every once in a while evaluate the loss on train and val sets
        if i % eval_interval == 0:
            losses = estimate_loss()
            print(f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
            writer.add_scalar("loss/train", losses["train"], i)
            writer.add_scalar("loss/val", losses["val"], i)
            torch.save(model, f"gpt-{model_version}-{i}.pt")

        # sample a batch of data
        xb, yb = get_batch("train", batch_size=batch_size, block_size=block_size)

        # evaluate the loss
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        for name, weight in model.named_parameters():
            writer.add_histogram(name, weight, i)

        optimizer.step()

    torch.save(model, f"gpt-{model_version}.pt")

In [None]:
if not do_train:
    model = torch.load(f"gpt-{model_version}.pt")

In [None]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
if use_bpe:
    print(tokenizer.decode(model.generate(context, max_new_tokens=50)[0].to(int).tolist()))
else:
    print(decode(model.generate(context, max_new_tokens=100)[0].tolist()))

In [None]:
embeddings = model.state_dict()["token_emb.weight"]
metadata = [itos[i] for i in range(vocab_size)]


# Add embeddings to the writer
writer.add_embedding(embeddings, metadata=metadata)

# Close the writer
writer.close()

## Understand model behaviour

Capture the activation at various moment within the model

In [None]:
model.blocks

In [None]:
activation = {}


def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()

    return hook


# register a hook to access activations of each head.
for b, block in enumerate(model.blocks):
    for i, h in enumerate(block.sa.heads):
        h.register_forward_hook(get_activation(f"block_{b}_head_{i}"))
        h.smax.register_forward_hook(get_activation(f"block_{b}_smax_{i}"))

model.lm_head.register_forward_hook(get_activation("lm_head"))

In [None]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
gen_toks = model.generate(context, max_new_tokens=8)
chars = [c for c in decode(gen_toks[0].tolist())]
chars

In [None]:
num_rows = num_heads // 2
num_cols = num_heads // num_rows
fig, axes = plt.subplots(num_rows, num_cols, figsize=(7, 7))
axes = axes.flatten()


for i, ax in enumerate(axes):
    im = ax.imshow(activation[f"block_0_smax_{i}"][0].cpu().numpy(), cmap="hot", interpolation="nearest")
    ax.set_title(f"head {i}")
    ax.set_xticks(np.arange(len(chars)), labels=chars)
    ax.set_yticks(np.arange(len(chars)), labels=chars)

cbar = fig.colorbar(im, ax=axes, orientation="vertical", fraction=0.1, pad=0.05)


# fig.colorbar(im, ax=ax)
# fig.tight_layout()

In [None]:
# plot histogramt (torch.histogram) for each output  distribution for each layer. Also plot gradient to check its normally distributed.
for l in model.blocks[0].sa.heads.named_parameters():
    print(l[0], l[1].shape)

In [None]:
def display_params(model, layer_name):
    npp = model.state_dict()[layer_name].cpu().detach().numpy()

    plt.imshow(npp, cmap="hot", interpolation="nearest")
    plt.title(layer_name)
    plt.show()


display_params(model, "token_emb.weight")
display_params(model, "pos_emb.weight")

## Mixture of experts

https://huggingface.co/blog/AviSoori1x/makemoe-from-scratch

In [None]:
# Expert module
class Expert(nn.Module):
    """An MLP is a simple linear layer followed by a non-linearity i.e. each Expert"""

    def __init__(self, n_embd: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
# Understanding how gating/router works
num_experts = 4
top_k = 2
n_embed = 32


# Example multi-head attention output for a simple illustrative example, consider n_embed=32, context_length=4 and batch_size=2
mh_output = torch.randn(2, 4, n_embed)

topkgate_linear = nn.Linear(n_embed, num_experts)  # nn.Linear(32, 4)

logits = topkgate_linear(mh_output)
top_k_logits, top_k_indices = logits.topk(top_k, dim=-1)  # Get top-k experts
print(logits)
"top k logits:", top_k_logits, "top k indices:", top_k_indices

In [None]:
# keep the top-k experts and set the rest to -inf
zeros = torch.full_like(
    logits, float("-inf")
)  # full_like clones a tensor and fills it with a specified value (like infinity) for masking or calculations.
sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits)
# transform the logits into a probability distribution
gating_output = F.softmax(sparse_logits, dim=-1)
sparse_logits, gating_output

In [None]:
# First define the top k router module
class TopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(TopkRouter, self).__init__()
        self.top_k = top_k
        self.linear = nn.Linear(n_embed, num_experts)

    def forward(self, mh_ouput):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.linear(mh_output)
        top_k_logits, indices = logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(logits, float("-inf"))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

In [None]:
# Testing this out:
num_experts = 4
top_k = 2
n_embd = 32

mh_output = torch.randn(2, 4, n_embd)  # Example input
top_k_gate = TopkRouter(n_embd, num_experts, top_k)
gating_output, indices = top_k_gate(mh_output)
gating_output.shape, gating_output, indices
# And it works!!

In [None]:
# softplus is a smoothed version of RELU function.
input = torch.tensor([0.2, 2.3, 10.0, -0.1, -3.2, -10.0])
sp = F.softplus(input)

rand_noise = torch.randn_like(input)
out = rand_noise * F.softplus(input)
input, sp.numpy().round(4), out.numpy().round(4),

In [None]:
# Changing the above to accomodate noisy top-k gating
class NoisyTopkRouter(nn.Module):
    """Essentially, you don't want all the tokens to be sent to the same set of 'favored' experts.
    You want a fine balance of exploitation and exploration. For this purpose, to load balance,
    it is helpful to add standard normal noise to the logits from the gating linear layer.
    This makes training more efficient"""

    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        # layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear = nn.Linear(n_embed, num_experts)

    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        # Noise logits
        noise_logits = self.noise_linear(mh_output)

        # Adding scaled unit gaussian noise to the logits
        # softplus ensures that the noise is always positive and right skewed
        noise = torch.randn_like(logits) * F.softplus(noise_logits)
        # noisy logit add noise to the logits so some tokens are sent to different experts and not just the top-k.
        # It pushes the model to explore more.
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float("-inf"))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

In [None]:
num_experts = 8
top_k = 2
n_embd = 16

mh_output = torch.randn(2, 4, n_embd)  # Example input
noisy_top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)
gating_output, indices = noisy_top_k_gate(mh_output)
gating_output.shape, gating_output, indices

In [None]:
class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        gating_output, indices = self.router(x)
        ic(gating_output)
        final_output = torch.zeros_like(x)

        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask = (indices == i).any(dim=-1)
            ic(expert_mask)
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)

                # Extract and apply gating scores
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                # Update final output additively by indexing and adding
                final_output[expert_mask] += weighted_output.squeeze(1)

        return final_output

In [None]:
import torch
import torch.nn as nn

ic.enable()
# ic.disable()

# Let's test this out
num_experts = 4
top_k = 2
n_embd = 16
dropout = 0.1

mh_output = torch.randn(1, 8, n_embd)  # Example multi-head attention output
sparse_moe = SparseMoE(n_embd, num_experts, top_k)
final_output = sparse_moe(mh_output)
print("Shape of the final output:", final_output.shape)

### view layer values

In [None]:
Add mechanism to view attention on each token
https://www.comet.com/site/blog/explainable-ai-for-transformers/