In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()     
        self.layer_idx = layer_idx
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        assert self.n_embd % self.n_head == 0
        self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_k = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_v = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)

    def forward(self, x):
        B, T, C =  x.size()

        q = self.c_q(x).view(B, T, self.n_head, self.head_dim) # B, T, H, D
        k = self.c_k(x).view(B, T, self.n_head, self.head_dim) # B, T, H, D
        v = self.c_v(x).view(B, T, self.n_head, self.head_dim) # B, T, H, D

        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)  # B, H, T, D

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # B, H, T, D

        y = y.transpose(1, 2).contiguous().view(B, T, -1) # B, H, T, D -> B, T, H, D -> B, T, C

        y = self.c_proj(y) # B, T, C
        
        return y

In [None]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4* config.n_embd, bias=False)
        self.c_proj = nn.Linear(4*config.n_embd, config.n_embd, bias=False)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()
        x = self.c_proj(x)
        return x


In [None]:
def norm(x):
    return F.rms_norm(x, (x.size(-1),))

In [None]:
class Block(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.attn = CausalSelfAttention(config, layer_idx)
        self.mlp = MLP(config)
    
    def forward(self, x):
        x = x + self.attn(norm(x))
        x = x + self.mlp(norm(x))
        return x

In [None]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict({
            "wte": nn.Embedding(config.vocab_size, config.n_embd),
            "pte": nn.Embedding(config.sequence_len, config.n_embd),
            "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)])
        })
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        head_size = config.n_embd // config.n_head
    
    def forward(self, idx, targets=None, loss_reduction='mean'):
        B,T = idx.size()
        x = self.transformer.wte(idx) + self.transformer.pte(torch.arange(T, device='cuda'))
        x = norm(x)
        for block in self.transformer.h:
            x = block(x)
        
        x = norm(x)
        if targets is not None:
            logits = self.lm_head(x)
            logits = logits.float()
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
            return loss, logits

        else:
            logits = self.lm_head(x)
            return logits
    
    def generate(self, idx, seq_len):

        for _ in range(seq_len):
            logits = self(idx)
            last_idx = logits[:, -1, :]
            last_token = F.softmax(last_idx, dim=-1)
            idx_next = torch.multinomial(last_token, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [None]:
# Data loading from shakespeare dataset

# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
#!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [None]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [None]:
vocab_size = len(set(text))

train_size = int(0.9 * len(text))

In [None]:
encode = {}
decode = {}
for i, t in enumerate(set(text)):
    encode[t] = i
    decode[i] = t

In [None]:
encoder = lambda text: [encode[x] for x in text]
decoder = lambda ids: ''.join(decode[i] for i in ids)

In [None]:
data = encoder(text)
train_data = data[:train_size]
val_data = data[train_size+1:]

In [None]:
block_size = 256
batch_size = 64

In [None]:
import torch
torch.manual_seed(1443)

def get_data(split='train', device='cuda'):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data)-block_size, (batch_size,))
    x = torch.stack([torch.tensor(data[i:i+block_size]) for i in ix])
    y = torch.stack([torch.tensor(data[i+1:i+block_size+1]) for i in ix])
    if device == 'cuda':
        return x.cuda(), y.cuda()
    return x, y

In [None]:
train_x, train_y = get_data('train')

In [None]:
from dataclasses import dataclass

@dataclass
class GPTConfig:
    sequence_len: int = block_size
    vocab_size: int = vocab_size
    n_layer :int = 12
    n_head: int = 6
    n_kv_head: int = 6
    n_embd: int = 768

In [None]:
model = GPT(GPTConfig)

In [None]:
model = model.cuda()

In [None]:
optim = torch.optim.AdamW(model.parameters(),lr=1e-4)

In [None]:
n_batches = len(data)//batch_size
n_epochs = 10
seq_len = 100

losses = {}

for epoch in range(n_epochs):
    for batch in range(100):
        optim.zero_grad(set_to_none=True)
        train_x, train_y = get_data('train', device='cuda')
        loss, logits = model.forward(train_x, train_y)
        loss.backward()
        optim.step()
        print(f'Epoch: {epoch}, Batch: {batch}, Loss: {loss}')
        losses[epoch]= losses.get(epoch, []) + [loss.item()]
    print(decoder(model.generate(idx = torch.zeros((1, 1), dtype=torch.long).cuda(), seq_len=100)[0].tolist()))

In [None]:
print(decoder(model.generate(idx = torch.zeros((1, 1), dtype=torch.long).cuda(), seq_len=2000)[0].tolist()))

In [None]:
loss_arr = []
for i, t in losses.items():
    loss_arr += t

In [None]:
import matplotlib.pyplot as plt

plt.plot(loss_arr)
plt.title("Training loss without PE")
plt.savefig("Training loss without PE")

In [None]:
# Byte-Pair Encoding

def get_stats(ids, counts=None):
    counts = {} if not counts else counts
    for a, b in zip(ids, ids[1:]):
        counts[(a,b)] = counts.get((a,b)) + 1
    return counts

def merge(ids, pair, idx):
    i = 0
    new_ids = []
    while i <= len(ids)-1:
        if ids[i]==pair[0] and ids[i+1]==pair[1]:
            new_ids.append(idx)
            idx+=1
            i+=2
        else:
            new_ids.append(ids[i])
            i+=1
    return new_ids

def train(text, num_merges):
    text_bytes = text.encode("utf-8")
    ids = list(text_bytes)
    merges = {}
    vocab ={idx: bytes(
        [idx]
    ) for idx in range(256)}

    for i in range(num_merges):
        stats = get_stats(ids)
        pair = max(stats, key=stats.get)
        idx = 256 + i
        ids = merge(ids, pair, idx)
        merges[pair] = idx
        vocab[idx] = vocab[ids[0]] + vocab[ids[1]]
    
    return merges, vocab

def decode(ids):
    text_bytes = "".join(vocab[idx] for idx in ids)
    text = text_bytes.decode("utf-8", errors="replace")
    return text

def encode(text):
    text_bytes = text.encode("utf-8")
    ids = list(text_bytes)

    while len(ids) >= 2:
        stats = get_stats(ids)
        pair = min(stats, key=lambda p: merges.get(p, float('inf')))
        if pair not in merges:
            return
        idx = merges[pair]
        ids = merge(ids, pair, idx)
    return ids