In [1]:
class CharTokenizer:
    def __init__(self, texts, indent_spaces=4):
        self.indent_spaces = indent_spaces

        # Special tokens (fixed IDs)
        self.special_tokens = ["<pad>", "<bos>", "<eos>", "<indent>"] # padding, beginning of sequence, end of sequence, indent
        self.stoi = {tok: i for i, tok in enumerate(self.special_tokens)} # map string to index
        self.itos = {i: tok for tok, i in self.stoi.items()} # map index to string

        # Collect characters
        chars = set() # set of unique characters
        for text in texts: 
            chars.update(text) # add all of the unique characters

        # Assign IDs
        offset = len(self.stoi) # offset by 4, ie numbers taken up by special tokens already
        for i, ch in enumerate(sorted(chars)): 
            self.stoi[ch] = i + offset # map the characters to index
            self.itos[i + offset] = ch # reverse map

        self.vocab_size = len(self.stoi)

    def encode(self, text, add_special_tokens=True):
        ids = []

        if add_special_tokens:
            ids.append(self.stoi["<bos>"])

        i = 0
        while i < len(text):
            # Handle indentation (only at line start)
            if text[i] == " ":
                count = 0
                while i < len(text) and text[i] == " ":
                    count += 1
                    i += 1
                # you kinda reverse engineer from the amount of spaces how many indents there are

                while count >= self.indent_spaces: # when count bigger than 4 it counts as an indent
                    ids.append(self.stoi["<indent>"]) # add token for indent
                    count -= self.indent_spaces # reduce count by 4

                # leftover spaces
                ids.extend([self.stoi[" "]] * count) # add remaining spaces as formatting spaces basically
            else:
                ids.append(self.stoi[text[i]])
                i += 1

        if add_special_tokens:
            ids.append(self.stoi["<eos>"])

        return ids

    def decode(self, ids):
        text = "" #initialize string
        for i in ids:
            token = self.itos.get(i, "") # get the token from ids (the index)
            if token == "<bos>" or token == "<eos>" or token == "<pad>":
                continue
            elif token == "<indent>":
                text += " " * self.indent_spaces #. add 4 spaces if there was an indent
            else:
                text += token #just add the token to the string
        return text


In [2]:
import pandas as pd
df = pd.read_csv("code_bug_fix_pairs.csv")

In [3]:
import re
def clean_code_logic(text):
    if not isinstance(text, str):
            return ""

    marker = "# Sample ID"
    index = text.find(marker)

    if index == -1:
        return text.strip()

    return text[:index].strip()

# --- Step 1: Clean the DataFrame first ---
print("Cleaning data and building custom vocabulary...")

df['buggy_clean'] = df['buggy_code'].apply(clean_code_logic)
df['fixed_clean'] = df['fixed_code'].apply(clean_code_logic)

# --- Step 2: Gather cleaned tokens into a list ---
# Using .tolist() is much faster than iterrows()
texts = df['buggy_clean'].tolist() + df['fixed_clean'].tolist()

print(f"Collected {len(texts)} cleaned code snippets.")



tokenizer = CharTokenizer(texts)

print("Vocab size:", tokenizer.vocab_size)
print(list(tokenizer.stoi.items()))



Cleaning data and building custom vocabulary...
Collected 2000 cleaned code snippets.
Vocab size: 49
[('<pad>', 0), ('<bos>', 1), ('<eos>', 2), ('<indent>', 3), ('\n', 4), (' ', 5), ("'", 6), ('(', 7), (')', 8), ('*', 9), ('+', 10), (',', 11), ('-', 12), ('0', 13), ('1', 14), ('2', 15), ('3', 16), ('4', 17), ('5', 18), (':', 19), ('=', 20), ('>', 21), ('F', 22), ('H', 23), ('M', 24), ('T', 25), ('[', 26), (']', 27), ('_', 28), ('a', 29), ('b', 30), ('c', 31), ('d', 32), ('e', 33), ('f', 34), ('g', 35), ('h', 36), ('i', 37), ('l', 38), ('m', 39), ('n', 40), ('o', 41), ('p', 42), ('r', 43), ('s', 44), ('t', 45), ('u', 46), ('w', 47), ('x', 48)]


In [4]:
sample = df.iloc[0]["buggy_clean"]
print("ORIGINAL:")
print(repr(sample))

encoded = tokenizer.encode(sample)
decoded = tokenizer.decode(encoded)

print("ENCODED:")
print(encoded[:50])  # print first 50 tokens
print("DECODED:")
print(repr(decoded))


assert decoded == sample
print("Yippee reversiblityy")

ORIGINAL:
'x = [1, 2, 3]\nprint x'
ENCODED:
[1, 48, 5, 20, 5, 26, 14, 11, 5, 15, 11, 5, 16, 27, 4, 42, 43, 37, 40, 45, 5, 48, 2]
DECODED:
'x = [1, 2, 3]\nprint x'
Yippee reversiblityy


In [5]:
import math
from typing import List, Dict, Tuple

import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


# ----------------------------
# Positional encoding + embed
# ----------------------------
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4096):
        super().__init__()
        pe = torch.zeros(max_len, d_model, dtype=torch.float32)
        pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        T = x.size(1)
        return x + self.pe[:T].unsqueeze(0)


class PreTransformerInputs(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, pad_id: int, max_len: int = 4096):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.pos = SinusoidalPositionalEncoding(d_model, max_len=max_len)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.pos(self.emb(input_ids))


# ----------------------------
# Masks
# ----------------------------
def make_key_padding_mask(input_ids: torch.Tensor, pad_id: int) -> torch.Tensor:
    """
    nn.Transformer convention:
      True  = this position is PAD (ignore it)
      False = real token
    Shape: [B, T]
    """
    return (input_ids == pad_id)


def make_causal_mask(tgt_len: int, device: torch.device) -> torch.Tensor:
    """
    True = blocked (can't attend)
    Shape: [T, T]
    """
    return torch.triu(torch.ones(tgt_len, tgt_len, dtype=torch.bool, device=device), diagonal=1)


def pad_1d(seqs: List[List[int]], pad_id: int) -> torch.Tensor:
    max_len = max(len(s) for s in seqs)
    out = [s + [pad_id] * (max_len - len(s)) for s in seqs]
    return torch.tensor(out, dtype=torch.long)


def teacher_forcing_shift(tgt_ids_padded: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    tgt_ids_padded: [B, T]
    returns:
      tgt_in  [B, T-1]
      tgt_out [B, T-1]
    """
    return tgt_ids_padded[:, :-1].contiguous(), tgt_ids_padded[:, 1:].contiguous()


# ----------------------------
# Dataset using your notebook columns
# ----------------------------
class BugFixCharDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer):
        self.src_texts = df["buggy_clean"].astype(str).tolist()
        self.tgt_texts = df["fixed_clean"].astype(str).tolist()
        self.tok = tokenizer

    def __len__(self):
        return len(self.src_texts)

    def __getitem__(self, idx):
        # IMPORTANT:
        # - for src: we keep <bos>/<eos> because it helps the model learn boundaries
        # - for tgt: also keep <bos>/<eos> so we can do shifting for teacher forcing
        src_ids = self.tok.encode(self.src_texts[idx], add_special_tokens=True)
        tgt_ids = self.tok.encode(self.tgt_texts[idx], add_special_tokens=True)
        return src_ids, tgt_ids


def collate_for_transformer(batch, pad_id: int, device: torch.device):
    """
    batch: list of (src_ids, tgt_ids), variable length
    returns tensors/masks ready for transformer blocks
    """
    src_list, tgt_list = zip(*batch)

    src_ids = pad_1d(list(src_list), pad_id=pad_id).to(device)  # [B, Tsrc]
    tgt_ids = pad_1d(list(tgt_list), pad_id=pad_id).to(device)  # [B, Ttgt]

    tgt_in_ids, tgt_out_ids = teacher_forcing_shift(tgt_ids)     # [B, Ttgt-1]

    src_key_padding_mask = make_key_padding_mask(src_ids, pad_id=pad_id)      # [B, Tsrc]
    tgt_key_padding_mask = make_key_padding_mask(tgt_in_ids, pad_id=pad_id)   # [B, Ttgt-1]
    tgt_causal_mask = make_causal_mask(tgt_in_ids.size(1), device=device)     # [Ttgt-1, Ttgt-1]

    return {
        "src_ids": src_ids,
        "tgt_in_ids": tgt_in_ids,
        "tgt_out_ids": tgt_out_ids,
        "src_key_padding_mask": src_key_padding_mask,
        "tgt_key_padding_mask": tgt_key_padding_mask,
        "tgt_causal_mask": tgt_causal_mask,
    }


# ----------------------------
# Example wiring (after you already ran your notebook cleaning + tokenizer build)
# ----------------------------
if __name__ == "__main__":
    # Assume df already has: buggy_clean, fixed_clean
    # and tokenizer = CharTokenizer(texts) already exists.

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    PAD_ID = tokenizer.stoi["<pad>"]
    vocab_size = tokenizer.vocab_size

    dataset = BugFixCharDataset(df, tokenizer)

    loader = DataLoader(
        dataset,
        batch_size=32,
        shuffle=True,
        collate_fn=lambda b: collate_for_transformer(b, pad_id=PAD_ID, device=device),
    )

    # Optional: build embeddings+pos enc outputs (ready to feed into blocks)
    d_model = 256
    prep = PreTransformerInputs(vocab_size=vocab_size, d_model=d_model, pad_id=PAD_ID).to(device)

    batch = next(iter(loader))

    src_x = prep(batch["src_ids"])       # [B, Tsrc, d_model]
    tgt_x = prep(batch["tgt_in_ids"])    # [B, Ttgt-1, d_model]

    print("src_ids:", batch["src_ids"].shape)
    print("tgt_in_ids:", batch["tgt_in_ids"].shape)
    print("tgt_out_ids:", batch["tgt_out_ids"].shape)
    print("src_x:", src_x.shape)
    print("tgt_x:", tgt_x.shape)
    print("src_key_padding_mask:", batch["src_key_padding_mask"].shape)
    print("tgt_key_padding_mask:", batch["tgt_key_padding_mask"].shape)
    print("tgt_causal_mask:", batch["tgt_causal_mask"].shape)


src_ids: torch.Size([32, 76])
tgt_in_ids: torch.Size([32, 76])
tgt_out_ids: torch.Size([32, 76])
src_x: torch.Size([32, 76, 256])
tgt_x: torch.Size([32, 76, 256])
src_key_padding_mask: torch.Size([32, 76])
tgt_key_padding_mask: torch.Size([32, 76])
tgt_causal_mask: torch.Size([76, 76])
