In [51]:
import torch
import torch.nn as nn
from torchinfo import summary
from tqdm import tqdm

In [52]:
with open("brackets.txt", "r") as f:
    brackets = f.readlines()
idxs = list(map(lambda line: list(map(lambda val: int(val), line.split())), brackets))
idxs = torch.tensor(idxs, dtype=torch.long)
idxs.shape

torch.Size([60000, 64])

In [53]:
train_val_split = 0.8
train_size = int(train_val_split * len(idxs))

train_idxs = idxs[:train_size]
val_idxs = idxs[train_size:]

train_dataset = torch.utils.data.TensorDataset(train_idxs[:, :-1], train_idxs[:, 1:])
val_dataset = torch.utils.data.TensorDataset(val_idxs[:, :-1], val_idxs[:, 1:])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=False)

In [54]:
idx_to_char = {
    0: "{",
    1: "(",
    2: "[",
    3: "<",
    4: "}",
    5: ")",
    6: "]",
    7: ">",
    8: "SOS ",
    9: " EOS",
    10: "_"
}
char_to_idx = {v: k for k, v in idx_to_char.items()}
def decode_brackets(brackets):
    brackets = brackets.tolist()
    return "".join([idx_to_char[idx] for idx in brackets])

decode_brackets(idxs[0])

'SOS [<{[[<{{[([[[[<{<<<[({(<<(((<<{}>>)))>>)})]>>>}>]]]])]}}>]]}>] EOS'

In [55]:
import difflib

def show_diff(seq1, seq2):
    diff = difflib.ndiff(seq1, seq2)
    diff = list(diff)
    print("".join(seq1), end="")
    diff = reversed(diff)
    for d in diff:
        if d[0] == "-":
            idx = char_to_idx[d[1:].strip()] + 4
            print(f"\033[91m{idx_to_char[idx]}\033[0m", end="")
        elif d[0] == "+":
            idx = char_to_idx[d[1:].strip()] + 4
            print(f"\033[92m{idx_to_char[idx]}\033[0m", end="")
        else:
            idx = char_to_idx[d.strip()] + 4
            print(idx_to_char[idx], end="")

def compute_diff(seq):
    seq = decode_brackets(seq).strip("_").strip("[SOS] ").strip(" [EOS]")
    print(seq)
    brack_open = seq[:len(seq) // 2]
    id_open = torch.tensor([char_to_idx[char] for char in brack_open], dtype=torch.long)
    id_close = reversed(torch.tensor([char_to_idx[char] for char in seq[id_open.shape[0]:]], dtype=torch.long)) - 4
    show_diff([idx_to_char[idx] for idx in id_open.tolist()], [idx_to_char[idx] for idx in id_close.tolist()])


# sequence1 = "<{[[<{{[([[[[<{<<<[({("
# sequence2 = "{[[<{{[([[[[<{<<<[(((("

# show_diff(sequence1, sequence2)

compute_diff(idxs[0])

<{[[<{{[([[[[<{<<<[({(<<(((<<{}>>)))>>)})]>>>}>]]]])]}}>]]}>
<{[[<{{[([[[[<{<<<[({(<<(((<<{}>>)))>>)})]>>>}>]]]])]}}>]]}>

In [56]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

print(f"Using device: {device}")

@torch.no_grad()
def evaluate(model, data_loader):
    model.eval()
    total_loss = 0
    for inputs, targets in data_loader:
        x, y = inputs.to(device), targets.to(device)

        y_pred = model(x)
        loss = model.loss(y_pred, y)
        total_loss += loss
       
    return total_loss.item() / len(data_loader)

Using device: mps


In [60]:
from torch.nn import functional as F

class SelfAttention(nn.Module):
    kv_cache: dict[str, torch.Tensor] | None

    def __init__(self, n_embd: int, n_head: int, attn_dropout: float = 0.0, is_causal: bool = True):
        super().__init__()
        assert n_embd % n_head == 0, f"Embedding dimension {n_embd} should be divisible by number of heads {n_head}"

        self.kv_cache = None

        self.c_attn = nn.Linear(n_embd, n_embd * 3)
        self.c_proj = nn.Linear(n_embd, n_embd)
        
        self.n_head = n_head
        self.n_embd = n_embd
        self.attn_dropout = attn_dropout
        self.is_causal = is_causal

    def forward(self, x) -> torch.Tensor:
        B, T, C = x.size()
        
        if self.kv_cache is None or self.kv_cache["k"].shape[0] == 0:
            qkv = self.c_attn(x)
            q, k, v = qkv.split(self.n_embd, dim=2)

            # (B, T, C) -> (B, T, n_head, C // n_head) -> (B, n_head, T, C // n_head)
            k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
            q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
            v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

            y = (
                F.scaled_dot_product_attention(
                    q, k, v, is_causal=self.is_causal, dropout_p=self.attn_dropout
                )
                .transpose(1, 2)
                .contiguous()
                .view(B, T, C)
            )

            # output projection
            y = self.c_proj(y)

            if self.kv_cache is not None:
                # print("Setting cache")
                self.kv_cache["k"] = k
                self.kv_cache["v"] = v

            return y
        
        else:
            # print("Using cache")
            if self.kv_cache["k"].shape[0] != B:
                self.kv_cache["k"] = self.kv_cache["k"][:B]
                self.kv_cache["v"] = self.kv_cache["v"][:B]
            
            qkv = self.c_attn(x)
            q, _, _ = qkv.split(self.n_embd, dim=2)
            q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

            k = self.kv_cache["k"].to(q.device)
            v = self.kv_cache["v"].to(q.device)
            # print(q.shape, k.shape, v.shape)
            y = (
                F.scaled_dot_product_attention(
                    q, k, v, is_causal=self.is_causal, dropout_p=self.attn_dropout
                )
                .transpose(1, 2)
                .contiguous()
                .view(B, T, C)
            )

            # output projection
            y = self.c_proj(y)
            return y

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, emb_dim, atten_dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.emb_dim = emb_dim

        self.heads = nn.ModuleList([
            SelfAttention(emb_dim, num_heads, atten_dropout) for _ in range(num_heads)
        ])

        self.fc = nn.Linear(emb_dim * num_heads, emb_dim)

    def forward(self, x):
        heads = [head(x) for head in self.heads]
        x = torch.cat(heads, dim=-1)
        x = self.fc(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, m=4):
        super().__init__()
        self.attention = nn.Sequential(
            MultiHeadAttention(num_heads, emb_dim),
        )
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)

        self.fc = nn.Sequential(
            nn.Linear(emb_dim, m * emb_dim),
            nn.SiLU(),
            nn.Linear(m * emb_dim, emb_dim),
        )

    def forward(self, x):
        x = x + self.attention(self.ln1(x))
        x = x + self.fc(self.ln2(x))
        return x

class Model(nn.Module):
    def __init__(self, vocab=11, emb_dim=8, seq_len=64):
        super().__init__()
        self.vocab = vocab
        self.emb_dim = emb_dim
        self.seq_len = seq_len

        self.tok_emb = nn.Embedding(vocab, emb_dim)
        self.pos_emb = nn.Parameter(torch.randn(seq_len, emb_dim))

        self.blocks = nn.Sequential(
            TransformerBlock(emb_dim, num_heads=1, m=1),
            TransformerBlock(emb_dim, num_heads=2, m=1),
            TransformerBlock(emb_dim, num_heads=4, m=2),
            TransformerBlock(emb_dim, num_heads=8, m=2),
        )

        self.lm_head = nn.Linear(emb_dim, vocab, bias=False)
        
        # init weights
        self.lm_head.weight = self.tok_emb.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, x):
        if x.shape[1] > self.seq_len:
            x = x[:, -self.seq_len:]
        
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb[:tok_emb.shape[1]]
        
        x = tok_emb + pos_emb

        for block in self.blocks:
            x = block(x)

        x = self.lm_head(x)
        return x
    
    def loss(self, y_pred, y):
        return torch.nn.functional.cross_entropy(y_pred.reshape(-1, self.vocab), y.reshape(-1))
    
    @torch.no_grad()
    def generate(self, start: list[int] | torch.Tensor | None = None, max_len: int = 100, temperature: float = 1.0, top_k: int = 0, use_cache: bool = False):
        if start is None:
            start = torch.randint(self.vocab, (1, 1), device=device)
        elif isinstance(start, list):
            start = torch.tensor(start, dtype=torch.long, device=device).unsqueeze(0)

        if use_cache:
            self.toggle_kv_cache(True)

        x = start

        for _ in tqdm(range(max_len)):
            y_pred = self(x)
            y_pred = y_pred[:, -1, :] / temperature
            if top_k > 0:
                y_pred = torch.topk(y_pred, top_k, dim=-1).values
            next_char = torch.multinomial(torch.nn.functional.softmax(y_pred, dim=-1), 1)
            x = torch.cat([x, next_char], dim=1)
        
        self.toggle_kv_cache(False)
        return x
    
    def toggle_kv_cache(self, value: bool):
        if value:
            self.blocks.apply(lambda module: setattr(module, "kv_cache", {"k": torch.empty(0), "v": torch.empty(0)}))
        else:
            self.blocks.apply(lambda module: setattr(module, "kv_cache", None))

model = Model().to(device)
# model.toggle_kv_cache(True)
print(decode_brackets(model.generate(max_len=4, use_cache=True).squeeze()))
print(f"Evaluation loss: {evaluate(model, val_loader):.4f}")
summary(model, (1, 63), dtypes=[torch.long], device=device, depth=2)

100%|██████████| 4/4 [00:00<00:00, 59.03it/s]


}<><_
Evaluation loss: 2.3978


Layer (type:depth-idx)                             Output Shape              Param #
Model                                              [1, 63, 11]               424
├─Embedding: 1-1                                   [1, 63, 8]                88
├─Sequential: 1-2                                  --                        --
│    └─TransformerBlock: 2-1                       [1, 63, 8]                536
│    └─TransformerBlock: 2-2                       [1, 63, 8]                888
│    └─TransformerBlock: 2-3                       [1, 63, 8]                1,728
│    └─TransformerBlock: 2-4                       [1, 63, 8]                3,136
├─Linear: 1-3                                      [1, 63, 11]               88
Total params: 6,888
Trainable params: 6,888
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.01
Input size (MB): 0.00
Forward/backward pass size (MB): 0.34
Params size (MB): 0.03
Estimated Total Size (MB): 0.37

In [61]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, fused=True)

RuntimeError: `fused=True` requires all the params to be floating point Tensors of supported devices: ['cuda', 'xpu', 'privateuseone'].

In [62]:
%pip show torch

Name: torch
Version: 2.3.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: /Users/karan/projects/playground/.venv/lib/python3.11/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: lightning, nn-zoo, pytorch-lightning, torchmetrics, torchvision
Note: you may need to restart the kernel to use updated packages.


In [59]:
if "val_loss" in locals():
    pass
else:
    val_loss = float("inf")


for epoch in range(100):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for x, y in pbar:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = model.loss(y_pred, y)
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item(), val_loss=val_loss)

    val_loss = evaluate(model, val_loader)
    print(f"Validation loss: {val_loss:.4f}")

Epoch 0:  47%|████▋     | 88/188 [00:17<00:20,  4.94it/s, loss=1.68, val_loss=1.36]


KeyboardInterrupt: 

In [50]:
compute_diff(model.generate(max_len=64).squeeze())

100%|██████████| 64/64 [00:00<00:00, 86.20it/s]

><{({{<((<{(<[<<(<{(<{<({<<{<{[(>}}>)]]>>>}>}]>>]>})>]}}])}}]>}





KeyError: -3

In [None]:
import matplotlib.pyplot as plt

# visualize embeddings
with torch.no_grad():
    emb = model.tok_emb.weight.cpu().numpy()
    plt.figure(figsize=(10, 10))
    plt.imshow(emb, cmap="viridis")
    plt.colorbar()
    plt.title("Embeddings")
    plt.xlabel("Embedding dimension")
    plt.ylabel("Vocabulary index")
    plt.yticks(list(idx_to_char.keys()), list(idx_to_char.values()))
    plt.show()