In [None]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Optional
import math
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


## Tokenizer and Data Loader


In [None]:
from tokenizer import ARCTokenizer

In [None]:
# Initialize tokenizer
tokenizer = ARCTokenizer()
print(f"Vocabulary size: {tokenizer.vocab_size}")
print(f"Special tokens: {tokenizer.token_to_id}")

# Test tokenizer
test_grid = [[1, 2, 3], [4, 5, 6]]
tokens = tokenizer.grid_to_tokens(test_grid)
print(f"\nTest grid: {test_grid}")
print(f"Tokens: {tokens}")
print(f"Back to grid: {tokenizer.tokens_to_grid(tokens, (2, 3))}")

## Token converter (enrich with position info)

In [None]:
from token_converter import TokenTo3DConverter
# Initialize converter
token_converter = TokenTo3DConverter(tokenizer)


## Data Loader with Augmentation


In [None]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Optional
import math
from tqdm import tqdm
import warnings
from dataset import ARCDataset, ARCTorchDataset 
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
# Load datasets
print("Loading datasets...")
train_dataset = ARCDataset(
    challenges_path='arc-agi_training_challenges.json',
    solutions_path='arc-agi_training_solutions.json'
)

test_dataset = ARCDataset(
    challenges_path='arc-agi_test_challenges.json'
)

print(f"Training challenges: {len(train_dataset.get_all_challenges())}")
print(f"Test challenges: {len(test_dataset.get_all_challenges())}")

# Create PyTorch datasets
train_torch_dataset = ARCTorchDataset(train_dataset, tokenizer, token_converter=token_converter)
test_torch_dataset = ARCTorchDataset(test_dataset, tokenizer, token_converter=token_converter)

print(f"\nTraining samples (with augmentation): {len(train_torch_dataset)}")
print(f"Test samples: {len(test_torch_dataset)}")

# Test data loading
sample = train_torch_dataset[0]
print(f"\nSample data:")
print(f"Sample ID: {sample['sample_id']}")
print(f"Challenge ID: {sample['challenge_id']}")
print(f"Input sequence length: {len(sample['input'])}")
print(f"Target sequence length: {len(sample['target'])}")
print(f"Input dims: {sample['input_dims']}")
print(f"Output dims: {sample['output_dims']}")
print(f"Test input dims: {sample['test_input_dims']}")
print(f"Test output dims: {sample['test_output_dims']}")

In [None]:
# self.PAD_TOKEN = 10
# self.SOS_TOKEN = 11  # Start of sequence
# self.EOS_TOKEN = 12  # End of sequence
# self.TRAIN_TOKEN = 13  # Start of training example
# self.TEST_TOKEN = 14  # Start of test example
# self.INPUT_TOKEN = 15  # Start of input grid
# self.OUTPUT_TOKEN = 16  # Start of output grid
# self.NEWLINE_TOKEN = 17  # Grid separator (], [)
train_torch_dataset[0]['input']#[0:300]

In [None]:
train_torch_dataset[0]['target'][0:300]

In [None]:
len(train_torch_dataset), len(test_torch_dataset)

## Autoregressive Dataset for LLM Training

For autoregressive training, we need to:
1. Concatenate input + target into one sequence
2. Create labels shifted by 1 position (next token prediction)
3. Use causal masking so model can't see future tokens

**Important**: This is NOT data leakage! During training, the model learns to predict the next token given previous tokens. During inference, we'll use the same autoregressive generation process.


In [None]:
from exploded_dataset import ARCExplodedDataset

In [None]:
# Create exploded datasets from existing ARCTorchDataset
print("Creating exploded training dataset...")
train_exploded_dataset = ARCExplodedDataset(train_torch_dataset, tokenizer, sequence_length=5400)

In [None]:
for i in range(10):
    print(train_exploded_dataset[i]['input_3d'][218:228], train_exploded_dataset[i]['target_vector'])

## split to val / train

In [None]:
# Create DataLoader
def collate_fn(batch):
    """Collate function for batching"""
    input_3d = torch.stack([item['input_3d'] for item in batch])
    target_values = torch.stack([torch.tensor(item['target_value'], dtype=torch.long) for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    
    return {
        'input_3d': input_3d,
        'target_values': target_values,
        'attention_mask': attention_mask
    }

In [None]:
# Split dataset by challenge_id (ensures no data leakage)
import random
from collections import defaultdict
from torch.utils.data import Subset

# Group samples by challenge_id
challenge_to_indices = defaultdict(list)
for idx in range(len(train_exploded_dataset)):
    sample = train_exploded_dataset[idx]
    challenge_id = sample['challenge_id']
    challenge_to_indices[challenge_id].append(idx)

# Get unique challenge IDs
challenge_ids = list(challenge_to_indices.keys())
print(f"Total challenges: {len(challenge_ids)}")

# Shuffle and split challenges (not individual samples)
random.seed(42)
random.shuffle(challenge_ids)

train_ratio = 0.8
split_idx = int(len(challenge_ids) * train_ratio)
train_challenge_ids = set(challenge_ids[:split_idx])
val_challenge_ids = set(challenge_ids[split_idx:])

print(f"Train challenges: {len(train_challenge_ids)}")
print(f"Val challenges: {len(val_challenge_ids)}")

# Collect indices for each split
train_indices = []
val_indices = []

for challenge_id, indices in challenge_to_indices.items():
    if challenge_id in train_challenge_ids:
        train_indices.extend(indices)
    else:
        val_indices.extend(indices)

print(f"\nTrain samples: {len(train_indices)}")
print(f"Val samples: {len(val_indices)}")

# Create subset datasets
train_dataset_split = Subset(train_exploded_dataset, train_indices)
val_dataset_split = Subset(train_exploded_dataset, val_indices)

# Create DataLoaders
batch_size = 8

train_loader = DataLoader(
    train_dataset_split,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=0
)

val_loader = DataLoader(
    val_dataset_split,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=0
)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")


# Modeling 

In [None]:
import wandb
wandb.login(key='9c6d131f5fcedb96565fa31f4680c2da83ea07d5')

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
class NextTokenPredictor(nn.Module):
    """
    Transformer model to predict next token from 3D vectors [value, x, y].
    """
    
    def __init__(self, vocab_size=18, d_model=16, nhead=8, num_layers=4, 
                 dim_feedforward=1024, max_seq_length=5400, dropout=0.1):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_seq_length = max_seq_length
        
        # Embedding for token values (0-17)
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Projection for x, y coordinates (add coordinate information)
        self.coord_projection = nn.Linear(2, d_model)  # [x, y] -> d_model
        
        # Positional encoding (learned)
        #self.pos_encoding = nn.Parameter(torch.randn(max_seq_length, d_model) * 0.02)
        self.pos_encoding = self.create_sinusoidal_positional_encoding(max_seq_length, d_model)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Output projection to vocab
        self.output_proj = nn.Linear(d_model, vocab_size)
        
        self.dropout = nn.Dropout(dropout)
    
    def create_sinusoidal_positional_encoding(self, max_len, d_model):
        """
        Create sinusoidal positional encoding (no learnable parameters).
        
        Args:
            max_len: Maximum sequence length
            d_model: Model dimension
        
        Returns:
            [max_len, d_model] tensor with positional encodings
        """
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)  # Even indices: sin
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd indices: cos
        
        return pe  # [max_len, d_model]

        
    def forward(self, input_3d, attention_mask=None):
        """
        Args:
            input_3d: [batch_size, seq_len, 3] - [value, x, y] vectors
            attention_mask: [batch_size, seq_len] - 1 for real tokens, 0 for padding
        
        Returns:
            logits: [batch_size, seq_len, vocab_size] - logits for each position
        """
        batch_size, seq_len, _ = input_3d.shape
        
        # Extract components
        token_values = input_3d[:, :, 0].long()  # [batch_size, seq_len] - token values
        coordinates = input_3d[:, :, 1:3].float()  # [batch_size, seq_len, 2] - x, y
        
        # Embed tokens
        token_emb = self.token_embedding(token_values)  # [batch_size, seq_len, d_model]
        
        # Add coordinate information
        coord_emb = self.coord_projection(coordinates)  # [batch_size, seq_len, d_model]
        
        # Combine token and coordinate embeddings
        x = token_emb + coord_emb  # [batch_size, seq_len, d_model]
        
        # Add positional encoding
        x = x + self.pos_encoding[:seq_len].unsqueeze(0)  # [batch_size, seq_len, d_model]
        
        #x = self.dropout(x)
        if attention_mask is not None:
            padding_mask = (attention_mask == 0).bool()  # True for padding, False for real tokens
        else:
            padding_mask = None
        
        # Apply transformer
        x = self.transformer(x, src_key_padding_mask=padding_mask)  # [batch_size, seq_len, d_model]
        
        # Get logits for all positions
        logits = self.output_proj(x)  # [batch_size, seq_len, vocab_size]
        
        return logits


In [None]:
# Create model
model = NextTokenPredictor(
    vocab_size=tokenizer.vocab_size,
    d_model=256,
    nhead=8,
    num_layers=4,
    dim_feedforward=1024,
    max_seq_length=5400,
    dropout=0.1
).to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Model device: {next(model.parameters()).device}")

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)


In [None]:
# Training loop
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc="Training")
    for batch in pbar:
        # Move to device
        input_3d = batch['input_3d'].to(device)  # [batch_size, seq_len, 3]
        target_values = batch['target_values'].to(device)  # [batch_size]
        attention_mask = batch['attention_mask'].to(device)  # [batch_size, seq_len]
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(input_3d=input_3d, attention_mask=attention_mask)  # [batch_size, seq_len, vocab_size]
        
        # Get logits for the last non-padding position (where we predict)
        # Find last non-padding position for each sequence
        batch_size = input_3d.size(0)
        seq_lengths = attention_mask.sum(dim=1) - 1  # -1 because we want the position before the target
        last_logits = logits[torch.arange(batch_size), seq_lengths]  # [batch_size, vocab_size]
        
        # Compute loss
        loss = criterion(last_logits, target_values)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()
        
        # Metrics
        total_loss += loss.item()
        predictions = last_logits.argmax(dim=1)
        correct += (predictions == target_values).sum().item()
        total += target_values.size(0)
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100 * correct / total:.2f}%'
        })
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy


In [None]:
# Evaluation function
def evaluate(model, dataloader, criterion, device):
    """Evaluate model on dataset"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_3d = batch['input_3d'].to(device)
            target_values = batch['target_values'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            logits = model(input_3d=input_3d, attention_mask=attention_mask)
            
            batch_size = input_3d.size(0)
            seq_lengths = attention_mask.sum(dim=1) - 1
            last_logits = logits[torch.arange(batch_size), seq_lengths]
            
            loss = criterion(last_logits, target_values)
            total_loss += loss.item()
            
            predictions = last_logits.argmax(dim=1)
            correct += (predictions == target_values).sum().item()
            total += target_values.size(0)
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy

In [None]:
# Updated training loop with validation and wandb logging
num_epochs = 3
print(f"\nStarting training for {num_epochs} epochs...")
print("=" * 60)

best_val_acc = 0.0

#wandb.init(
#    name='test',
#    project="arc-next-token-prediction",
#    config={
#        "vocab_size": tokenizer.vocab_size,
#        "d_model": model.d_model,
#        "nhead": model.transformer.layers[0].self_attn.num_heads,
#        "num_layers": len(model.transformer.layers),
#        "batch_size": batch_size,
#        "learning_rate": optimizer.param_groups[0]['lr'],
#        "max_seq_length": model.max_seq_length,
#    }
#)

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    print("-" * 60)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step()
    
    # Log to wandb
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "train_accuracy": train_acc,
        "val_loss": val_loss,
        "val_accuracy": val_acc,
        "learning_rate": optimizer.param_groups[0]['lr']
    })
    
    # Print results
    print(f"\nEpoch {epoch + 1} Results:")
    print(f"  Train - Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%")
    print(f"  Val   - Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%")
    print(f"  Learning rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Track best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        print(f"  âœ“ New best validation accuracy: {best_val_acc:.2f}%")
        # Optionally save model checkpoint
        # torch.save(model.state_dict(), 'best_model.pt')
    
    print()

print("\n" + "=" * 60)
print("Training completed!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")
print("=" * 60)

wandb.finish()
