In [None]:
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm.notebook as tqdm

from tokenizers import Tokenizer, models, decoders, trainers, tools, pre_tokenizers

In [None]:
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
tokenizer.decoder = decoders.ByteLevel()
trainer = trainers.BpeTrainer(special_tokens=["[PAD]", "[SOS]", "[EOS]", "[MASK]", "[UNK]"], vocab_size=8192)

In [None]:
tokenizer.train(["blog/5-shakespeare/data/train.txt"], trainer=trainer)
print(f"Vocab size: {tokenizer.get_vocab_size()}")

del trainer

In [None]:
with open("blog/5-shakespeare/data/train.txt", "r") as f:
    train_corpus = f.read()

with open("blog/5-shakespeare/data/test.txt", "r") as f:
    test_corpus = f.read()

train_encoded_corpus = tokenizer.encode(train_corpus).ids
val_encoded_corpus = tokenizer.encode(test_corpus).ids

del train_corpus, test_corpus

In [None]:
# Create dataset
class Dataset(torch.utils.data.Dataset):
    def __init__(self, corpus, seq_len):
        self.corpus = corpus
        self.seq_len = seq_len

    def __len__(self):
        return len(self.corpus) - self.seq_len

    def __getitem__(self, idx):
        return torch.as_tensor(self.corpus[idx:idx+self.seq_len]), torch.as_tensor(self.corpus[idx+1:idx+self.seq_len+1])
    

seq_len = 64
train_dataset = Dataset(train_encoded_corpus, seq_len)
val_dataset = Dataset(val_encoded_corpus, seq_len)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=True)

del train_encoded_corpus, val_encoded_corpus

In [None]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"

In [None]:
# If model already exists, delete it
if "model" in locals():
    del model
if "optimizer" in locals():
    del optimizer
if "scheduler" in locals():
    del scheduler


@dataclass
class GPTConfig:
    vocab_size: int = tokenizer.get_vocab_size()
    block_size: int = seq_len
    emb_size: int = 64
    heads: int = 8
    num_layers: int = 1
    attn_dropout: float = 0
    ff_mult: int = 1
    ff_dropout: float = 0


class AttentionHead(nn.Module):
    def __init__(self, config: GPTConfig, layer_idx, head_idx, cache_enabled=False):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_idx = head_idx
        self.cache_enabled = cache_enabled
        self.cache = {}

        self.q = nn.Linear(config.emb_size, config.emb_size)
        self.k = nn.Linear(config.emb_size, config.emb_size)
        self.v = nn.Linear(config.emb_size, config.emb_size)

        self.out = nn.Linear(config.emb_size, config.emb_size)

        self.attn_dropout = nn.Dropout(config.attn_dropout)

    def forward(self, x, mask=None):
        B, T, C = x.size()
        cache_key = (B, T, C)  # Example key; adjust based on your caching strategy

        if self.cache_enabled and cache_key in self.cache:
            k, v = self.cache[cache_key]["k"], self.cache[cache_key]["v"]
        else:
            k = (
                self.k(x)
                .view(B, T, self.config.heads, C // self.config.heads)
                .transpose(1, 2)
            )
            v = (
                self.v(x)
                .view(B, T, self.config.heads, C // self.config.heads)
                .transpose(1, 2)
            )
            if self.cache_enabled:
                self.cache[cache_key] = {"k": k, "v": v}

        if self.training:
            k = (
                self.k(x)
                .view(B, T, self.config.heads, C // self.config.heads)
                .transpose(1, 2)
            )
            v = (
                self.v(x)
                .view(B, T, self.config.heads, C // self.config.heads)
                .transpose(1, 2)
            )

        else:
            if cache_key in self.cache:
                k, v = self.cache[cache_key]["k"], self.cache[cache_key]["v"]

            else:
                k = (
                    self.k(x)
                    .view(B, T, self.config.heads, C // self.config.heads)
                    .transpose(1, 2)
                )
                v = (
                    self.v(x)
                    .view(B, T, self.config.heads, C // self.config.heads)
                    .transpose(1, 2)
                )

                self.cache[cache_key] = {"k": k, "v": v}

        q = (
            self.q(x)
            .view(B, T, self.config.heads, C // self.config.heads)
            .transpose(1, 2)
        )

        attn = (q @ k.transpose(-2, -1)) / ((C // self.config.heads) ** 0.5)

        if mask is not None:
            attn = attn.masked_fill(mask == 1, float("-inf"))

        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)

        x = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)

        return self.out(x), attn


class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, config: GPTConfig, layer_idx):
        super().__init__()
        self.config = config

        self.heads = nn.ModuleList(
            [
                AttentionHead(config, layer_idx=layer_idx, head_idx=i)
                for i in range(config.heads)
            ]
        )

    def forward(self, x, mask=None):
        # input and output are the same size
        attns = []
        for head in self.heads:
            attn, _ = head(x, mask=mask)
            attns.append(attn)

        return torch.mean(torch.stack(attns), dim=0)


class Block(nn.Module):
    def __init__(self, config: GPTConfig, layer_idx):
        super().__init__()
        self.config = config

        self.ln1 = nn.LayerNorm(config.emb_size)
        self.attn = MaskedMultiHeadAttention(config, layer_idx)

        self.ln2 = nn.LayerNorm(config.emb_size)
        self.ff = nn.Sequential(
            nn.Linear(config.emb_size, config.ff_mult * config.emb_size),
            nn.GELU(),
            nn.Linear(config.ff_mult * config.emb_size, config.emb_size),
        )

        if config.ff_dropout > 0:
            self.ff_dropout = nn.Dropout(config.ff_dropout)

        if config.attn_dropout > 0:
            self.attn_dropout = nn.Dropout(config.attn_dropout)

    def forward(self, x, mask=None):
        B, T, C = x.size()

        identity = x
        x = self.ln1(x)
        x = self.attn(x, mask=mask)

        if hasattr(self, "attn_dropout"):
            x = self.attn_dropout(x)

        x = x + identity

        identity = x
        x = self.ln2(x)
        x = self.ff(x)

        if hasattr(self, "ff_dropout"):
            x = self.ff_dropout(x)

        return x + identity


class GPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        self.token_emb = nn.Embedding(config.vocab_size, config.emb_size)
        self.pos_emb = nn.Embedding(config.block_size, config.emb_size)

        self.blocks = nn.ModuleList(
            [Block(config, layer_idx=i) for i in range(config.num_layers)]
        )

        self.ln = nn.LayerNorm(config.emb_size)
        self.head = nn.Linear(config.emb_size, config.vocab_size, bias=False)

        # tie weights
        self.head.weight = self.token_emb.weight

        # initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=0.02)

        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x):
        B, T = x.size()
        assert (
            not T > self.config.block_size
        ), "Sequence length is longer than block size"

        emb = self.token_emb(x)
        pe = self.pos_emb(torch.arange(T - 1, -1, step=-1, device=device))
        mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).unsqueeze(0)

        x = emb + pe

        for block in self.blocks:
            x = block(x, mask=mask)

        x = self.ln(x)
        return self.head(x)

    def loss(self, y, y_pred):
        # Input is a contiguous tensor
        y = y.flatten()
        y_pred = y_pred.view(-1, y_pred.size(-1))

        return F.cross_entropy(y_pred, y)

    def get_param_count(self):
        return sum(p.numel() for p in self.parameters())

    @torch.no_grad()
    def generate(
        self,
        start_seq: str,
        max_len: int = 128,
        temperature: float = 1.0,
        top_k: int = 50,
    ):
        self.eval()

        generated = tokenizer.encode(start_seq).ids
        primer_t = torch.as_tensor(generated, device=device).view(1, -1)

        for _ in range(max_len):
            if primer_t.size(1) >= self.config.block_size:
                primer_t = primer_t[:, -self.config.block_size :]

            out = self(primer_t)
            out = out[:, -1, :] / temperature
            out = F.softmax(out, dim=-1)
            out = torch.topk(out, top_k, dim=-1)[0]
            out = torch.multinomial(out, num_samples=1)

            gen = out.item()

            generated.append(gen)

            primer_t = torch.cat((primer_t, out), dim=1)

        return tokenizer.decode(generated), generated


config = GPTConfig()
model = GPT(config).to(device)
num_train_steps = 0

print(f"Model has {model.get_param_count():,} parameters")
print(model.generate("To be or not to be", max_len=128)[0])

del config

In [None]:
@torch.no_grad()
def evaluate(model, dataloader):
    model.eval()
    total_loss = 0
    pbar = tqdm.tqdm(dataloader, desc="Evaluation")
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = model.loss(y, y_pred).item()
        total_loss += loss
        pbar.set_postfix({"loss": loss})
    return total_loss / len(dataloader)

In [None]:
# import torch
from torch.optim.optimizer import Optimizer
import math
import torch.distributed as dist
from torch.optim.optimizer import _dispatch_sqrt

# device = 'cuda' if torch.cuda.is_available() else 'cpu'


class Adam_mini(Optimizer):
    def __init__(
            self,
            model=None,
            weight_decay=0.1,
            lr=1,
            beta1=0.9,
            beta2=0.999,
            epsilon=1e-8,
            model_sharding=False,
            n_embd=2048,
            n_head=32,
            n_query_groups=None
    ):
        '''
        model: the model you are training.

        model_sharding: set to True if you are using model parallelism with more than 1 GPU, including FSDP and zero_1,2,3 in Deepspeed. Set to False if otherwise.

        n_embd: number of embedding dimensions. Could be unspecified if you are training non-transformer models.

        n_head: number of attention heads. Could be unspecified if you are training non-transformer models.

        n_query_groups: number of query groups in Group query Attention. If not specified, it will be equal to n_head. Could be unspecified if you are training non-transformer models.
        '''

        self.n_embd = n_embd
        self.n_head = n_head
        if n_query_groups is not None:
            self.n_query_groups = n_query_groups
            assert self.n_head % self.n_query_groups == 0
        else:
            self.n_query_groups = self.n_head

        self.model = model
        self.world_size = torch.cuda.device_count()
        self.model_sharding = model_sharding
        if self.model_sharding:
            print("Adam-mini is using model_sharding")
        optim_groups = []
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                dic = {}
                dic["name"] = name
                dic["params"] = param
                if ("norm" in name or "ln_f" in name):
                    dic["weight_decay"] = 0
                else:
                    dic["weight_decay"] = weight_decay

                if ("self_attn.k_proj.weight" in name or "self_attn.q_proj.weight" in name or "attn.wq.weight" in name or "attn.wk.weight"):
                    dic["parameter_per_head"] = self.n_embd * self.n_embd // self.n_head

                if ("attn.attn.weight" in name or "attn.qkv.weight" in name):
                    dic["n_head"] = self.n_head
                    dic["q_per_kv"] = self.n_head // self.n_query_groups

                optim_groups.append(dic)

        defaults = dict(lr=lr, beta1=beta1, beta2=beta2, epsilon=epsilon)

        super(Adam_mini, self).__init__(optim_groups, defaults)

    def step(self):
        with torch.no_grad():
            for group in self.param_groups:
                beta1 = group["beta1"]
                beta2 = group["beta2"]
                lr = group["lr"]
                name = group["name"]
                epsilon = group["epsilon"]

                for p in group["params"]:
                    state = self.state[p]
                    if ("embed_tokens" in name or "wte" in name or "lm_head" in name):
                        if p.grad is None:
                            continue
                        if len(state) == 0:
                            state["m"] = torch.zeros_like(p.data).to(torch.float32)
                            state["iteration"] = 0
                            state["v"] = torch.zeros_like(p.data).to(torch.float32)

                        grad = p.grad.data.to(torch.float32)
                        state["v"].mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
                        state["iteration"] += 1
                        if group["weight_decay"] != 0:
                            p.data.mul_(1 - lr * group["weight_decay"])

                        state["m"].lerp_(grad, 1 - beta1)

                        bias_correction_1 = 1 - beta1 ** state["iteration"]
                        bias_correction_2 = 1 - beta2 ** state["iteration"]
                        bias_correction_2_sqrt = math.sqrt(bias_correction_2)

                        h = (state["v"].sqrt() / bias_correction_2_sqrt).add_(epsilon)
                        stepsize = lr / bias_correction_1
                        p.addcdiv_(state["m"], h, value=-stepsize)

                    elif (
                            "self_attn.k_proj.weight" in name or "self_attn.q_proj.weight" in name or "attn.wq.weight" in name or "attn.wk.weight" in name):
                        if p.grad is None:
                            continue
                        dim = group["parameter_per_head"]
                        if (len(state) == 0):
                            state["m"] = torch.zeros_like(p.data).to(torch.float32)
                            state["m"] = state["m"].view(-1, dim)
                            state['head'] = state['m'].shape[0]
                            state["iteration"] = 0
                            state["vmean"] = torch.zeros(state['head']).to(device)

                        grad = p.grad.data.to(torch.float32)
                        head = state['head']
                        grad = grad.view(head, dim)

                        tmp_lr = torch.mean(grad * grad, dim=1).to(device)
                        state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2)
                        v = state["vmean"]

                        state["iteration"] += 1
                        if group["weight_decay"] != 0:
                            p.data.mul_(1 - lr * group["weight_decay"])

                        state["m"].lerp_(grad, 1 - beta1)

                        bias_correction_1 = 1 - beta1 ** state["iteration"]
                        bias_correction_2 = 1 - beta2 ** state["iteration"]
                        bias_correction_2_sqrt = math.sqrt(bias_correction_2)

                        h = (v.sqrt() / bias_correction_2_sqrt).add_(epsilon)
                        stepsize = ((1 / bias_correction_1) / h).view(head, 1)

                        update = state["m"] * (stepsize.to(state['m'].device))

                        if p.dim() > 1:
                            d0, d1 = p.size()
                            update = update.view(d0, d1)
                        else:
                            update = update.view(-1)

                        update.mul_(lr)
                        p.add_(-update)

                    elif ("attn.attn.weight" in name or "attn.qkv.weight" in name):
                        if p.grad is None:
                            continue
                        if (len(state) == 0):
                            state["m"] = torch.zeros_like(p.data).to(torch.float32)
                            state["m"] = state["m"].view(group["n_head"], group["q_per_kv"] + 2, -1)
                            state["iteration"] = 0
                            state["vmean"] = torch.zeros(group["n_head"], group["q_per_kv"] + 2).to(device)

                        grad = p.grad.data.to(torch.float32)
                        grad = grad.view(group["n_head"], group["q_per_kv"] + 2, -1)

                        tmp_lr = torch.mean(grad * grad, dim=2).to(device)
                        state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2)
                        v = state["vmean"]

                        print(f'name = {name} tmp_lr = {tmp_lr.size()}, vmean = {v.size()}')

                        state["iteration"] += 1
                        if group["weight_decay"] != 0:
                            p.data.mul_(1 - lr * group["weight_decay"])

                        state["m"].lerp_(grad, 1 - beta1)

                        bias_correction_1 = 1 - beta1 ** state["iteration"]
                        bias_correction_2 = 1 - beta2 ** state["iteration"]
                        bias_correction_2_sqrt = math.sqrt(bias_correction_2)

                        h = (v.sqrt() / bias_correction_2_sqrt).add_(epsilon)
                        stepsize = ((1 / bias_correction_1) / h).view(group["n_head"], group["q_per_kv"] + 2, 1)

                        update = state["m"] * (stepsize.to(state['m'].device))

                        if p.dim() > 1:
                            d0, d1 = p.size()
                            update = update.view(d0, d1)
                        else:
                            update = update.view(-1)

                        update.mul_(lr)
                        p.add_(-update)


                    else:
                        if (len(state) == 0):
                            dimension = torch.tensor(p.data.numel()).to(device).to(torch.float32)
                            reduced = False
                            if (self.world_size > 1) and (self.model_sharding is True):
                                tensor_list = [torch.zeros_like(dimension) for _ in range(self.world_size)]
                                dist.all_gather(tensor_list, dimension)
                                s = 0
                                dimension = 0
                                for d in tensor_list:
                                    if (d > 0):
                                        s = s + 1
                                    dimension = dimension + d
                                if (s >= 2):
                                    reduced = True

                            state["m"] = torch.zeros_like(p.data).to(torch.float32)
                            state["iteration"] = 0
                            state["reduced"] = reduced
                            state["vmean"] = torch.tensor(0.0).to(device)
                            state["dimension"] = dimension.item()
                        if p.grad is None:
                            tmp_lr = torch.tensor(0.0).to(device)
                        else:
                            grad = p.grad.data.to(torch.float32)
                            tmp_lr = torch.sum(grad * grad).to(device)
                        if (state["reduced"]):
                            dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM)
                        if (p.grad is None):
                            continue
                        tmp_lr = tmp_lr / (state["dimension"])
                        tmp_lr = tmp_lr.to(grad.device)

                        if group["weight_decay"] != 0:
                            p.data.mul_(1 - lr * group["weight_decay"])
                        state["iteration"] += 1
                        state["m"].lerp_(grad, 1 - beta1)

                        bias_correction_1 = 1 - beta1 ** state["iteration"]
                        bias_correction_2 = 1 - beta2 ** state["iteration"]
                        bias_correction_2_sqrt = math.sqrt(bias_correction_2)
                        state["vmean"] = (1 - beta2) * tmp_lr + beta2 * state["vmean"]
                        h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(epsilon)

                        stepsize = (1 / bias_correction_1) / h
                        update = state["m"] * (stepsize.to(state['m'].device))
                        update.mul_(lr)
                        p.add_(-update)

In [None]:
optimizer = Adam_mini(model=model, lr=5e-4, weight_decay=0.01, n_embd=model.config.emb_size, n_head=model.config.heads)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5)

roll_loss = 0
val_loss = 0

In [None]:
for epoch in range(100):
    pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=True)
    if roll_loss == 0:
        pass
    else:
        val_loss = evaluate(model, val_loader)
        scheduler.step(val_loss)

    model.train()
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = model.loss(y, y_pred)
        loss.backward()
        optimizer.step()
        
        num_train_steps += 1
        roll_loss = 0.9 * roll_loss + 0.1 * loss.item()

        pbar.set_postfix_str(f"loss: {roll_loss:.4f}, val_loss: {val_loss:.2e}, steps: {num_train_steps:,}")

        # assert num_train_steps != 100, "Stop training"

    scheduler.step(roll_loss)

In [None]:
for name, param in model.named_parameters():
    print(f"{name}: {param.grad.mean().item():4f}, {param.grad.std().item():4f}")

In [None]:
with torch.no_grad():
    model.eval()
    print(model.generate(
        "The Project Gutenberg eBook",
        max_len=128,
        
        top_k=1000
    )[0])

In [None]:
# plot postion embeddings
import matplotlib.pyplot as plt

pos_emb = model.pos_emb.weight.detach().cpu().numpy()
plt.figure(figsize=(20, 5))
plt.imshow(pos_emb, aspect="auto", cmap="RdYlGn")
plt.colorbar()
plt.title("Position Embeddings")
plt.xlabel("Embedding Dimension")
plt.ylabel("Position")
plt.show()


In [None]:
plt.figure(figsize=(20, 5))
plt.plot(model.pos_emb.weight.detach().cpu().numpy()[:, 0])
plt.title("Position Embedding 0")
plt.xlabel("Position")
plt.ylabel("Value")
plt.show()

In [None]:
# plot attention heads
x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)

model.eval()
B, T = x.size()

emb = model.token_emb(x)
pe = model.pos_emb(torch.arange(T - 1, -1, step=-1, device=device))

x = emb + pe

attns = []
for block in model.blocks:
    x = block.ln1(x)
    for head in block.attn.heads:
        attn, _ = head(x)
        attns.append(attn)

        

In [None]:
plt.figure(figsize=(20, 5))

for i in range(4):
    plt.subplot(1, 4, i + 1)
    plt.imshow(attns[i][:, :, 0].detach().cpu().numpy(), aspect="auto", cmap="RdYlGn")
    plt.title(f"Attention Head {i}")
    plt.xlabel("Query Position")
    plt.ylabel("Key Position")

plt.colorbar()
plt.tight_layout()
plt.show()