In [1]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Optional
import math
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cpu


## Tokenizer and Data Loader


In [2]:
from tokenizer import ARCTokenizer

In [3]:
# Initialize tokenizer
tokenizer = ARCTokenizer()
print(f"Vocabulary size: {tokenizer.vocab_size}")
print(f"Special tokens: {tokenizer.token_to_id}")

# Test tokenizer
test_grid = [[1, 2, 3], [4, 5, 6]]
tokens = tokenizer.grid_to_tokens(test_grid)
print(f"\nTest grid: {test_grid}")
print(f"Tokens: {tokens}")
print(f"Back to grid: {tokenizer.tokens_to_grid(tokens, (2, 3))}")

Vocabulary size: 18
Special tokens: {'PAD': 10, 'SOS': 11, 'EOS': 12, 'TRAIN': 13, 'TEST': 14, 'INPUT': 15, 'OUTPUT': 16, 'NEWLINE': 17}

Test grid: [[1, 2, 3], [4, 5, 6]]
Tokens: [1, 2, 3, 17, 4, 5, 6]
Back to grid: [[1, 2, 3], [4, 5, 6]]


## Token converter (enrich with position info)

In [4]:
from token_converter import TokenTo3DConverter
# Initialize converter
token_converter = TokenTo3DConverter(tokenizer)


## Data Loader with Augmentation


In [5]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Optional
import math
from tqdm import tqdm
import warnings
from dataset import ARCDataset, ARCTorchDataset 
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cpu


In [6]:
# Load datasets
print("Loading datasets...")
train_dataset = ARCDataset(
    challenges_path='arc-agi_training_challenges.json',
    solutions_path='arc-agi_training_solutions.json'
)

test_dataset = ARCDataset(
    challenges_path='arc-agi_test_challenges.json'
)

print(f"Training challenges: {len(train_dataset.get_all_challenges())}")
print(f"Test challenges: {len(test_dataset.get_all_challenges())}")

# Create PyTorch datasets
train_torch_dataset = ARCTorchDataset(train_dataset, tokenizer, token_converter=token_converter)
test_torch_dataset = ARCTorchDataset(test_dataset, tokenizer, token_converter=token_converter)

print(f"\nTraining samples (with augmentation): {len(train_torch_dataset)}")
print(f"Test samples: {len(test_torch_dataset)}")

# Test data loading
sample = train_torch_dataset[0]
print(f"\nSample data:")
print(f"Sample ID: {sample['sample_id']}")
print(f"Challenge ID: {sample['challenge_id']}")
print(f"Input sequence length: {len(sample['input'])}")
print(f"Target sequence length: {len(sample['target'])}")
print(f"Input dims: {sample['input_dims']}")
print(f"Output dims: {sample['output_dims']}")
print(f"Test input dims: {sample['test_input_dims']}")
print(f"Test output dims: {sample['test_output_dims']}")

Loading datasets...
Training challenges: 400
Test challenges: 100

Training samples (with augmentation): 614
Test samples: 152

Sample data:
Sample ID: 007bbfb7_orig
Challenge ID: 007bbfb7
Input sequence length: 5400
Target sequence length: 1000
Input dims: [(3, 3), (3, 3)]
Output dims: [(9, 9), (9, 9)]
Test input dims: (3, 3)
Test output dims: (9, 9)


In [7]:
# self.PAD_TOKEN = 10
# self.SOS_TOKEN = 11  # Start of sequence
# self.EOS_TOKEN = 12  # End of sequence
# self.TRAIN_TOKEN = 13  # Start of training example
# self.TEST_TOKEN = 14  # Start of test example
# self.INPUT_TOKEN = 15  # Start of input grid
# self.OUTPUT_TOKEN = 16  # Start of output grid
# self.NEWLINE_TOKEN = 17  # Grid separator (], [)
train_torch_dataset[0]['input']#[0:300]

tensor([[11, -1, -1],
        [13, -1, -1],
        [15, -1, -1],
        ...,
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1]])

In [8]:
train_torch_dataset[0]['target'][0:300]

tensor([[11, -1, -1],
        [16, -1, -1],
        [ 7,  0,  0],
        [ 0,  1,  0],
        [ 7,  2,  0],
        [ 0,  3,  0],
        [ 0,  4,  0],
        [ 0,  5,  0],
        [ 7,  6,  0],
        [ 0,  7,  0],
        [ 7,  8,  0],
        [17, -1, -1],
        [ 7,  0,  1],
        [ 0,  1,  1],
        [ 7,  2,  1],
        [ 0,  3,  1],
        [ 0,  4,  1],
        [ 0,  5,  1],
        [ 7,  6,  1],
        [ 0,  7,  1],
        [ 7,  8,  1],
        [17, -1, -1],
        [ 7,  0,  2],
        [ 7,  1,  2],
        [ 0,  2,  2],
        [ 0,  3,  2],
        [ 0,  4,  2],
        [ 0,  5,  2],
        [ 7,  6,  2],
        [ 7,  7,  2],
        [ 0,  8,  2],
        [17, -1, -1],
        [ 7,  0,  3],
        [ 0,  1,  3],
        [ 7,  2,  3],
        [ 0,  3,  3],
        [ 0,  4,  3],
        [ 0,  5,  3],
        [ 7,  6,  3],
        [ 0,  7,  3],
        [ 7,  8,  3],
        [17, -1, -1],
        [ 7,  0,  4],
        [ 0,  1,  4],
        [ 7,  2,  4],
        [ 

In [9]:
len(train_torch_dataset), len(test_torch_dataset)

(614, 152)

## Autoregressive Dataset for LLM Training

For autoregressive training, we need to:
1. Concatenate input + target into one sequence
2. Create labels shifted by 1 position (next token prediction)
3. Use causal masking so model can't see future tokens

**Important**: This is NOT data leakage! During training, the model learns to predict the next token given previous tokens. During inference, we'll use the same autoregressive generation process.


In [10]:
from exploded_dataset import ARCExplodedDataset

In [12]:
# Create exploded datasets from existing ARCTorchDataset
print("Creating exploded training dataset...")
train_exploded_dataset = ARCExplodedDataset(train_torch_dataset, tokenizer, sequence_length=5400)

Creating exploded training dataset...
Exploding 614 base samples...


100%|██████████| 614/614 [00:04<00:00, 150.94it/s]

Created 88372 exploded samples from 614 base samples





In [13]:
for i in range(10):
    print(train_exploded_dataset[i]['input_3d'][218:228], train_exploded_dataset[i]['target_vector'])

tensor([[ 7,  1,  2],
        [ 0,  2,  2],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1]]) tensor([11, -1, -1])
tensor([[ 7,  1,  2],
        [ 0,  2,  2],
        [11, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1]]) tensor([16, -1, -1])
tensor([[ 7,  1,  2],
        [ 0,  2,  2],
        [11, -1, -1],
        [16, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1]]) tensor([7, 0, 0])
tensor([[ 7,  1,  2],
        [ 0,  2,  2],
        [11, -1, -1],
        [16, -1, -1],
        [ 7,  0,  0],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1],
        [10, -1, -1]]) tensor([0, 1, 0])
tensor([[ 7,  1,  2],
        [ 0,  2,

# Modeling 

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
class NextTokenPredictor(nn.Module):
    """
    Transformer model to predict next token from 3D vectors [value, x, y].
    """
    
    def __init__(self, vocab_size=18, d_model=256, nhead=8, num_layers=4, 
                 dim_feedforward=1024, max_seq_length=5400, dropout=0.1):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_seq_length = max_seq_length
        
        # Embedding for token values (0-17)
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Projection for x, y coordinates (add coordinate information)
        self.coord_projection = nn.Linear(2, d_model)  # [x, y] -> d_model
        
        # Positional encoding (learned)
        self.pos_encoding = nn.Parameter(torch.randn(max_seq_length, d_model) * 0.02)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Output projection to vocab
        self.output_proj = nn.Linear(d_model, vocab_size)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input_3d, attention_mask=None):
        """
        Args:
            input_3d: [batch_size, seq_len, 3] - [value, x, y] vectors
            attention_mask: [batch_size, seq_len] - 1 for real tokens, 0 for padding
        
        Returns:
            logits: [batch_size, seq_len, vocab_size] - logits for each position
        """
        batch_size, seq_len, _ = input_3d.shape
        
        # Extract components
        token_values = input_3d[:, :, 0].long()  # [batch_size, seq_len] - token values
        coordinates = input_3d[:, :, 1:3].float()  # [batch_size, seq_len, 2] - x, y
        
        # Embed tokens
        token_emb = self.token_embedding(token_values)  # [batch_size, seq_len, d_model]
        
        # Add coordinate information
        coord_emb = self.coord_projection(coordinates)  # [batch_size, seq_len, d_model]
        
        # Combine token and coordinate embeddings
        x = token_emb + coord_emb  # [batch_size, seq_len, d_model]
        
        # Add positional encoding
        x = x + self.pos_encoding[:seq_len].unsqueeze(0)  # [batch_size, seq_len, d_model]
        
        x = self.dropout(x)
        
        # Convert attention mask to format expected by transformer
        # Transformer expects: True/1 = attend, False/0 = mask out
        # Our mask: 1 = real token, 0 = padding
        # So we need to invert: padding_mask = 1 - attention_mask
        if attention_mask is not None:
            padding_mask = (attention_mask == 0)  # True for padding, False for real tokens
        else:
            padding_mask = None
        
        # Apply transformer
        x = self.transformer(x, src_key_padding_mask=padding_mask)  # [batch_size, seq_len, d_model]
        
        # Get logits for all positions
        logits = self.output_proj(x)  # [batch_size, seq_len, vocab_size]
        
        return logits


In [None]:
# Create model
model = NextTokenPredictor(
    vocab_size=tokenizer.vocab_size,
    d_model=256,
    nhead=8,
    num_layers=4,
    dim_feedforward=1024,
    max_seq_length=5400,
    dropout=0.1
).to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Model device: {next(model.parameters()).device}")

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)


In [None]:
# Create DataLoader
def collate_fn(batch):
    """Collate function for batching"""
    input_3d = torch.stack([item['input_3d'] for item in batch])
    target_values = torch.stack([torch.tensor(item['target_value'], dtype=torch.long) for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    
    return {
        'input_3d': input_3d,
        'target_values': target_values,
        'attention_mask': attention_mask
    }

batch_size = 8
train_loader = DataLoader(
    train_exploded_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=0  # Set to 0 for notebooks
)

print(f"Created DataLoader with {len(train_loader)} batches (batch_size={batch_size})")


In [None]:
# Training loop
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc="Training")
    for batch in pbar:
        # Move to device
        input_3d = batch['input_3d'].to(device)  # [batch_size, seq_len, 3]
        target_values = batch['target_values'].to(device)  # [batch_size]
        attention_mask = batch['attention_mask'].to(device)  # [batch_size, seq_len]
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(input_3d=input_3d, attention_mask=attention_mask)  # [batch_size, seq_len, vocab_size]
        
        # Get logits for the last non-padding position (where we predict)
        # Find last non-padding position for each sequence
        batch_size = input_3d.size(0)
        seq_lengths = attention_mask.sum(dim=1) - 1  # -1 because we want the position before the target
        last_logits = logits[torch.arange(batch_size), seq_lengths]  # [batch_size, vocab_size]
        
        # Compute loss
        loss = criterion(last_logits, target_values)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()
        
        # Metrics
        total_loss += loss.item()
        predictions = last_logits.argmax(dim=1)
        correct += (predictions == target_values).sum().item()
        total += target_values.size(0)
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100 * correct / total:.2f}%'
        })
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy

# Train for a few epochs
num_epochs = 3
print(f"\nStarting training for {num_epochs} epochs...")
print("=" * 60)

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    avg_loss, accuracy = train_epoch(model, train_loader, criterion, optimizer, device)
    scheduler.step()
    
    print(f"Epoch {epoch + 1} - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
    print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}")

print("\nTraining completed!")


In [None]:
# Evaluation function
def evaluate(model, dataloader, criterion, device):
    """Evaluate model on dataset"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_3d = batch['input_3d'].to(device)
            target_values = batch['target_values'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            logits = model(input_3d=input_3d, attention_mask=attention_mask)
            
            batch_size = input_3d.size(0)
            seq_lengths = attention_mask.sum(dim=1) - 1
            last_logits = logits[torch.arange(batch_size), seq_lengths]
            
            loss = criterion(last_logits, target_values)
            total_loss += loss.item()
            
            predictions = last_logits.argmax(dim=1)
            correct += (predictions == target_values).sum().item()
            total += target_values.size(0)
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy

# Test on a few samples
print("\n" + "=" * 60)
print("Testing model on a few samples...")
model.eval()

with torch.no_grad():
    sample = train_exploded_dataset[0]
    input_3d = sample['input_3d'].unsqueeze(0).to(device)  # [1, seq_len, 3]
    attention_mask = sample['attention_mask'].unsqueeze(0).to(device)  # [1, seq_len]
    target_value = sample['target_value']
    
    logits = model(input_3d=input_3d, attention_mask=attention_mask)
    seq_len = attention_mask.sum().item() - 1
    prediction = logits[0, seq_len].argmax().item()
    
    token_names = {10: 'PAD', 11: 'SOS', 12: 'EOS', 13: 'TRAIN', 14: 'TEST', 
                   15: 'INPUT', 16: 'OUTPUT', 17: 'NEWLINE'}
    
    print(f"Sample prediction:")
    print(f"  Target token: {target_value} ({token_names.get(target_value, target_value)})")
    print(f"  Predicted token: {prediction} ({token_names.get(prediction, prediction)})")
    print(f"  Correct: {prediction == target_value}")
