In [None]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Optional
import math
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    torch.cuda.empty_cache()

AUGMENTATION=True
EVAL_COLOR_DIST=False


Using device: cpu


# helper classes

In [3]:
class ARCTokenizer:
    """Tokenizer for ARC challenges with special tokens for structure"""
    
    def __init__(self):
        # Value tokens (0-9)
        self.value_tokens = list(range(10))
        
        # Special tokens
        self.PAD_TOKEN = 10
        self.SOS_TOKEN = 11  # Start of sequence
        self.EOS_TOKEN = 12  # End of sequence
        self.TRAIN_TOKEN = 13  # Start of training example
        self.TEST_TOKEN = 14  # Start of test example
        self.INPUT_TOKEN = 15  # Start of input grid
        self.OUTPUT_TOKEN = 16  # Start of output grid
        self.NEWLINE_TOKEN = 17  # Grid separator (], [)
        
        self.vocab_size = 18
        
        # Token mappings
        self.token_to_id = {
            'PAD': self.PAD_TOKEN,
            'SOS': self.SOS_TOKEN,
            'EOS': self.EOS_TOKEN,
            'TRAIN': self.TRAIN_TOKEN,
            'TEST': self.TEST_TOKEN,
            'INPUT': self.INPUT_TOKEN,
            'OUTPUT': self.OUTPUT_TOKEN,
            'NEWLINE': self.NEWLINE_TOKEN
        }
    
    def grid_to_tokens(self, grid: List[List[int]]) -> List[int]:
        """Convert 2D grid to token sequence"""
        if not grid or not grid[0]:
            return []
        
        tokens = []
        for i, row in enumerate(grid):
            for j, value in enumerate(row):
                tokens.append(value)  # Just the value, position will be encoded separately
            if i < len(grid) - 1:  # Add newline between rows (except last)
                tokens.append(self.NEWLINE_TOKEN)
        
        return tokens
    
    def tokens_to_grid(self, tokens: List[int], target_shape: Tuple[int, int]) -> List[List[int]]:
        """Convert token sequence back to 2D grid"""
        h, w = target_shape
        grid = [[0 for _ in range(w)] for _ in range(h)]
        
        # Filter out special tokens and newlines
        values = [t for t in tokens if t < 10]  # Only keep value tokens (0-9)
        
        idx = 0
        for i in range(h):
            for j in range(w):
                if idx < len(values):
                    grid[i][j] = values[idx]
                    idx += 1
        
        return grid
    
    def create_input_sequence(self, train_examples: List[Dict], test_input: List[List[int]]) -> List[int]:
        """Create input sequence from training examples and test input"""
        sequence = [self.SOS_TOKEN]
        
        # Add training examples (exactly 2)
        for i, example in enumerate(train_examples[:2]):
            sequence.append(self.TRAIN_TOKEN)
            
            # Add input
            sequence.append(self.INPUT_TOKEN)
            input_tokens = self.grid_to_tokens(example['input'])
            sequence.extend(input_tokens)
            
            # Add output
            sequence.append(self.OUTPUT_TOKEN)
            output_tokens = self.grid_to_tokens(example['output'])
            sequence.extend(output_tokens)
        
        # Add test input
        sequence.append(self.TEST_TOKEN)
        sequence.append(self.INPUT_TOKEN)
        test_tokens = self.grid_to_tokens(test_input)
        sequence.extend(test_tokens)
        
        return sequence
    
    def create_target_sequence(self, target_grid: List[List[int]]) -> List[int]:
        """Create target sequence for training"""
        sequence = [self.SOS_TOKEN]
        sequence.append(self.OUTPUT_TOKEN)
        target_tokens = self.grid_to_tokens(target_grid)
        sequence.extend(target_tokens)
        sequence.append(self.EOS_TOKEN)
        return sequence
    
    def pad_sequence(self, sequence: List[int], max_length: int) -> List[int]:
        """Pad sequence to max_length"""
        if len(sequence) > max_length:
            return sequence[:max_length]
        return sequence + [self.PAD_TOKEN] * (max_length - len(sequence))




In [4]:
class TokenTo3DConverter:
    """Converts token sequences to 3D vectors [value, x, y] with coordinate information"""
    
    def __init__(self, tokenizer: ARCTokenizer):
        self.tokenizer = tokenizer
    
    def tokens_to_3d(self, 
                     tokens: List[int],
                     input_dims: List[Tuple[int, int]],
                     output_dims: List[Tuple[int, int]],
                     test_input_dims: Tuple[int, int],
                     test_output_dims: Optional[Tuple[int, int]] = None,
                     is_target: bool = False) -> torch.Tensor:
        """
        Convert token sequence to 3D vectors [value, x, y]
        
        Args:
            tokens: List of token IDs
            input_dims: List of (height, width) for training input grids
            output_dims: List of (height, width) for training output grids
            test_input_dims: (height, width) for test input grid
            is_target: If True, this is a target sequence (starts with OUTPUT_TOKEN)
        
        Returns:
            Tensor of shape [seq_len, 3] where each row is [value, x, y]
            Special tokens have x=-1, y=-1
        """
        result = []
        
        # Track current grid context
        current_grid_type = None  # 'train_input', 'train_output', 'test_input'
        current_grid_idx = 0
        current_row = 0
        current_col = 0
        current_grid_dims = None
        
        i = 0
        while i < len(tokens):
            token = tokens[i]
            
            # Handle special tokens that change context
            if token == self.tokenizer.SOS_TOKEN:
                result.append([token, -1, -1])
                i += 1
                continue
            elif token == self.tokenizer.EOS_TOKEN:
                result.append([token, -1, -1])
                i += 1
                continue
            elif token == self.tokenizer.PAD_TOKEN:
                result.append([token, -1, -1])
                i += 1
                continue
            elif token == self.tokenizer.TRAIN_TOKEN:
                current_grid_type = None
                current_grid_idx = 0
                result.append([token, -1, -1])
                i += 1
                continue
            elif token == self.tokenizer.TEST_TOKEN:
                current_grid_type = None
                result.append([token, -1, -1])
                i += 1
                continue
            elif token == self.tokenizer.INPUT_TOKEN:
                # Determine which input grid we're in
                if is_target:
                    # In target sequence, INPUT_TOKEN shouldn't appear
                    result.append([token, -1, -1])
                    i += 1
                    continue
                
                if current_grid_type is None:
                    # First INPUT after TRAIN - this is training input
                    if current_grid_idx < len(input_dims):
                        current_grid_dims = input_dims[current_grid_idx]
                        current_grid_type = 'train_input'
                elif current_grid_type == 'train_output':
                    # INPUT after OUTPUT in training - next training example
                    current_grid_idx += 1
                    if current_grid_idx < len(input_dims):
                        current_grid_dims = input_dims[current_grid_idx]
                        current_grid_type = 'train_input'
                elif current_grid_type is None:
                    # INPUT after TEST - this is test input
                    current_grid_dims = test_input_dims
                    current_grid_type = 'test_input'
                
                current_row = 0
                current_col = 0
                result.append([token, -1, -1])
                i += 1
                continue
            elif token == self.tokenizer.OUTPUT_TOKEN:
                # Determine which output grid we're in
                if current_grid_type == 'train_input':
                    # OUTPUT after INPUT in training
                    if current_grid_idx < len(output_dims):
                        current_grid_dims = output_dims[current_grid_idx]
                        current_grid_type = 'train_output'
                elif current_grid_type is None:
                    # OUTPUT at start (for target sequence) or after TEST
                    if is_target:
                        # For target sequence, use test_output_dims if available
                        if test_output_dims is not None:
                            current_grid_dims = test_output_dims
                        elif len(output_dims) > 0:
                            current_grid_dims = output_dims[0]  # Fallback to first output dims
                        else:
                            current_grid_dims = (1, 1)  # Default fallback
                    elif len(output_dims) > 0:
                        current_grid_dims = output_dims[0]
                    current_grid_type = 'train_output'
                
                current_row = 0
                current_col = 0
                result.append([token, -1, -1])
                i += 1
                continue
            elif token == self.tokenizer.NEWLINE_TOKEN:
                # Move to next row
                if current_grid_dims is not None:
                    current_row += 1
                    current_col = 0
                result.append([token, -1, -1])
                i += 1
                continue
            elif token < 10:  # Value token (0-9)
                # This is a grid value - add coordinates
                if current_grid_dims is not None:
                    h, w = current_grid_dims
                    # Clamp to valid ranges
                    row = min(current_row, h - 1)
                    col = min(current_col, w - 1)
                    result.append([token, col, row])  # x=col, y=row
                    
                    # Move to next column
                    current_col += 1
                else:
                    # No grid context, treat as special
                    result.append([token, -1, -1])
                i += 1
            else:
                # Unknown token, treat as special
                result.append([token, -1, -1])
                i += 1
        
        return torch.tensor(result, dtype=torch.int8)

In [5]:
class ARCDataset:
    """Dataset class for ARC challenges with data augmentation"""
    
    def __init__(self, challenges_path: str, solutions_path: str = None):
        self.challenges_path = challenges_path
        self.solutions_path = solutions_path
        
        # Load challenges
        with open(challenges_path, 'r') as f:
            self.challenges = json.load(f)
        
        # Load solutions if provided
        self.solutions = None
        if solutions_path:
            with open(solutions_path, 'r') as f:
                self.solutions = json.load(f)
    
    def get_challenge_data(self, challenge_id: str) -> Dict:
        """Get data for a specific challenge"""
        challenge = self.challenges[challenge_id]
        
        # Get training examples
        train_examples = challenge.get('train', [])
        
        # Get test examples
        test_examples = challenge.get('test', [])
        
        # Get solution if available
        solution = None
        if self.solutions and challenge_id in self.solutions:
            solution = self.solutions[challenge_id][0]  # First solution
        
        return {
            'train_examples': train_examples,
            'test_examples': test_examples,
            'solution': solution,
            'challenge_id': challenge_id
        }
    
    def get_all_challenges(self) -> List[str]:
        """Get list of all challenge IDs"""
        return list(self.challenges.keys())
    
    def create_augmented_samples(self, challenge_id: str) -> List[Dict]:
        """Create a single training sample from a challenge with 2 train examples and 1 test input"""
        data = self.get_challenge_data(challenge_id)
        
        samples = []
        
        # Always use first 2 training examples and first test input
        if len(data['train_examples']) >= 2:
            train_examples = data['train_examples'][:2]  # First 2 training examples
            test_input = data['test_examples'][0]['input'] if data['test_examples'] else []
            test_output = data['solution']
            
            samples.append({
                'train_examples': train_examples,
                'test_input': test_input,
                'test_output': test_output,
                'challenge_id': challenge_id,
                'sample_id': challenge_id  # Simple ID, no augmentation suffix
            })
        
        return samples

In [6]:
class ARCTorchDataset(Dataset):
    """PyTorch Dataset for ARC challenges"""
    
    def __init__(self, arc_dataset: ARCDataset, tokenizer: ARCTokenizer, 
                 token_converter = None):  # Optional converter
        self.arc_dataset = arc_dataset
        self.tokenizer = tokenizer
        self.token_converter = token_converter  # Optional converter for 3D vectors
        
        # Create all samples with augmentation
        self.samples = []
        for challenge_id in arc_dataset.get_all_challenges():
            samples = arc_dataset.create_augmented_samples(challenge_id)
            self.samples.extend(samples)
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Create input sequence
        input_seq = self.tokenizer.create_input_sequence(
            sample['train_examples'], 
            sample['test_input']
        )
        
        # Create target sequence
        if sample['test_output']:
            target_seq = self.tokenizer.create_target_sequence(sample['test_output'])
        else:
            # Create dummy target for test data
            target_seq = [self.tokenizer.SOS_TOKEN, self.tokenizer.EOS_TOKEN]
        
        # Pad sequences
        input_seq = self.tokenizer.pad_sequence(input_seq, 5400)   # 5 * 30x30 + bunch of extra tokens + possible target 30x30= 6*30x30
        target_seq = self.tokenizer.pad_sequence(target_seq, 1000) # max 30x30 + punch of extra tokens
        
        # Calculate dimensions
        input_dims = []
        output_dims = []
        
        for example in sample['train_examples']:
            input_dims.append((len(example['input']), len(example['input'][0]) if example['input'] else 0))
            output_dims.append((len(example['output']), len(example['output'][0]) if example['output'] else 0))
        
        test_input_dims = (len(sample['test_input']), len(sample['test_input'][0]) if sample['test_input'] else 0)
        test_output_dims = (len(sample['test_output']), len(sample['test_output'][0]) if sample['test_output'] else 0)
        
        # Convert to 3D vectors if converter is provided
        if self.token_converter is not None:
            input_3d = self.token_converter.tokens_to_3d(
                input_seq,
                input_dims,
                output_dims,
                test_input_dims,
                test_output_dims=test_output_dims,
                is_target=False
            )
            target_3d = self.token_converter.tokens_to_3d(
                target_seq,
                input_dims,
                output_dims,
                test_input_dims,
                test_output_dims=test_output_dims,
                is_target=True
            )
            return {
                'input': input_3d,  # Shape: [seq_len, 3] - [value, x, y]
                'target': target_3d,  # Shape: [seq_len, 3] - [value, x, y]
                'input_tokens': torch.tensor(input_seq, dtype=torch.int8),  # Keep original tokens too (int8 for memory efficiency)
                'target_tokens': torch.tensor(target_seq, dtype=torch.int8),  # Keep original tokens too (int8 for memory efficiency)
                'sample_id': sample['sample_id'],
                'challenge_id': sample['challenge_id'],
                'input_dims': input_dims,
                'output_dims': output_dims,
                'test_input_dims': test_input_dims,
                'test_output_dims': test_output_dims
            }
        else:
            # Return original token format
            return {
                'input': torch.tensor(input_seq, dtype=torch.int8),  # int8 for memory efficiency
                'target': torch.tensor(target_seq, dtype=torch.int8),  # int8 for memory efficiency
                'sample_id': sample['sample_id'],
                'challenge_id': sample['challenge_id'],
                'input_dims': input_dims,
                'output_dims': output_dims,
                'test_input_dims': test_input_dims,
                'test_output_dims': test_output_dims
            }

In [7]:
class ARCExplodedDataset(Dataset):
    """
    Explodes ARCTorchDataset into trainable samples.
    
    Takes each sample from ARCTorchDataset and creates multiple training samples:
    - Sample 0: input → predict target[0]
    - Sample 1: input + target[0] → predict target[1]
    - Sample 2: input + target[0:2] → predict target[2]
    - etc.
    
    Expects both input and target to be in 3D vector format [value, x, y].
    When adding target tokens:
    1. Loop through input sequence and replace first PAD token with target token
    2. If no PAD token found, append to end and remove first token
    """
    
    def __init__(self, torch_dataset: ARCTorchDataset, tokenizer: ARCTokenizer, sequence_length: int = 5400):
        self.torch_dataset = torch_dataset
        self.tokenizer = tokenizer
        self.sequence_length = sequence_length
        
        # Create all exploded samples
        self.exploded_samples = []
        
        print(f"Exploding {len(torch_dataset)} base samples...")
        for base_idx in tqdm(range(len(torch_dataset))):
            base_sample = torch_dataset[base_idx]
            
            # Get input and target as 3D tensors [seq_len, 3]
            input_3d = base_sample['input']  # Shape: [max_length, 3]
            target_3d = base_sample['target']  # Shape: [max_length, 3]
            
            # Find actual length of input (before padding)
            # PAD token has value = PAD_TOKEN (10), x = -1, y = -1
            input_actual_len = 0
            for i in range(input_3d.shape[0]):
                if input_3d[i, 0].item() == self.tokenizer.PAD_TOKEN:
                    break
            input_actual_len = i-1
            
            target_actual_len = 0
            for i in range(target_3d.shape[0]):
                if target_3d[i, 0].item() == self.tokenizer.PAD_TOKEN:
                    break
            target_actual_len = i-1
            
            #print(input_actual_len, target_actual_len)
            
            target_vectors = target_3d[0:target_actual_len]
            
            
            # Optimized version - remove unnecessary cloning and use input_actual_len directly
            # Replace the target_vectors collection and loop in cell 9 with this:

            # In the target_vectors collection (around line 43-50):
            # Change: target_vectors.append(target_3d[i].clone())
            # To:     target_vectors.append(target_3d[i])  # No clone needed

            # In the loop (around line 58-67):
            # Replace the entire loop with this optimized version:

            # Start with full input sequence (we'll modify it in place)
            current_seq = input_3d.clone()
            for i, target_vector in enumerate(target_vectors):
                # Calculate position where we should place this target token
                # Start from input_actual_len and add i (position in target sequence)
                target_pos = input_actual_len + i
                
                if i>0:
                    # first target vector is not added to the input sequence
                    if target_pos < self.sequence_length:
                        # Check if position has a PAD token
                        if current_seq[target_pos, 0].item() == self.tokenizer.PAD_TOKEN:
                            # Replace PAD token with target vector
                            current_seq[target_pos] = target_vectors[i-1]
                        else:
                            #print("Sequence is full - append and remove from beginning", target_pos, current_seq.shape)
                            # Sequence is full - append and remove from beginning
                            current_seq = torch.cat([current_seq[1:], target_vectors[i-1].unsqueeze(0)], dim=0)
                    else:
                        current_seq = torch.cat([current_seq[1:], target_vectors[i-1].unsqueeze(0)], dim=0)

                # Store exploded sample
                exploded_sample = {
                    'input_3d': current_seq.clone(),
                    'target_vector': target_vector.clone(),  # Clone here since we store it separately
                    'target_position': i,
                    'base_sample_idx': base_idx,
                    'base_sample_id': base_sample.get('sample_id', f'sample_{base_idx}'),
                    'challenge_id': base_sample.get('challenge_id', ''),
                    'input_dims': base_sample.get('input_dims', []),
                    'output_dims': base_sample.get('output_dims', []),
                    'test_input_dims': base_sample.get('test_input_dims', (0, 0)),
                    'test_output_dims': base_sample.get('test_output_dims', (0, 0)),
                }

                self.exploded_samples.append(exploded_sample)
        
        print(f"Created {len(self.exploded_samples)} exploded samples from {len(torch_dataset)} base samples")
    
    def __len__(self):
        return len(self.exploded_samples)
    
    def __getitem__(self, idx):
        sample = self.exploded_samples[idx]
        
        input_3d = sample['input_3d']  # Shape: [max_length, 3]
        target_vector = sample['target_vector']  # Shape: [3]
        
        # Create attention mask (1 for non-padding, 0 for padding)
        attention_mask = (input_3d[:, 0] != self.tokenizer.PAD_TOKEN).long()
        
        return {
            'input_3d': input_3d,  # [max_length, 3] - full 3D vectors
            'target_vector': target_vector,  # [3] - target as 3D vector
            'target_value': target_vector[0].item(),  # Just the value token for convenience
            'attention_mask': attention_mask,  # [max_length]
            'target_position': sample['target_position'],
            'base_sample_idx': sample['base_sample_idx'],
            'base_sample_id': sample['base_sample_id'],
            'challenge_id': sample['challenge_id'],
            'input_dims': sample['input_dims'],
            'output_dims': sample['output_dims'],
            'test_input_dims': sample['test_input_dims'],
            'test_output_dims': sample['test_output_dims'],
        }



In [8]:
# augmentation during training
def apply_random_color_mapping(sample, apply_probability=1.0):
    """
    Apply a random color permutation (0-9) to a sample.
    OPTIMIZED: Uses vectorized PyTorch operations instead of Python loops.
    
    Args:
        sample: Dict with 'input_3d' and 'target_vector'
        apply_probability: Probability of applying augmentation (1.0 = always, 0.5 = 50% chance)
    
    Returns:
        Augmented sample with permuted colors
    """
    if np.random.random() > apply_probability:
        return sample  # Skip augmentation
    
    # Create random permutation of colors 0-9
    permuted_colors = torch.randperm(10, dtype=torch.int8)  # Vectorized permutation
    
    # Create lookup table: mapping[old_color] = new_color
    # For colors 0-9, use permuted mapping; for special tokens (10+), keep original
    mapping = torch.arange(18, dtype=torch.int8)  # Default: identity mapping
    mapping[:10] = permuted_colors  # Apply permutation to colors 0-9
    
    # Apply to input_3d (vectorized - much faster!)
    input_3d = sample['input_3d'].clone()
    color_values = input_3d[:, 0].int()  # Extract color values
    input_3d[:, 0] = mapping[color_values]  # Apply mapping in one operation
    
    # Apply to target_vector
    target_vector = sample['target_vector'].clone()
    target_color = target_vector[0].int()
    target_vector[0] = mapping[target_color]
    
    # Create augmented sample
    augmented_sample = sample.copy()
    augmented_sample['input_3d'] = input_3d
    augmented_sample['target_vector'] = target_vector
    augmented_sample['target_value'] = int(target_vector[0].item())
    
    return augmented_sample

# Dataset wrapper for live color augmentation
class AugmentedDataset(Dataset):
    """
    Wrapper that applies random color permutation augmentation on-the-fly.
    """
    def __init__(self, base_dataset, apply_probability=1.0):
        self.base_dataset = base_dataset
        self.apply_probability = apply_probability
    
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        sample = self.base_dataset[idx]
        return apply_random_color_mapping(sample, self.apply_probability)


# Alternative: Augment in collate_fn (even simpler)
def collate_fn_with_augmentation(batch, apply_probability=1.0):
    """
    Collate function that applies color augmentation to each sample.
    
    Usage:
        train_loader = DataLoader(
            train_dataset_split,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=lambda b: collate_fn_with_augmentation(b, apply_probability=1.0),
            num_workers=4
        )
    """
    # Apply augmentation to each sample
    augmented_batch = [apply_random_color_mapping(sample, apply_probability) for sample in batch]
    
    # Original collate logic
    input_3d = torch.stack([item['input_3d'] for item in augmented_batch])
    target_values = torch.stack([torch.tensor(item['target_value'], dtype=torch.long) for item in augmented_batch])
    attention_mask = torch.stack([item['attention_mask'] for item in augmented_batch])
    
    return {
        'input_3d': input_3d,
        'target_values': target_values,
        'attention_mask': attention_mask
    }

# creating the dataset

In [10]:
PATH='/kaggle/input/arc-prize-2025/'

# Load datasets
print("Loading datasets...")
train_dataset = ARCDataset(
    challenges_path=PATH+'arc-agi_training_challenges.json',
    solutions_path=PATH+'arc-agi_training_solutions.json'
)

test_dataset = ARCDataset(
    challenges_path=PATH+'arc-agi_test_challenges.json'
)

print(f"Training challenges: {len(train_dataset.get_all_challenges())}")
print(f"Test challenges: {len(test_dataset.get_all_challenges())}")

# Create PyTorch datasets
tokenizer = ARCTokenizer()
token_converter = TokenTo3DConverter(tokenizer)
train_torch_dataset = ARCTorchDataset(train_dataset, tokenizer, token_converter=token_converter)
test_torch_dataset = ARCTorchDataset(test_dataset, tokenizer, token_converter=token_converter)

print(f"\nTraining samples (with augmentation): {len(train_torch_dataset)}")
print(f"Test samples: {len(test_torch_dataset)}")

# Test data loading
sample = train_torch_dataset[0]
print(f"\nSample data:")
print(f"Sample ID: {sample['sample_id']}")
print(f"Challenge ID: {sample['challenge_id']}")
print(f"Input sequence length: {len(sample['input'])}")
print(f"Target sequence length: {len(sample['target'])}")
print(f"Input dims: {sample['input_dims']}")
print(f"Output dims: {sample['output_dims']}")
print(f"Test input dims: {sample['test_input_dims']}")
print(f"Test output dims: {sample['test_output_dims']}")

# Create exploded datasets from existing ARCTorchDataset
print("Creating exploded training dataset...")
train_exploded_dataset = ARCExplodedDataset(train_torch_dataset, tokenizer, sequence_length=5400)

Loading datasets...
Training challenges: 1000
Test challenges: 240

Training samples (with augmentation): 1000
Test samples: 240

Sample data:
Sample ID: 00576224
Challenge ID: 00576224
Input sequence length: 5400
Target sequence length: 1000
Input dims: [(2, 2), (2, 2)]
Output dims: [(6, 6), (6, 6)]
Test input dims: (2, 2)
Test output dims: (6, 6)
Creating exploded training dataset...
Exploding 1000 base samples...


100%|██████████| 1000/1000 [00:24<00:00, 40.38it/s]


Created 204169 exploded samples from 1000 base samples


In [12]:
# augmentation check
if AUGMENTATION:
    train_dataset_augmented = AugmentedDataset(train_exploded_dataset, apply_probability=1.0)


if EVAL_COLOR_DIST:
    from collections import Counter, defaultdict
    def analyze_color_distribution(dataset):
        """Analyze the distribution of target colors (0-9) in the dataset"""
        color_counts = Counter()
        
        print("Analyzing color distribution...")
        for idx in tqdm(range(len(dataset))):
            sample = dataset[idx]
            target_value = sample['target_value']
            # Only count color tokens (0-9), ignore special tokens
            color_counts[target_value] += 1
        
        return color_counts
    aug_counts = analyze_color_distribution(train_dataset_augmented)
    original_counts = analyze_color_distribution(train_exploded_dataset)
    print(original_counts)
    print(aug_counts) 
    # result, looks a lot better
    # Counter({0: 86966, 8: 19131, 1: 17683, 4: 14558, 7: 13704, 3: 12134, 2: 11829, 5: 7086, 6: 5145, 9: 3042})
    # Counter({2: 19417, 0: 19327, 4: 19150, 7: 19150, 8: 19144, 5: 19062, 3: 19055, 9: 19004, 6: 18995, 1: 18974})

if AUGMENTATION:
    train_exploded_dataset = train_dataset_augmented

In [None]:
# Create DataLoader
def collate_fn(batch):
    """Collate function for batching"""
    input_3d = torch.stack([item['input_3d'] for item in batch])
    target_values = torch.stack([torch.tensor(item['target_value'], dtype=torch.long) for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    
    return {
        'input_3d': input_3d,
        'target_values': target_values,
        'attention_mask': attention_mask
    }
    
# Split dataset by challenge_id (ensures no data leakage)
# Create TWO validation sets: tiny (frequent) and full (accurate)
import random
from collections import defaultdict
from torch.utils.data import Subset

# Group samples by challenge_id
challenge_to_indices = defaultdict(list)
for idx in range(len(train_exploded_dataset)):
    sample = train_exploded_dataset[idx]
    challenge_id = sample['challenge_id']
    challenge_to_indices[challenge_id].append(idx)

# Get unique challenge IDs
challenge_ids = list(challenge_to_indices.keys())
print(f"Total challenges: {len(challenge_ids)}")

# Shuffle and split challenges (not individual samples)
random.seed(42)
random.shuffle(challenge_ids)

train_ratio = 0.8
split_idx = int(len(challenge_ids) * train_ratio)
train_challenge_ids = set(challenge_ids[:split_idx])
val_challenge_ids_all = set(challenge_ids[split_idx:])

# Split validation challenges into tiny and full
val_challenge_ids_list = list(val_challenge_ids_all)
random.shuffle(val_challenge_ids_list)
tiny_val_ratio = 0.1  # 10% of validation challenges for tiny set
tiny_split_idx = int(len(val_challenge_ids_list) * tiny_val_ratio)
tiny_val_challenge_ids = set(val_challenge_ids_list[:tiny_split_idx])
full_val_challenge_ids = set(val_challenge_ids_list[tiny_split_idx:])

print(f"Train challenges: {len(train_challenge_ids)}")
print(f"Tiny val challenges: {len(tiny_val_challenge_ids)}")
print(f"Full val challenges: {len(full_val_challenge_ids)}")

# Collect indices for each split
train_indices = []
tiny_val_indices = []
full_val_indices = []

for challenge_id, indices in challenge_to_indices.items():
    if challenge_id in train_challenge_ids:
        train_indices.extend(indices)
    elif challenge_id in tiny_val_challenge_ids:
        tiny_val_indices.extend(indices)
    elif challenge_id in full_val_challenge_ids:
        full_val_indices.extend(indices)

print(f"\nTrain samples: {len(train_indices)}")
print(f"Tiny val samples: {len(tiny_val_indices)}")
print(f"Full val samples: {len(full_val_indices)}")

# Create subset datasets
train_dataset_split = Subset(train_exploded_dataset, train_indices)
tiny_val_dataset_split = Subset(train_exploded_dataset, tiny_val_indices)
full_val_dataset_split = Subset(train_exploded_dataset, full_val_indices)

# Create DataLoaders
batch_size = 16

train_loader = DataLoader(
    train_dataset_split,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4
)

# Tiny validation loader (for frequent checks)
tiny_val_loader = DataLoader(
    tiny_val_dataset_split,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=1
)

# Full validation loader (for accurate metrics)
full_val_loader = DataLoader(
    full_val_dataset_split,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=1
)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Tiny val batches: {len(tiny_val_loader)}")
print(f"Full val batches: {len(full_val_loader)}")

Total challenges: 1000
Train challenges: 800
Tiny val challenges: 20
Full val challenges: 180

Train samples: 161456
Tiny val samples: 5083
Full val samples: 37630

Train batches: 631
Tiny val batches: 20
Full val batches: 147


# enriched attention

In [None]:
class SparseEnhancedAttentionLayerTopK(nn.Module):
    """
    Custom transformer layer that:
    1. Computes full attention matrix Q @ K^T
    2. Extracts upper triangular (without diagonal, i < j)
    3. Applies ReLU to sparsify (only keep positive values)
    4. Selects top seq_len pairs by attention score (always exactly seq_len pairs per batch)
    5. Enriches non-zero attention scores with original vectors [attn_score, vec_i, vec_j]
    6. Bottlenecks enhanced vectors back to sequence length
    
    This bypasses the V network entirely and learns direct transformations.
    """
    
    def __init__(self, d_model, dim_feedforward=2048, dropout=0.1, 
                 bottleneck_hidden=None):
        """
        Args:
            d_model: Model dimension
            dim_feedforward: Feedforward network hidden dimension
            dropout: Dropout rate
            bottleneck_hidden: Hidden dimension for bottleneck MLP (defaults to dim_feedforward)
        """
        super().__init__()
        self.d_model = d_model
        
        # Query and Key projections (no V projection needed!)
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        # Bottleneck MLP: processes [attn_score, vec_i, vec_j] -> d_model
        # Input: [attn_score (1) + vec_i (d_model) + vec_j (d_model)] = d_model*2 + 1
        bottleneck_hidden = bottleneck_hidden or dim_feedforward
        self.bottleneck = nn.Sequential(
            nn.Linear(d_model * 2 + 1, bottleneck_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(bottleneck_hidden, d_model),
            nn.Dropout(dropout)
        )
        
        # Feedforward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model),
            nn.Dropout(dropout)
        )
        
        # Layer norms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, src_key_padding_mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            src_key_padding_mask: [batch_size, seq_len] - True for padding tokens
        
        Returns:
            [batch_size, seq_len, d_model]
        """
        residual = x
        x = self.norm1(x)
        
        batch_size, seq_len, _ = x.shape
        
        # Compute Q, K (no V needed!)
        Q = self.w_q(x)  # [batch_size, seq_len, d_model]
        K = self.w_k(x)  # [batch_size, seq_len, d_model]
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_model)
        # [batch, seq_len, seq_len]
        
        # Apply padding mask if provided
        if src_key_padding_mask is not None:
            mask = src_key_padding_mask.unsqueeze(1)  # [batch, 1, seq_len]
            scores = scores.masked_fill(mask, float('-inf'))
            mask = src_key_padding_mask.unsqueeze(2)  # [batch, seq_len, 1]
            scores = scores.masked_fill(mask, float('-inf'))
        
        # Extract upper triangular WITHOUT diagonal (strictly upper: i < j)
        triu_indices = torch.triu_indices(seq_len, seq_len, offset=1, device=x.device)
        i_indices = triu_indices[0]  # [num_pairs]
        j_indices = triu_indices[1]  # [num_pairs]
        
        # Extract upper triangular attention scores
        upper_tri_scores = scores[:, i_indices, j_indices]  # [batch, num_pairs]
        
        # Apply ReLU to sparsify (set negatives to 0)
        sparse_scores = torch.relu(upper_tri_scores)  # [batch, num_pairs]
        
        # Process each batch separately
        all_processed = []
        
        for b in range(batch_size):
            # Get this batch's sparse scores
            batch_sparse_scores = sparse_scores[b]  # [num_pairs]
            
            # Find non-zero pairs for this batch
            batch_non_zero_mask = batch_sparse_scores > 0  # [num_pairs]
            num_non_zero = batch_non_zero_mask.sum().item()
            
            if num_non_zero == 0:
                # No non-zero pairs - create zero vectors
                batch_processed = torch.zeros(seq_len, self.d_model, device=x.device)
            else:
                # Get non-zero scores and their pair indices
                non_zero_scores_b = batch_sparse_scores[batch_non_zero_mask]  # [num_non_zero]
                non_zero_pair_indices = torch.where(batch_non_zero_mask)[0]  # [num_non_zero]
                
                # Select exactly seq_len pairs (top-k)
                if num_non_zero >= seq_len:
                    # Select top seq_len pairs
                    topk_scores, topk_local_indices = torch.topk(non_zero_scores_b, k=seq_len)
                    selected_pair_indices = non_zero_pair_indices[topk_local_indices]
                    selected_scores = topk_scores
                else:
                    # Fewer than seq_len pairs available - use all and pad with zeros
                    selected_pair_indices = non_zero_pair_indices
                    selected_scores = non_zero_scores_b
                    
                    # Pad to seq_len
                    num_to_pad = seq_len - num_non_zero
                    # Use first available pairs as placeholders (will have zero scores)
                    pad_indices = torch.arange(num_to_pad, device=x.device) % len(non_zero_pair_indices)
                    pad_pair_indices = non_zero_pair_indices[pad_indices]
                    selected_pair_indices = torch.cat([selected_pair_indices, pad_pair_indices], dim=0)
                    selected_scores = torch.cat([
                        selected_scores,
                        torch.zeros(num_to_pad, device=x.device)
                    ], dim=0)
                
                # Get sequence indices for selected pairs
                selected_i = i_indices[selected_pair_indices]  # [seq_len]
                selected_j = j_indices[selected_pair_indices]  # [seq_len]
                
                # Gather original vectors for this batch's pairs
                vec_i = x[b, selected_i, :]  # [seq_len, d_model]
                vec_j = x[b, selected_j, :]  # [seq_len, d_model]
                
                # Create enhanced vectors for this batch
                enhanced = torch.cat([
                    selected_scores.unsqueeze(-1),  # [seq_len, 1]
                    vec_i,  # [seq_len, d_model]
                    vec_j   # [seq_len, d_model]
                ], dim=-1)  # [seq_len, d_model*2 + 1]
                
                # Process through bottleneck (per batch)
                batch_processed = self.bottleneck(enhanced)  # [seq_len, d_model]
            
            all_processed.append(batch_processed)
        
        # Stack batches
        attn_output = torch.stack(all_processed, dim=0)  # [batch_size, seq_len, d_model]
        
        # Output projection
        attn_output = self.w_o(attn_output)
        x = self.dropout(attn_output)  # No residual connection
        
        # Feedforward
        residual = x
        x = self.norm2(x)
        x = residual + self.ffn(x)
        
        return x

In [None]:
# Sparse Enhanced Attention Layer
class SparseEnhancedAttentionLayer(nn.Module):
    """
    Custom transformer layer that:
    1. Computes full attention matrix Q @ K^T
    2. Extracts upper triangular (without diagonal, i < j)
    3. Applies ReLU to sparsify (only keep positive values)
    4. Enriches non-zero attention scores with original vectors [attn_score, vec_i, vec_j]
    5. Bottlenecks enhanced vectors back to sequence length
    
    This bypasses the V network entirely and learns direct transformations.
    """
    
    def __init__(self, d_model, dim_feedforward=2048, dropout=0.1, 
                 bottleneck_hidden=None):
        super().__init__()
        self.d_model = d_model
        
        # Query and Key projections (no V projection needed!)
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        # Bottleneck MLP: processes [attn_score, vec_i, vec_j] -> d_model
        # Input: [attn_score (1) + vec_i (d_model) + vec_j (d_model)] = d_model*2 + 1
        bottleneck_hidden = bottleneck_hidden or dim_feedforward
        self.bottleneck = nn.Sequential(
            nn.Linear(d_model * 2 + 1, bottleneck_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(bottleneck_hidden, d_model),
            nn.Dropout(dropout)
        )
        
        # Feedforward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model),
            nn.Dropout(dropout)
        )
        
        # Layer norms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, src_key_padding_mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            src_key_padding_mask: [batch_size, seq_len] - True for padding tokens
        
        Returns:
            [batch_size, seq_len, d_model]
        """
        residual = x
        x = self.norm1(x)
        
        batch_size, seq_len, _ = x.shape
        
        # Compute Q, K (no V needed!)
        Q = self.w_q(x)  # [batch_size, seq_len, d_model]
        K = self.w_k(x)  # [batch_size, seq_len, d_model]
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_model)
        # [batch, seq_len, seq_len]
        
        # Apply padding mask if provided
        if src_key_padding_mask is not None:
            mask = src_key_padding_mask.unsqueeze(1)  # [batch, 1, seq_len]
            scores = scores.masked_fill(mask, float('-inf'))
            mask = src_key_padding_mask.unsqueeze(2)  # [batch, seq_len, 1]
            scores = scores.masked_fill(mask, float('-inf'))
        
        # Extract upper triangular WITHOUT diagonal (strictly upper: i < j)
        # Get indices for upper triangular (excluding diagonal)
        triu_indices = torch.triu_indices(seq_len, seq_len, offset=1, device=x.device)
        i_indices = triu_indices[0]  # [num_pairs]
        j_indices = triu_indices[1]  # [num_pairs]
        
        # Extract upper triangular attention scores
        upper_tri_scores = scores[:, i_indices, j_indices]  # [batch, num_pairs]
        
        # Apply ReLU to sparsify (set negatives to 0)
        sparse_scores = torch.relu(upper_tri_scores)  # [batch, num_pairs]
        
        # Find non-zero pairs (sparse attention)
        # For each batch, find which pairs have non-zero attention
        non_zero_mask = sparse_scores > 0  # [batch, num_pairs]
        
                # Vectorized: process all non-zero pairs across all batches at once
        # Get total number of non-zero pairs across all batches
        total_non_zeros = non_zero_mask.sum().item()
        
        if total_non_zeros == 0:
            # No non-zero attention in any batch, output zeros
            attn_output = torch.zeros(batch_size, seq_len, self.d_model, device=x.device)
        else:
            # Get batch indices and pair indices for all non-zero pairs
            batch_indices, pair_indices = torch.where(non_zero_mask)
            # batch_indices: [total_non_zeros] - which batch each pair belongs to
            # pair_indices: [total_non_zeros] - which pair index (in i_indices/j_indices)
            
            # Get the actual sequence indices for these pairs
            non_zero_i = i_indices[pair_indices]  # [total_non_zeros]
            non_zero_j = j_indices[pair_indices]  # [total_non_zeros]
            
            # Get attention scores for non-zero pairs
            non_zero_scores = sparse_scores[batch_indices, pair_indices]  # [total_non_zeros]
            
            # Gather original vectors for all non-zero pairs (fully vectorized)
            vec_i = x[batch_indices, non_zero_i, :]  # [total_non_zeros, d_model]
            vec_j = x[batch_indices, non_zero_j, :]  # [total_non_zeros, d_model]
            
            # Create enhanced vectors: [attn_score, vec_i, vec_j]
            enhanced = torch.cat([
                non_zero_scores.unsqueeze(-1),  # [total_non_zeros, 1]
                vec_i,  # [total_non_zeros, d_model]
                vec_j   # [total_non_zeros, d_model]
            ], dim=-1)  # [total_non_zeros, d_model*2 + 1]
            
            # Process through bottleneck (all at once, fully vectorized)
            processed = self.bottleneck(enhanced)  # [total_non_zeros, d_model]
            
            # Aggregate back to sequence length using index_add_ (vectorized per batch)
            # Initialize output tensor
            attn_output = torch.zeros(batch_size, seq_len, self.d_model, device=x.device)
            
            # Use index_add_ for each batch (still need batch loop for indexing, but bottleneck is vectorized)
            for b in range(batch_size):
                batch_mask = batch_indices == b
                if batch_mask.any():
                    batch_non_zero_i = non_zero_i[batch_mask]
                    batch_non_zero_j = non_zero_j[batch_mask]
                    batch_processed = processed[batch_mask]
                    
                    # Sum contributions where position appears as i
                    attn_output[b].index_add_(0, batch_non_zero_i, batch_processed)
                    # Also sum contributions where position appears as j
                    attn_output[b].index_add_(0, batch_non_zero_j, batch_processed)
        
        # Stack batch outputs
        attn_output = torch.stack(batch_outputs, dim=0)  # [batch, seq_len, d_model]
        
        # Output projection
        attn_output = self.w_o(attn_output)
        #x = residual + self.dropout(attn_output)
        
        # Feedforward
        residual = x
        x = self.norm2(x)
        x = residual + self.ffn(x)
        
        return x


In [15]:
class CustomTransformerEncoder(nn.Module):
    """
    Custom transformer encoder using CNN-filtered attention layers
    """
    def __init__(self, encoder_layer, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([encoder_layer for _ in range(num_layers)])
        
    def forward(self, x, src_key_padding_mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            src_key_padding_mask: [batch_size, seq_len] - True for padding tokens
        """
        for layer in self.layers:
            x = layer(x, src_key_padding_mask)
        return x

# Model

In [19]:
import wandb
wandb.login(key='9c6d131f5fcedb96565fa31f4680c2da83ea07d5')

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmatthiaskargl[0m ([33mmatthiaskargl-personal[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
class NextTokenPredictor(nn.Module):
    """
    Transformer model to predict next token from 3D vectors [value, x, y].
    """
    
    def __init__(self, vocab_size=18, d_model=16, nhead=8, num_layers=4, 
                 dim_feedforward=1024, max_seq_length=5400, dropout=0.1, unet_base_channels=32, unet_num_downsample=3):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_seq_length = max_seq_length
        
        # Embedding for token values (0-17)
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Projection for x, y coordinates (add coordinate information)
        self.coord_projection = nn.Linear(2, d_model)  # [x, y] -> d_model
        
        # Positional encoding (learned)
        #self.pos_encoding = nn.Parameter(torch.randn(max_seq_length, d_model) * 0.02)
        #self.pos_encoding = self.create_sinusoidal_positional_encoding(max_seq_length, d_model)
        pos_encoding = self.create_sinusoidal_positional_encoding(max_seq_length, d_model)
        self.register_buffer('pos_encoding', pos_encoding)
        
        
        USE_CUSTOM_ATTENTION = True  # Set to False for standard transformer
        
        if USE_CUSTOM_ATTENTION:
            # Custom sparse enhanced attention layer
            encoder_layer = SparseEnhancedAttentionLayerTopK(
                d_model=d_model,
                max_pairs_per_batch=seq_len,
                dim_feedforward=dim_feedforward,
                #dropout=dropout,
            )
            self.transformer = CustomTransformerEncoder(encoder_layer, num_layers=num_layers)
        else:
            # Standard transformer encoder
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                batch_first=True
            )
            self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Output projection to vocab
        self.output_proj = nn.Linear(d_model, vocab_size)
        
        self.dropout = nn.Dropout(dropout)
    
    def create_sinusoidal_positional_encoding(self, max_len, d_model):
        """
        Create sinusoidal positional encoding (no learnable parameters).
        
        Args:
            max_len: Maximum sequence length
            d_model: Model dimension
        
        Returns:
            [max_len, d_model] tensor with positional encodings
        """
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)  # Even indices: sin
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd indices: cos
        
        return pe  # [max_len, d_model]

        
    def forward(self, input_3d, attention_mask=None):
        """
        Args:
            input_3d: [batch_size, seq_len, 3] - [value, x, y] vectors
            attention_mask: [batch_size, seq_len] - 1 for real tokens, 0 for padding
        
        Returns:
            logits: [batch_size, seq_len, vocab_size] - logits for each position
        """
        batch_size, seq_len, _ = input_3d.shape
        
        # Extract components
        token_values = input_3d[:, :, 0].long()  # [batch_size, seq_len] - token values
        coordinates = input_3d[:, :, 1:3].float()  # [batch_size, seq_len, 2] - x, y
        
        # Embed tokens
        token_emb = self.token_embedding(token_values)  # [batch_size, seq_len, d_model]
        
        # Add coordinate information
        coord_emb = self.coord_projection(coordinates)  # [batch_size, seq_len, d_model]
        
        # Combine token and coordinate embeddings
        x = token_emb + coord_emb  # [batch_size, seq_len, d_model]
        
        # Add positional encoding
        x = x + self.pos_encoding[:seq_len].unsqueeze(0)  # [batch_size, seq_len, d_model]
        
        #x = self.dropout(x)
        if attention_mask is not None:
            padding_mask = (attention_mask == 0).bool()  # True for padding, False for real tokens
        else:
            padding_mask = None
        
        # Apply transformer
        x = self.transformer(x, src_key_padding_mask=padding_mask)  # [batch_size, seq_len, d_model]
        
        # Get logits for all positions
        logits = self.output_proj(x)  # [batch_size, seq_len, vocab_size]
        
        return logits


In [21]:
# Create model
model = NextTokenPredictor(
    vocab_size=tokenizer.vocab_size,
    d_model=16,
    nhead=1,
    num_layers=1,
    dim_feedforward=128,
    max_seq_length=5400,
    dropout=0.1,
    unet_base_channels=2,
    unet_num_downsample=2,
).to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Model device: {next(model.parameters()).device}")

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)


Model created with 1,933,876 parameters
Model device: cpu


# Training

In [None]:
# Evaluation function
import os

def save_checkpoint(model, optimizer, epoch, batch_idx, val_loss, val_acc, train_loss, train_acc, 
                   checkpoint_dir='checkpoints', is_best=False):
    """
    Save model checkpoint
    
    Args:
        model: The model to save
        optimizer: The optimizer to save
        epoch: Current epoch number
        batch_idx: Current batch index
        val_loss: Validation loss
        val_acc: Validation accuracy
        train_loss: Training loss
        train_acc: Training accuracy
        checkpoint_dir: Directory to save checkpoints
        is_best: Whether this is the best model so far
    """
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    checkpoint = {
        'epoch': epoch,
        'batch': batch_idx,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
        'val_accuracy': val_acc,
        'train_loss': train_loss,
        'train_accuracy': train_acc,
    }
    
    # Save regular checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch{epoch}_batch{batch_idx}.pt')
    torch.save(checkpoint, checkpoint_path)
    
    # Save best model if applicable
    if is_best:
        best_path = os.path.join(checkpoint_dir, 'best_model.pt')
        torch.save(checkpoint, best_path)
        print(f"  ✓ Best model saved: {best_path}")
    
    return checkpoint_path
    
def evaluate(model, dataloader, criterion, device):
    """Evaluate model on dataset"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    last_logits = None  # Initialize
    last_targets = None  # Initialize
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_3d = batch['input_3d'].to(device)
            target_values = batch['target_values'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            logits = model(input_3d=input_3d, attention_mask=attention_mask)
            
            batch_size = input_3d.size(0)
            seq_lengths = attention_mask.sum(dim=1) - 1
            last_logits = logits[torch.arange(batch_size), seq_lengths]
            last_targets = target_values
            
            loss = criterion(last_logits, target_values)
            total_loss += loss.item()
            
            predictions = last_logits.argmax(dim=1)
            correct += (predictions == target_values).sum().item()
            total += target_values.size(0)
    if last_logits is not None and last_targets is not None:
        print(f"Sample predictions: {last_logits[0:3].argmax(dim=1)}, targets: {last_targets[0:3]}")
    else:
        print("No predictions or targets available")
    
    
    avg_loss = total_loss / len(dataloader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy

# Updated train_epoch with periodic validation during training
def train_epoch(model, train_dataloader, tiny_val_loader, full_val_loader, criterion, optimizer, device, 
                log_every_n_batches=2, tiny_val_every_n_batches=10, full_val_every_n_batches=200):
    """
    Train for one epoch with periodic validation using two validation sets
    
    Args:
        model: The model to train
        train_dataloader: Training data loader
        tiny_val_loader: Tiny validation loader (for frequent checks)
        full_val_loader: Full validation loader (for accurate metrics)
        criterion: Loss function
        optimizer: Optimizer
        device: Device to run on
        log_every_n_batches: Log to wandb every N batches
        tiny_val_every_n_batches: Run tiny validation every N batches (default: 10)
        full_val_every_n_batches: Run full validation every N batches (default: 200)
    """
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    best_val_los=1e6
    
    pbar = tqdm(train_dataloader, desc="Training")
    for batch_idx, batch in enumerate(pbar):
        # Move to device
        input_3d = batch['input_3d'].to(device)  # [batch_size, seq_len, 3]
        target_values = batch['target_values'].to(device)  # [batch_size]
        attention_mask = batch['attention_mask'].to(device)  # [batch_size, seq_len]
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(input_3d=input_3d, attention_mask=attention_mask)  # [batch_size, seq_len, vocab_size]
        
        # Get logits for the last non-padding position (where we predict)
        # Find last non-padding position for each sequence
        batch_size = input_3d.size(0)
        seq_lengths = attention_mask.sum(dim=1) - 1  # -1 because we want the position before the target
        last_logits = logits[torch.arange(batch_size), seq_lengths]  # [batch_size, vocab_size]
        
        # Compute loss
        loss = criterion(last_logits, target_values)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()
        
        # Metrics
        total_loss += loss.item()
        predictions = last_logits.argmax(dim=1)
        batch_correct = (predictions == target_values).sum().item()
        correct += batch_correct
        total += target_values.size(0)
        batch_acc = 100 * batch_correct / target_values.size(0)
        
        # Log to wandb every N batches (default: every other batch)
        if batch_idx % log_every_n_batches == 0:
            wandb.log({
                "batch_loss": loss.item(),
                "batch_accuracy": batch_acc,
                "running_accuracy": 100 * correct / total,
            })
        
        # Tiny validation (frequent, quick check)
        if (batch_idx + 1) % tiny_val_every_n_batches == 0:
            tiny_val_loss, tiny_val_acc = evaluate(model, tiny_val_loader, criterion, device)
            
            # Log tiny validation metrics
            wandb.log({
                "tiny_val_loss": tiny_val_loss,
                "tiny_val_accuracy": tiny_val_acc,
                "train_batch": batch_idx + 1,
            })
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100 * correct / total:.2f}%',
                'tiny_val': f'{tiny_val_acc:.1f}%'
            })
            
            model.train()  # Switch back to training mode
        
        # Full validation (less frequent, more accurate)
        if (batch_idx + 1) % full_val_every_n_batches == 0:
            full_val_loss, full_val_acc = evaluate(model, full_val_loader, criterion, device)
            is_better=full_val_loss<best_val_los
            if is_better:
                best_val_los=full_val_loss

            save_checkpoint(model, optimizer, epoch, batch_idx, full_val_loss, full_val_acc, total_loss, batch_acc, 
                   checkpoint_dir='checkpoints', is_best=is_better)
            
            # Log full validation metrics
            wandb.log({
                "full_val_loss": full_val_loss,
                "full_val_accuracy": full_val_acc,
                "train_batch": batch_idx + 1,
            })
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100 * correct / total:.2f}%',
                'full_val': f'{full_val_acc:.1f}%'
            })
            
            model.train()  # Switch back to training mode
        
        # Update progress bar (if no validation was run this batch)
        if (batch_idx + 1) % tiny_val_every_n_batches != 0 and (batch_idx + 1) % full_val_every_n_batches != 0:
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100 * correct / total:.2f}%'
            })
    
    avg_loss = total_loss / len(train_dataloader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy


In [None]:
# Updated training loop with validation and wandb logging
num_epochs = 1
print(f"\nStarting training for {num_epochs} epochs...")
print("=" * 60)

best_val_acc = 0.0

# Validation frequencies
tiny_val_every_n_batches = 1   # Tiny validation every 10 batches (~30 seconds)
full_val_every_n_batches = 5  # Full validation every 200 batches (~5-10 minutes)

wandb.init(
    name='test-aug-colors-unet_try0',
    project="arc-next-token-prediction",
    config={
        "vocab_size": tokenizer.vocab_size,
        "d_model": model.d_model,
        #"nhead": model.transformer.layers[0].self_attn.num_heads,
        "num_layers": len(model.transformer.layers),
        "batch_size": batch_size,
        "learning_rate": optimizer.param_groups[0]['lr'],
        "max_seq_length": model.max_seq_length,
    }
)

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    print("-" * 60)

    train_loss, train_acc = train_epoch(
        model, 
        train_loader, 
        tiny_val_loader,  # Tiny validation for frequent checks
        full_val_loader,  # Full validation for accurate metrics
        criterion, 
        optimizer, 
        device,
        tiny_val_every_n_batches=tiny_val_every_n_batches,
        full_val_every_n_batches=full_val_every_n_batches
    )
    
    # Update learning rate
    scheduler.step()
    
    # Log to wandb
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "train_accuracy": train_acc,
        "val_loss": val_loss,
        "val_accuracy": val_acc,
        "learning_rate": optimizer.param_groups[0]['lr']
    })
    
    # Print results
    print(f"\nEpoch {epoch + 1} Results:")
    print(f"  Train - Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%")
    print(f"  Val   - Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%")
    print(f"  Learning rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Track best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        print(f"  ✓ New best validation accuracy: {best_val_acc:.2f}%")
        # Optionally save model checkpoint
        # torch.save(model.state_dict(), 'best_model.pt')
    
print("\n" + "=" * 60)
print("Training completed!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")
print("=" * 60)

wandb.finish()



Starting training for 1 epochs...



Epoch 1/1
------------------------------------------------------------


Training:   0%|          | 0/631 [00:00<?, ?it/s]