In [None]:
import torch
import torch.nn as nn
from torchinfo import summary
from tqdm import tqdm

In [None]:
import urllib.request
text = urllib.request.urlopen("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt").read().decode("utf-8")

vocab = sorted(set(text))
char_to_idx = {char: idx for idx, char in enumerate(vocab)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}
text = torch.tensor([char_to_idx[char] for char in text], dtype=torch.long)
print(f"Vocab size: {len(vocab):,}")

seq_len = 64
split = 0.2

train_len = int((1 - split) * len(text))

train_text = text[:train_len]
val_text = text[train_len:]
print(f"Train size: {len(train_text):,}")
print(f"Val size: {len(val_text):,}")

train_text = train_text[:len(train_text) // seq_len * seq_len].view(-1, seq_len)
val_text = val_text[:len(val_text) // seq_len * seq_len].view(-1, seq_len)
print(f"Train size: {train_text.shape}")
print(f"Val size: {val_text.shape}")

train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(train_text),
    batch_size=256,
    shuffle=True,
    num_workers=2,
    persistent_workers=True,
    pin_memory=True,

)
val_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(val_text),
    batch_size=128,
    shuffle=False,
    num_workers=2,
    persistent_workers=True,
    pin_memory=True,
)

In [None]:
def encode(text: str) -> torch.Tensor:
    return torch.tensor([char_to_idx[char] for char in text], dtype=torch.long)

def decode(tensor: torch.Tensor) -> str:
    tensor = tensor.tolist()
    return "".join([idx_to_char[idx] for idx in tensor])

encoded = encode("Hello, World!")
decoded = decode(encoded)
assert decoded == "Hello, World!"

In [None]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

print(f"Using device: {device}")

@torch.no_grad()
def evaluate(model, data_loader):
    model.eval()
    total_loss = 0
    for inputs, in data_loader:
        inputs = inputs.to(device)

        x = inputs[:, :-1]
        y = inputs[:, 1:]

        y_pred = model(x)
        loss = model.loss(y_pred, y)
        total_loss += loss
       
    return total_loss.item() / len(data_loader)

In [None]:
from nn_zoo.models.components import SelfAttention


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, emb_dim, atten_dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.emb_dim = emb_dim

        self.heads = nn.ModuleList([
            SelfAttention(emb_dim, num_heads, atten_dropout) for _ in range(num_heads)
        ])

        self.fc = nn.Linear(emb_dim * num_heads, emb_dim)

    def forward(self, x):
        heads = [head(x) for head in self.heads]
        x = torch.cat(heads, dim=-1)
        x = self.fc(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, m=4):
        super().__init__()
        self.attention = nn.Sequential(
            MultiHeadAttention(num_heads, emb_dim),
        )
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)

        self.fc = nn.Sequential(
            nn.Linear(emb_dim, m * emb_dim),
            nn.SiLU(),
            nn.Linear(m * emb_dim, emb_dim),
        )

    def forward(self, x):
        x = x + self.attention(self.ln1(x))
        x = x + self.fc(self.ln2(x))
        return x

class Model(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.tok_emb = nn.Embedding(len(vocab), emb_dim)
        self.pos_emb = nn.Parameter(torch.randn(seq_len, emb_dim))

        self.blocks = nn.Sequential(
            TransformerBlock(emb_dim, num_heads=4, m=4),
            TransformerBlock(emb_dim, num_heads=4, m=4),
            TransformerBlock(emb_dim, num_heads=4, m=4),
            TransformerBlock(emb_dim, num_heads=4, m=4),
        )

        self.lm_head = nn.Linear(emb_dim, len(vocab), bias=False)
        
        # init weights
        self.lm_head.weight = self.tok_emb.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, x):
        if x.shape[1] > seq_len:
            x = x[:, -seq_len:]

        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb[:tok_emb.shape[1]]
        
        x = tok_emb + pos_emb

        for block in self.blocks:
            x = block(x)

        x = self.lm_head(x)
        return x
    
    def loss(self, y_pred, y):
        return torch.nn.functional.cross_entropy(y_pred.reshape(-1, y_pred.shape[-1]), y.reshape(-1), ignore_index=-1)
    
    @torch.no_grad()
    def generate(self, start: list[int] | torch.Tensor | None = None, max_len: int = 100, temperature: float = 1.0, top_k: int = 0):
        if start is None:
            start = torch.randint(len(vocab), (1, 1), device=device)
        elif isinstance(start, list):
            start = torch.tensor(start, dtype=torch.long, device=device).unsqueeze(0)
        x = start

        for _ in tqdm(range(max_len)):
            y_pred = self(x)
            y_pred = y_pred[:, -1, :] / temperature
            if top_k > 0:
                y_pred = torch.topk(y_pred, top_k, dim=-1).values
            next_char = torch.multinomial(torch.nn.functional.softmax(y_pred, dim=-1), 1)
            x = torch.cat([x, next_char], dim=1)
        
        return x
    
model = Model(64).to(device)
# print(decode(model.generate(max_len=128).squeeze()))
print(f"Evaluation loss: {evaluate(model, val_loader):.4f}")
summary(model, (64, 63), dtypes=[torch.long], device=device, depth=4)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
if "val_loss" in locals():
    pass
else:
    val_loss = float("inf")

pbar = tqdm(range(1000))
for epoch in pbar:
    model.train()
    
    for inputs, in train_loader:
        inputs = inputs.to(device)

        # for _ in range(10):
        x = inputs[:, :-1]
        y = inputs[:, 1:]

        optimizer.zero_grad()
            
        y_pred = model(x)
        loss = model.loss(y_pred, y)

        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=loss.item(), val_loss=val_loss)

    val_loss = evaluate(model, val_loader)

In [None]:
print(decode(model.generate(max_len=128).squeeze()))

In [None]:
# for named parameters, print gradient statistics
print("Gradient statistics")
print(f"{'Name':30} {'Mean':<7} {'Std':<7} {'Norm':<7}")
for name, param in model.named_parameters():
    print(f"{name:30} {param.grad.mean():.5f} {param.grad.std():.5f} {param.grad.norm():.5f}")

    # clip gradient norm
    

In [None]:
@torch.no_grad()
def clamp_gradients(model, max_value: float):
    for param in model.parameters():
        param.grad.clamp_(-max_value, max_value)

In [None]:
clamp_gradients(model, 1e-3)