In [5]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RANDOM_SEED = 1

### Tokenizer Definition

In [2]:
class CustomTokenizer:
    def __init__(self):
        # Define the special tokens
        self.special_tokens = ["Input:", "Target:", "<scratch>", "</scratch>", "A->", "C->", ',', ' ', '\n', '+','1','2','3','4','5','6','7','8','9','0', '<PAD>', '<START>', '<END>']
        self.token_to_id = {token: i for i, token in enumerate(self.special_tokens)}
        self.id_to_token = {i: token for token, i in self.token_to_id.items()}
        self.vocab = set(self.special_tokens)
        self.vocab_size = len(self.vocab)
        self.max_length = 50
    
    def tokenize(self, text):
        """
        Tokenize a given text into a list of tokens, properly handling newlines and special tokens.
        """
        # First preserve newlines by replacing them with a special placeholder
        text = text.replace('\n', ' <NEWLINE> ')
        
        # Handle other special tokens
        for token in self.special_tokens:
            text = text.replace(token, f" {token} ")
        
        # Split on whitespace and filter out empty strings
        tokens = [token for token in text.split() if token]
        
        # Replace newline placeholder back if needed
        tokens = ['\n' if t == '<NEWLINE>' else t for t in tokens]
        
        # Update vocabulary
        self.vocab.update(tokens)
        
        return tokens
    
 
        
    def encode(self, tokens):
        """
        Encode a list of tokens into their corresponding IDs.
        """
        return [self.token_to_id.get(token, len(self.token_to_id)) for token in tokens]
    
    def decode(self, token_ids):
        """
        Decode a list of token IDs back into a text string.
        """
        tokens = [self.id_to_token.get(token_id, "<unk>") for token_id in token_ids]
        return " ".join(tokens)
    
    def pad(self, data_tokenized):
        """
        Pad a list of tokenized data so that all sequences have the same length.
        """
        row = data_tokenized + [self.token_to_id['<PAD>']] * (self.max_length - len(data_tokenized))
        return row

### Generate Data

In [3]:
def equation_to_prompt(a,b, with_answer=True):
    """
    Convert an equation string into a prompt string.
    """
    input_str = f"{a}+{b}"
    scratch = "<scratch>\n"
    carry = 0
    for i, (digit_a, digit_b) in enumerate(zip(str(a)[::-1], str(b)[::-1])):
        sum_digits = int(digit_a) + int(digit_b) + carry
        carry = sum_digits // 10
        scratch += f"A->{sum_digits % 10}, C->{carry}\n"
    
    # Append remaining digits of the longer number
    remaining = str(a)[::-1][len(str(b)):] if len(str(a)) > len(str(b)) else str(b)[::-1][len(str(a)):]

    for digit in remaining:
        sum_digits = int(digit) + carry
        carry = sum_digits // 10
        scratch += f"A->{sum_digits % 10}, C->{carry}\n"
    
    if carry > 0:
        scratch += f"A->{carry}, C->0\n"
    
    scratch += "</scratch>"
    
    # Combine into the target
    target_str = f"{scratch}\n{a + b}<END>" if with_answer else "<scratch>\n"
    
    return f"<START>Input:\n{input_str}\n\nTarget:\n{target_str}"

def generate_training_data(limit=1000):
    dataset = []
    for a in range(limit):
        for b in range(limit):
            example = equation_to_prompt(a, b)
            example_tokenized = tokenizer.encode(tokenizer.tokenize(example))
            example_padded = tokenizer.pad(example_tokenized)
            dataset.append(example_padded)
    return dataset


# Generate the dataset
tokenizer = CustomTokenizer()
data = generate_training_data(limit=200)
train_data, val_data = train_test_split(data, test_size=0.2, shuffle=True, random_state=RANDOM_SEED)
train_data, val_data = torch.tensor(train_data), torch.tensor(val_data)
print(f"Training data: {train_data.shape}")
print(f"Validation data: {val_data.shape}")

Training data: torch.Size([32000, 50])
Validation data: torch.Size([8000, 50])


In [22]:
x = np.array([
    [1,0,0,0],
    [0,1,0,0],
    [1,0,0,0]
])
rev = np.array([
    [0, 0,1],
    [0, 1,0],
    [1,0,0,]
])
# flip the rows of x
np.dot(rev,x)

array([[1, 0, 0, 0],
       [0, 1, 0, 0],
       [1, 0, 0, 0]])

### Model

In [4]:
############################
# POSITIONAL EMBEDDING
############################
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len):
        super().__init__()
        # We'll just use a learnable embedding [max_len, embed_dim].
        self.pos_embed = nn.Embedding(max_len, embed_dim)

    def forward(self, x):
        """
        x: [batch_size, seq_len, embed_dim]
        We'll add a positional embedding for each position in the sequence.
        """
        seq_len = x.size(1)
        # positions = [0, 1, 2, ..., seq_len-1]
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        # pos_embed[positions] -> [1, seq_len, embed_dim]
        return x + self.pos_embed(positions)
    

###################################
# MASKED MULTI-HEAD SELF ATTENTION
###################################
def causal_mask(seq_len):
    """
    Create a lower-triangular mask for causal attention of shape [seq_len, seq_len].
    Entry (i, j) is True if position j is masked out for position i (j > i).
    We'll convert it to a float mask for use in attention.
    """
    # mask[i, j] = True if j > i, i.e. future positions
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Query, Key, Value projection layers
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        x: [batch_size, seq_len, embed_dim]
        """
        B, T, C = x.size()

        # Project to Q, K, V
        q = self.q_proj(x)  # [B, T, C]
        k = self.k_proj(x)  # [B, T, C]
        v = self.v_proj(x)  # [B, T, C]

        # Reshape to [B, num_heads, T, head_dim]
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # [B, nh, T, hd]
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # [B, nh, T, hd]
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # [B, nh, T, hd]

        # Compute attention scores: [B, nh, T, T]
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Causal mask
        causal = causal_mask(T).to(x.device)  # [T, T]
        attn_scores = attn_scores.masked_fill(causal, float('-inf'))

        # Softmax
        attn_probs = F.softmax(attn_scores, dim=-1)  # [B, nh, T, T]
        attn_probs = self.attn_dropout(attn_probs)

        # Weighted sum
        out = torch.matmul(attn_probs, v)  # [B, nh, T, hd]

        # Recombine heads
        out = out.transpose(1, 2).contiguous().view(B, T, C)  # [B, T, embed_dim]

        # Output projection
        out = self.out_proj(out)
        out = self.resid_dropout(out)

        return out


############################
# TRANSFORMER DECODER BLOCK
############################
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        
        # Feed-Forward Network
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # x: [B, T, embed_dim]

        # Self-attention
        x_norm = self.ln1(x)
        attn_out = self.attn(x_norm)
        x = x + attn_out  # residual

        # Feed-forward
        x_norm = self.ln2(x)
        ffn_out = self.ffn(x_norm)
        x = x + ffn_out   # residual

        return x


############################
# GPT
############################
class GPT(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, max_seq_len, dropout=0.1):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = PositionalEncoding(embed_dim, max_seq_len)
        self.blocks = nn.ModuleList([
            DecoderBlock(embed_dim, num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])
        self.ln_f = nn.LayerNorm(embed_dim)  # final layer norm
        self.head = nn.Linear(embed_dim, vocab_size, bias=False)

    def forward(self, idx):
        """
        idx: [B, T] of token indices
        Returns: logits [B, T, vocab_size]
        """
        B, T = idx.shape

        # 1. Embed tokens
        x = self.token_emb(idx)  # [B, T, embed_dim]

        # 2. Add positional embeddings
        x = self.pos_emb(x)      # [B, T, embed_dim]

        # 3. Pass through decoder blocks
        for block in self.blocks:
            x = block(x)

        # 4. Final layer norm
        x = self.ln_f(x)

        # 5. Output head
        logits = self.head(x)    # [B, T, vocab_size]

        return logits


### Define Train Loop

In [6]:
# Example hyperparameters
vocab_size = tokenizer.vocab_size     # size of your vocabulary
embed_dim = 128                             # embedding dimension
num_heads = 4                               # number of attention heads
num_layers = 4                              # number of decoder layers
sequence_length = 50                        # max sequence length in  dataset
dropout = 0.1                               # dropout rate
batch_size = 32
num_batches = len(train_data) // batch_size


train_dataset = TensorDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = TensorDataset(val_data)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = GPT(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    max_seq_len=sequence_length,
    dropout=dropout
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

def train_one_epoch(model, data_loader, optimizer):
    model.train()
    total_loss = 0
    i = 0
    for batch in data_loader:
        (x,) = batch
        x = x.to(device)

        # GPT input: x[:, :-1], target: x[:, 1:]
        input_tokens = x[:, :-1]
        target_tokens = x[:, 1:]

        logits = model(input_tokens)  # [B, T-1, vocab_size]
        B, Tm1, V = logits.shape

        # Make tensors contiguous or use .reshape
        logits_2d = logits.contiguous().view(B * Tm1, V)
        targets_1d = target_tokens.contiguous().view(B * Tm1)

        # Cross entropy
        loss = F.cross_entropy(logits_2d, targets_1d)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print(f"BATCH {i}/{num_batches},  LOSS: {loss.item()}")

        
        total_loss += loss.item()
        i += 1

    return total_loss / len(data_loader)

@torch.no_grad()
def evaluate(model, data_loader):
    model.eval()
    total_loss = 0
    for batch in data_loader:
        (x,) = batch
        x = x.to(device)

        input_tokens = x[:, :-1]
        target_tokens = x[:, 1:]
        logits = model(input_tokens)

        B, Tm1, V = logits.shape

        logits_2d = logits.contiguous().view(B * Tm1, V)
        targets_1d = target_tokens.contiguous().view(B * Tm1)
        loss = F.cross_entropy(logits_2d, targets_1d)
        
        total_loss += loss.item()
    return total_loss / len(data_loader)


num_epochs = 2
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer)
    val_loss = evaluate(model, val_loader)
    print(f"Epoch {epoch+1} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

BATCH 0/1000,  LOSS: 3.2178728580474854
BATCH 100/1000,  LOSS: 0.896584153175354
BATCH 200/1000,  LOSS: 0.6238868236541748
BATCH 300/1000,  LOSS: 0.5347651839256287
BATCH 400/1000,  LOSS: 0.5086193084716797
BATCH 500/1000,  LOSS: 0.4683353006839752
BATCH 600/1000,  LOSS: 0.46341222524642944
BATCH 700/1000,  LOSS: 0.43469005823135376
BATCH 800/1000,  LOSS: 0.40084734559059143
BATCH 900/1000,  LOSS: 0.35798534750938416
Epoch 1 | Train Loss: 0.5907 | Val Loss: 0.3298
BATCH 0/1000,  LOSS: 0.3355371654033661
BATCH 100/1000,  LOSS: 0.33056873083114624
BATCH 200/1000,  LOSS: 0.3253413438796997
BATCH 300/1000,  LOSS: 0.3217582404613495
BATCH 400/1000,  LOSS: 0.319132000207901
BATCH 500/1000,  LOSS: 0.32011303305625916
BATCH 600/1000,  LOSS: 0.32148075103759766
BATCH 700/1000,  LOSS: 0.31413790583610535
BATCH 800/1000,  LOSS: 0.3162676692008972
BATCH 900/1000,  LOSS: 0.3226606845855713
Epoch 2 | Train Loss: 0.3227 | Val Loss: 0.3132


### Generate

In [28]:
@torch.no_grad()
def generate(model, num1, num2, max_length=50):
    # Build the prompt
    prompt = equation_to_prompt(num1, num2, with_answer=False)
    # Tokenize + encode
    tokens = tokenizer.tokenize(prompt)
    start_tokens = tokenizer.encode(tokens)

    model.eval()
    while len(start_tokens) < max_length:
        # Convert current tokens into a tensor; no padding, just the raw sequence
        input_ids = torch.tensor([start_tokens], dtype=torch.long).to(device)
        logits = model(input_ids)  # [1, current_length, vocab_size]
        
        # Take the logits from the last time-step
        next_token_logits = logits[0, -1, :]  # [vocab_size]
        next_token_id = torch.argmax(next_token_logits).item()

        # If the model predicts <PAD> or <END>, we can choose to stop
        # for example, if next_token_id == tokenizer.token_to_id['<PAD>']:
        #     break

        # Append next token to our running sequence
        start_tokens.append(next_token_id)
    
    # Decode the final token list
    answer = tokenizer.decode(start_tokens)
    return answer

In [29]:
generate(model, 2,5)

'<START> Input: \n 2 + 5 \n \n Target: \n <scratch> \n A-> 5 , C-> 0 \n </scratch> \n 5 <END> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>'