# Models

In [None]:
# Math mode
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import matplotlib.pyplot as plt

# We use a tiny vocab: 0-9, +, =, and space (pad)
chars = "0123456789+= "
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

class Config:
    vocab_size = len(chars)
    n_embed = 256
    n_heads = 4          
    block_size = 16    
    dropout = 0.05
    device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu'

Config = Config()
print("Using device:", Config.device)

# Standard transformer components
class Head(nn.Module):
    """ One head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(Config.n_embed, head_size, bias=False)
        self.query = nn.Linear(Config.n_embed, head_size, bias=False)
        self.value = nn.Linear(Config.n_embed, head_size, bias=False)
        self.register_buffer(
            "tril", torch.tril(torch.ones(Config.block_size, Config.block_size))
        )

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        v = self.value(x) # (B,T,C)

        # Compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * C**-0.5 # (B,T,C) @ (B,C,T) -> (B,T,T)

        # Apply the causal mask
        # 'to ensure that the model only attends to past and present tokens, never future ones.'
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))

        # Softmax normalization
        wei = F.softmax(wei, dim=-1) # (B,T,T)
        # wei = self.dropout(wei)

        # Perform the weighted aggregation of the values
        out = wei @ v # (B,T,T) @ (B,T,C) -> (B,T,C)
        return out
    

class FeedForward(nn.Module):
    """ A simple linear layer followed by a non-linearity """
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
        )

    def forward(self, x):
        return self.net(x)
    
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self):
        super().__init__()
        head_size = Config.n_embed // Config.n_heads
        self.sa = MultiHeadAttention(Config.n_heads, head_size)
        self.ffwd = FeedForward(Config.n_embed)
        self.ln1 = nn.LayerNorm(Config.n_embed)
        self.ln2 = nn.LayerNorm(Config.n_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x
    
    
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList(
            [Head(head_size) for _ in range(num_heads)]
        )
        self.proj = nn.Linear(Config.n_embed, Config.n_embed)
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        
        return out
        
        
class RecurrentGPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(Config.vocab_size, Config.n_embed)
        self.position_embedding_table = nn.Embedding(Config.block_size, Config.n_embed)
        
        # Shared blocks
        self.shared_sa = MultiHeadAttention(Config.n_heads, Config.n_embed // Config.n_heads)
        self.shared_ffwd = FeedForward(Config.n_embed)
        self.ln1 = nn.LayerNorm(Config.n_embed)
        self.ln2 = nn.LayerNorm(Config.n_embed)
        
        self.ln_f = nn.LayerNorm(Config.n_embed) # final layer norm
        self.lm_head = nn.Linear(Config.n_embed, Config.vocab_size)
        
    def forward(self, idx, targets=None, recur_depth=1):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
        # 1.Embeddings
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=Config.device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        
        # 2. Recurrent Depth loop
        # so instead of : for layer in self.laayers:
        # we do
        for _ in range(recur_depth):
            x = x + self.shared_sa(self.ln1(x))
            x = x + self.shared_ffwd(self.ln2(x))
        
        # 3. Final layer norm and head
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        loss = None
        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    



In [None]:


# GENERATOR (2 Digits ONLY)
def generate_math_batch(batch_size=32):
    inputs, targets = [], []
    for _ in range(batch_size):
        # 2 DIGITS (e.g. 50+50=100)
        # This is easy enough to learn in 5 minutes
        a = random.randint(10, 99)
        b = random.randint(10, 99)
        
        problem = f"{a}+{b}={a+b}"
        problem = problem.ljust(Config.block_size, ' ') 
        
        encoded = [stoi[p] for p in problem]
        x = torch.tensor(encoded[:-1], dtype=torch.long)
        y = torch.tensor(encoded[1:], dtype=torch.long)
        inputs.append(x)
        targets.append(y)
        
    return torch.stack(inputs).to(Config.device), torch.stack(targets).to(Config.device)

# TRAINING LOOP
model = RecurrentGPT().to(Config.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) # Standard LR

print("Starting Training (2-3 Digit Addition)...")
print("We need Loss < 0.1 for this to work.\n")

for step in range(3001):
    xb, yb = generate_math_batch(64)
    train_depth = random.randint(1, 8) 
    
    logits, loss = model(xb, targets=yb, recur_depth=train_depth)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    if step % 500 == 0:
        print(f"Step {step}: Loss {loss.item():.4f} (Depth {train_depth})")


# EVALUATION 
print("\nCHECKING PREDICTIONS ")
xb, yb = generate_math_batch(1)
with torch.no_grad():
    logits, _ = model(xb, recur_depth=8)
    preds = torch.argmax(logits, dim=2)

input_str = "".join([itos[i.item()] for i in xb[0]])
pred_str = "".join([itos[i.item()] for i in preds[0]])
# Note: Prediction is shifted by 1, so we align visually
print(f"Input: {input_str}")
print(f"Pred:  {pred_str}")


print('\nLets see accuracy at different depths:')
# EVALUATION (SMART VERSION) 
def evaluate_smart(depth):
    # Test on 200 samples
    total = 200
    correct = 0
    xb, yb = generate_math_batch(total)
    
    with torch.no_grad():
        logits, _ = model(xb, recur_depth=depth)
    
    preds = torch.argmax(logits, dim=2)
    
    for i in range(total):
        # Convert to strings
        target_str = "".join([itos[x.item()] for x in yb[i]])
        pred_str = "".join([itos[x.item()] for x in preds[i]])
        
        # Split at '=' to ignore the prompt
        if '=' in target_str:
            # We only check the part AFTER the equals sign
            target_ans = target_str.split('=')[1].strip()
            
            if '=' in pred_str:
                pred_ans = pred_str.split('=')[1].strip()
                
                # Check if answers match
                if target_ans == pred_ans:
                    correct += 1
            
    return correct / total

print("Calculating Accuracy on Answer Only (Ignoring Prompt Errors)...")
usable_accuracies = []
for d in [1, 2, 4, 8]:
    acc = evaluate_smart(d)
    usable_accuracies.append(acc*100)
    print(f"Depth {d:2d}: Accuracy = {acc*100:.1f}%")
    
# PLOTTING RESULTS
depths = ['Depth 1', 'Depth 2', 'Depth 4', 'Depth 8']
accuracy = usable_accuracies
colors = ['#ff9999', '#66b3ff', '#99ff99', '#66b3ff'] # Red for fail, Blue/Green for success

plt.figure(figsize=(8, 5))
bars = plt.bar(depths, accuracy, color=colors, edgecolor='black')

plt.title('Impact of Recurrent Depth on Arithmetic Accuracy', fontsize=14, fontweight='bold')
plt.ylabel('Accuracy (%)', fontsize=12)
plt.ylim(0, 110)

for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2.0, height + 2, f'{height}%', 
             ha='center', va='bottom', fontsize=12, fontweight='bold')

plt.savefig('demo_results.png', dpi=300)
print("Chart saved as demo_results.png!")
plt.show()