In [1]:
import torch
import torch.nn as nn
from torchinfo import summary
from tqdm import tqdm

In [2]:
with open("brackets.txt", "r") as f:
    brackets = f.readlines()
idxs = list(map(lambda line: list(map(lambda val: int(val), line.split())), brackets))
idxs = torch.tensor(idxs, dtype=torch.long)
idxs.shape

torch.Size([60000, 64])

In [9]:
train_val_split = 0.8
train_size = int(train_val_split * len(idxs))

train_idxs = idxs[:train_size]
val_idxs = idxs[train_size:]

train_dataset = torch.utils.data.TensorDataset(train_idxs)
val_dataset = torch.utils.data.TensorDataset(val_idxs)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=False)

In [26]:
idx_to_char = {
    0: "{",
    1: "(",
    2: "[",
    3: "<",
    4: "}",
    5: ")",
    6: "]",
    7: ">",
    8: " MASK ",
    9: "_"
}
char_to_idx = {v: k for k, v in idx_to_char.items()}
def decode_brackets(brackets):
    brackets = brackets.tolist()
    return "".join([idx_to_char[idx] for idx in brackets])

decode_brackets(idxs[0])

'[<{[[<{{[([[[[<{<<<[({(<<(((<<{{}}>>)))>>)})]>>>}>]]]])]}}>]]}>]'

In [27]:
import difflib

def show_diff(seq1, seq2):
    diff = difflib.ndiff(seq1, seq2)
    diff = list(diff)
    print("".join(seq1), end="")
    diff = reversed(diff)
    for d in diff:
        if d[0] == "-":
            idx = char_to_idx[d[1:].strip()] + 4
            print(f"\033[91m{idx_to_char[idx]}\033[0m", end="")
        elif d[0] == "+":
            idx = char_to_idx[d[1:].strip()] + 4
            print(f"\033[92m{idx_to_char[idx]}\033[0m", end="")
        else:
            idx = char_to_idx[d.strip()] + 4
            print(idx_to_char[idx], end="")

def compute_diff(seq):
    seq = decode_brackets(seq).strip("_").strip("[SOS] ").strip(" [EOS]")
    print(seq)
    brack_open = seq[:len(seq) // 2]
    id_open = torch.tensor([char_to_idx[char] for char in brack_open], dtype=torch.long)
    id_close = reversed(torch.tensor([char_to_idx[char] for char in seq[id_open.shape[0]:]], dtype=torch.long)) - 4
    show_diff([idx_to_char[idx] for idx in id_open.tolist()], [idx_to_char[idx] for idx in id_close.tolist()])


compute_diff(idxs[0])

<{[[<{{[([[[[<{<<<[({(<<(((<<{{}}>>)))>>)})]>>>}>]]]])]}}>]]}>
<{[[<{{[([[[[<{<<<[({(<<(((<<{{}}>>)))>>)})]>>>}>]]]])]}}>]]}>

In [12]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

print(f"Using device: {device}")

@torch.no_grad()
def evaluate(model, data_loader):
    model.eval()
    total_loss = 0
    for inputs, in data_loader:
        x = inputs.to(device)

        # mask input
        y = x.clone()
        idx = torch.randint(1, x.shape[1] - 1, (1,))
        x[:, idx:] = 9
        
        y_pred = model(x)
        loss = F.cross_entropy(y_pred[:, :-1].reshape(-1, y_pred.shape[-1]), y[:, 1:].reshape(-1))
        total_loss += loss
       
    return total_loss.item() / len(data_loader)

Using device: mps


In [19]:
from torch.nn import functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, emb_dim, atten_dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.emb_dim = emb_dim

        self.heads = nn.ModuleList([
            SelfAttention(emb_dim, num_heads, atten_dropout) for _ in range(num_heads)
        ])

        self.fc = nn.Linear(emb_dim * num_heads, emb_dim)

    def forward(self, x):
        heads = [head(x) for head in self.heads]
        x = torch.cat(heads, dim=-1)
        x = self.fc(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, m=4):
        super().__init__()
        self.attention = nn.Sequential(
            MultiHeadAttention(num_heads, emb_dim),
        )
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)

        self.fc = nn.Sequential(
            nn.Linear(emb_dim, m * emb_dim),
            nn.SiLU(),
            nn.Linear(m * emb_dim, emb_dim),
        )

    def forward(self, x):
        x = x + self.attention(self.ln1(x))
        x = x + self.fc(self.ln2(x))
        return x

class Model(nn.Module):
    def __init__(self, vocab=8, emb_dim=4, seq_len=64):
        super().__init__()
        self.vocab = vocab
        self.emb_dim = emb_dim
        self.seq_len = seq_len

        self.tok_emb = nn.Embedding(vocab, emb_dim)
        self.pos_emb = nn.Parameter(torch.randn(seq_len, emb_dim))

        self.blocks = nn.Sequential(
            TransformerBlock(emb_dim, num_heads=2, m=2),
            TransformerBlock(emb_dim, num_heads=2, m=2),
            TransformerBlock(emb_dim, num_heads=2, m=2),
            TransformerBlock(emb_dim, num_heads=2, m=2),
        )

        self.lm_head = nn.Linear(emb_dim, vocab, bias=False)
        
        # init weights
        self.lm_head.weight = self.tok_emb.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, x):
        if x.shape[1] > self.seq_len:
            x = x[:, -self.seq_len:]
        
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb[:tok_emb.shape[1]]
        
        x = tok_emb + pos_emb

        for block in self.blocks:
            x = block(x)

        x = self.lm_head(x)
        return x
    
    @torch.no_grad()
    def generate(self, start: list[int] | torch.Tensor | None = None, max_len: int = 100, temperature: float = 1.0, top_k: int = 0, use_cache: bool = False):
        if start is None:
            start = torch.randint(self.vocab, (1, 1), device=device)
        elif isinstance(start, list):
            start = torch.tensor(start, dtype=torch.long, device=device).unsqueeze(0)

        if use_cache:
            self.toggle_kv_cache(True)

        x = start

        for _ in tqdm(range(max_len)):
            y_pred = self(x)
            y_pred = y_pred[:, -1, :] / temperature
            if top_k > 0:
                y_pred = torch.topk(y_pred, top_k, dim=-1).values
            next_char = torch.multinomial(torch.nn.functional.softmax(y_pred, dim=-1), 1)
            x = torch.cat([x, next_char], dim=1)
        
        self.toggle_kv_cache(False)
        return x
    
    def toggle_kv_cache(self, value: bool):
        if value:
            self.blocks.apply(lambda module: setattr(module, "kv_cache", {"k": torch.empty(0), "v": torch.empty(0)}))
        else:
            self.blocks.apply(lambda module: setattr(module, "kv_cache", None))

model = Model().to(device)
# model.toggle_kv_cache(True)
print(decode_brackets(model.generate(max_len=64, use_cache=False).squeeze()))
print(f"Evaluation loss: {evaluate(model, val_loader):.4f}")
summary(model, (1, 63), dtypes=[torch.long], device=device, depth=2)

100%|██████████| 64/64 [00:05<00:00, 11.08it/s]


}>>(()[<}{){>[[>(({[)<}<)[)<(<}>]){}<<>}<[}((}}<{>>])]][<<>><]<}>
Evaluation loss: 2.0815


Layer (type:depth-idx)                             Output Shape              Param #
Model                                              [1, 63, 8]                224
├─Embedding: 1-1                                   [1, 63, 4]                32
├─Sequential: 1-2                                  --                        --
│    └─TransformerBlock: 2-1                       [1, 63, 4]                288
│    └─TransformerBlock: 2-2                       [1, 63, 4]                288
│    └─TransformerBlock: 2-3                       [1, 63, 4]                288
│    └─TransformerBlock: 2-4                       [1, 63, 4]                288
├─Linear: 1-3                                      [1, 63, 8]                32
Total params: 1,440
Trainable params: 1,440
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.00
Input size (MB): 0.00
Forward/backward pass size (MB): 0.12
Params size (MB): 0.00
Estimated Total Size (MB): 0.12

In [20]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [28]:
if "val_loss" in locals():
    pass
else:
    val_loss = float("inf")

def loss_fn(y_pred, y):
    # masked loss
    


for epoch in range(100):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for x, in pbar:
        x = x.to(device)

        # mask input
        y = x.clone()
        idx = torch.randint(1, x.shape[1] - 1, (1,))
        x[:, idx] = 8
        
        y_pred = model(x)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = F.cross_entropy(y_pred[:, :-1].reshape(-1, y_pred.shape[-1]), y[:, 1:].reshape(-1))
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item(), val_loss=val_loss)

    val_loss = evaluate(model, val_loader)
    print(f"Validation loss: {val_loss:.4f}")

Epoch 0:  38%|███▊      | 71/188 [00:05<00:09, 12.25it/s, loss=1.43, val_loss=1.71]


KeyboardInterrupt: 

In [29]:
decode_brackets(x[3])

'{{<(({[{(<<({{< MASK {{[{{{<<[{<[[{{<>}}]]>}]>>}}}]}}]>}})>>)}]}))>}}'

In [None]:
x 

In [None]:
import matplotlib.pyplot as plt

# visualize embeddings
with torch.no_grad():
    emb = model.tok_emb.weight.cpu().numpy()
    plt.figure(figsize=(10, 10))
    plt.imshow(emb, cmap="viridis")
    plt.colorbar()
    plt.title("Embeddings")
    plt.xlabel("Embedding dimension")
    plt.ylabel("Vocabulary index")
    plt.yticks(list(idx_to_char.keys()), list(idx_to_char.values()))
    plt.show()