In [None]:
import json
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pprint import pp

In [44]:
training_data = json.load(open("../data/arc-agi_training_challenges.json"))
training_solutions = json.load(open("../data/arc-agi_training_solutions.json"))

In [3]:
training_data['00576224']

{'train': [{'input': [[7, 9], [4, 3]],
   'output': [[7, 9, 7, 9, 7, 9],
    [4, 3, 4, 3, 4, 3],
    [9, 7, 9, 7, 9, 7],
    [3, 4, 3, 4, 3, 4],
    [7, 9, 7, 9, 7, 9],
    [4, 3, 4, 3, 4, 3]]},
  {'input': [[8, 6], [6, 4]],
   'output': [[8, 6, 8, 6, 8, 6],
    [6, 4, 6, 4, 6, 4],
    [6, 8, 6, 8, 6, 8],
    [4, 6, 4, 6, 4, 6],
    [8, 6, 8, 6, 8, 6],
    [6, 4, 6, 4, 6, 4]]}],
 'test': [{'input': [[3, 2], [7, 8]]}]}

In [47]:
training_solutions['00576224']

[[[3, 2, 3, 2, 3, 2],
  [7, 8, 7, 8, 7, 8],
  [2, 3, 2, 3, 2, 3],
  [8, 7, 8, 7, 8, 7],
  [3, 2, 3, 2, 3, 2],
  [7, 8, 7, 8, 7, 8]]]

In [None]:
class Task():
    def __init__(self, task, name):
        self.tasks = [*task['train'], task['test']]
        self.name = name

class ARCDataset(Dataset):
    def __init__(self, data: dict, test_set: bool = False):
        if test_set:
            pass

        else:
            self.data = []
            for name, task in data.items():
                self.data.append(Task(task, name))
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return item


In [4]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(nn.Conv2d(input_dim, hidden_dim, kernel_size=3, stride=1, padding=2))
            input_dim = hidden_dim

    def forward(self, x):
        for layer in self.layers:
            x = F.gelu(layer(x))

        x = torch.mean(x, dim=(2, 3))
        return x

class DiffusionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(DiffusionModel, self).__init__()
        self.linear_layers = nn.ModuleList()

        for _ in range(num_layers):
            self.linear_layers.append(nn.Linear(output_dim, output_dim))

    def get_2d_positional_encoding(self, height, width, d_model, device='cpu'):
        """
        Generate 2D positional encoding for an image of size (height, width).
        Returns a tensor of shape (height, width, d_model).
        """
        if d_model % 4 != 0:
            raise ValueError("d_model must be divisible by 4 for 2D positional encoding.")

        # (H, W, D)
        pe = torch.zeros(height, width, d_model, device=device)
        d_model_half = d_model // 2

        # Positions
        y_pos = torch.arange(height, device=device).unsqueeze(1).float()  # (H,1)
        x_pos = torch.arange(width, device=device).unsqueeze(0).float()   # (1,W)

        # frequency terms
        div_term = torch.exp(
            torch.arange(0, d_model_half, 2, device=device).float()
            * (-(torch.log(torch.tensor(10000.0, device=device)) / d_model_half))
        )  # (D/4,)

        # Angles with explicit last-dim for broadcasting
        y_angles = y_pos.unsqueeze(-1) * div_term  # (H,1,D/4)
        x_angles = x_pos.unsqueeze(-1) * div_term  # (1,W,D/4)

        # Sine/Cosine
        y_sin, y_cos = torch.sin(y_angles), torch.cos(y_angles)  # (H,1,D/4)
        x_sin, x_cos = torch.sin(x_angles), torch.cos(x_angles)  # (1,W,D/4)

        # Assign: Y encodings in first half, X in second half
        pe[:, :, 0:d_model_half:2] = y_sin
        pe[:, :, 1:d_model_half:2] = y_cos
        pe[:, :, d_model_half::2] = x_sin
        pe[:, :, d_model_half + 1::2] = x_cos

        return pe  # (height, width, d_model)

    def forward(self, x):
        return x

In [31]:
arc_dataset = ARCDataset(training_data)
training_loader = DataLoader(arc_dataset, batch_size=32, shuffle=True)



In [41]:
for x in arc_dataset:
    pp(x.tasks)
    break

[{'input': [[7, 9], [4, 3]],
  'output': [[7, 9, 7, 9, 7, 9],
             [4, 3, 4, 3, 4, 3],
             [9, 7, 9, 7, 9, 7],
             [3, 4, 3, 4, 3, 4],
             [7, 9, 7, 9, 7, 9],
             [4, 3, 4, 3, 4, 3]]},
 {'input': [[8, 6], [6, 4]],
  'output': [[8, 6, 8, 6, 8, 6],
             [6, 4, 6, 4, 6, 4],
             [6, 8, 6, 8, 6, 8],
             [4, 6, 4, 6, 4, 6],
             [8, 6, 8, 6, 8, 6],
             [6, 4, 6, 4, 6, 4]]},
 [{'input': [[3, 2], [7, 8]]}]]


In [5]:
dm = DiffusionModel(1, 32, 64, 2)
pe = dm.get_2d_positional_encoding(5, 7, 64)
print(pe.shape)
try:
    dm.get_2d_positional_encoding(5, 7, 62)
except ValueError as e:
    print("caught:", str(e)[:50])

torch.Size([5, 7, 64])
caught: d_model must be divisible by 4 for 2D positional e


In [None]:
# ==== Utilities and Dataset for ARC A->B pairs ====
import math
from typing import List, Tuple, Dict, Any

NUM_COLORS = 10  # ARC colors 0..9

def grid_to_onehot(grid: List[List[int]]) -> torch.Tensor:
    """
    Convert HxW int grid (values in [0,9]) to (C=10, H, W) one-hot float tensor.
    """
    h = len(grid)
    w = len(grid[0]) if h > 0 else 0
    t = torch.tensor(grid, dtype=torch.long)
    oh = F.one_hot(t, num_classes=NUM_COLORS).permute(2, 0, 1).float()  # (C,H,W)
    return oh


def onehot_to_labels(oh: torch.Tensor) -> torch.Tensor:
    """
    (C,H,W) -> (H,W) argmax labels.
    """
    return oh.argmax(dim=0)


class ARCPairsDataset(Dataset):
    """
    Produces (A, B, task_id) from ARC training data where A=input grid, B=output grid.
    """
    def __init__(self, data: Dict[str, Any]):
        self.samples = []
        for task_id, task in data.items():
            for io in task.get('train', []):
                A = io['input']
                B = io['output']
                self.samples.append({
                    'task_id': task_id,
                    'A': A,
                    'B': B,
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        A_oh = grid_to_onehot(s['A'])  # (C,H,W)
        B_oh = grid_to_onehot(s['B'])
        A_lbl = torch.tensor(s['A'], dtype=torch.long)  # (H,W)
        B_lbl = torch.tensor(s['B'], dtype=torch.long)
        return {
            'task_id': s['task_id'],
            'A_oh': A_oh,
            'B_oh': B_oh,
            'A_lbl': A_lbl,
            'B_lbl': B_lbl,
            'A_size': A_lbl.shape,  # (H,W)
            'B_size': B_lbl.shape,
        }


def collate_varsize(batch: List[Dict[str, Any]]):
    """
    Keep variable-sized tensors in lists; stack what's stackable.
    """
    out = {
        'task_id': [b['task_id'] for b in batch],
        'A_oh': [b['A_oh'] for b in batch],
        'B_oh': [b['B_oh'] for b in batch],
        'A_lbl': [b['A_lbl'] for b in batch],
        'B_lbl': [b['B_lbl'] for b in batch],
        'A_size': [b['A_size'] for b in batch],
        'B_size': [b['B_size'] for b in batch],
    }
    return out

In [None]:
# ==== Encoder-Decoder with shared weights ====
class ConvEncoder(nn.Module):
    def __init__(self, in_channels=NUM_COLORS, d_model=128):
        super().__init__()
        # simple CNN to get global embedding
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.proj = nn.Linear(128, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,C,H,W) or (C,H,W)
        if x.dim() == 3:
            x = x.unsqueeze(0)
        h = self.net(x)  # (B,128,H,W)
        h = h.mean(dim=(2, 3))  # (B,128)
        z = self.proj(h)  # (B,d_model)
        z = F.normalize(z, dim=-1)
        return z


class ConvDecoder(nn.Module):
    def __init__(self, out_channels=NUM_COLORS, d_model=128):
        super().__init__()
        # decode with learned constant spatial tokens + convs
        self.d_model = d_model
        self.pos_h = 30
        self.pos_w = 30
        self.pos = nn.Parameter(torch.randn(1, d_model, self.pos_h, self.pos_w) * 0.02)
        self.z_proj = nn.Linear(d_model, 128)
        self.pos_proj = nn.Conv2d(d_model, 128, kernel_size=1)
        self.net = nn.Sequential(
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, out_channels, 1),
        )

    def forward(self, z: torch.Tensor, out_size: Tuple[int, int]) -> torch.Tensor:
        # z: (B,d_model), out: (B,C,H,W) logits resized to out_size
        if z.dim() == 1:
            z = z.unsqueeze(0)
        B, D = z.shape
        H, W = out_size
        base = self.pos.expand(B, -1, -1, -1)  # (B,D,h,w)
        base128 = F.relu(self.pos_proj(base))  # (B,128,h,w)
        z_expand = self.z_proj(z).unsqueeze(-1).unsqueeze(-1)  # (B,128,1,1)
        h = base128 + z_expand
        logits = self.net(h)  # (B,C,h,w)
        logits = F.interpolate(logits, size=(H, W), mode='nearest')
        return logits


class DirectionARC(nn.Module):
    def __init__(self, d_model=128):
        super().__init__()
        self.encoder = ConvEncoder(d_model=d_model)
        self.decoder = ConvDecoder(d_model=d_model)

    def encode(self, x_chw: torch.Tensor) -> torch.Tensor:
        return self.encoder(x_chw)

    def decode(self, z: torch.Tensor, out_size: Tuple[int, int]) -> torch.Tensor:
        return self.decoder(z, out_size)  # logits (B,C,H,W)

    def forward(self, A_oh: torch.Tensor, B_oh: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[int,int]]:
        # Accept single example (C,H,W)
        H, W = A_oh.shape[-2], A_oh.shape[-1]
        zA = self.encode(A_oh)
        zB = self.encode(B_oh)
        logitsA = self.decode(zA, (H, W))
        logitsB = self.decode(zB, (H, W))
        return zA, zB, logitsB, (H, W)

In [None]:
# ==== Losses: direction contrastive and reconstruction ====

def cosine_sim(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8):
    a = F.normalize(a, dim=-1)
    b = F.normalize(b, dim=-1)
    return (a * b).sum(dim=-1)


def supervised_direction_loss(dirs: torch.Tensor, task_ids: List[str], temp: float = 0.1):
    """
    Supervised contrastive loss on direction vectors. Positives are same-task pairs, negatives different tasks.
    dirs: (N, d) normalized direction vectors
    task_ids: list of length N
    """
    N = dirs.size(0)
    dirs = F.normalize(dirs, dim=-1)
    # map task ids to ints
    uniq = {tid: i for i, tid in enumerate(sorted(set(task_ids)))}
    labels = torch.tensor([uniq[t] for t in task_ids], device=dirs.device)

    sim = dirs @ dirs.t() / temp  # (N,N)
    # mask self
    eye = torch.eye(N, device=dirs.device, dtype=torch.bool)
    sim = sim.masked_fill(eye, float('-inf'))

    loss_sum = 0.0
    count = 0
    for i in range(N):
        pos_mask = (labels == labels[i]) & (~eye[i])  # exclude self
        num_pos = int(pos_mask.sum().item())
        if num_pos == 0:
            continue
        # log-softmax over all j != i
        logits_i = sim[i]  # (N,)
        log_prob = logits_i - torch.logsumexp(logits_i, dim=0)
        loss_i = -log_prob[pos_mask].mean()
        loss_sum = loss_sum + loss_i
        count += 1
    if count == 0:
        return torch.tensor(0.0, device=dirs.device, requires_grad=True)
    return loss_sum / count


def reconstruction_loss(logits: torch.Tensor, target_labels: torch.Tensor):
    """
    Cross-entropy over pixels. logits: (B,C,H,W) or (C,H,W). target: (H,W)
    """
    if logits.dim() == 3:
        logits = logits.unsqueeze(0)
    B, C, H, W = logits.shape
    target = target_labels.view(1, H, W).expand(B, -1, -1)
    return F.cross_entropy(logits, target)

In [None]:
# ==== Minimal training loop (single step sanity check) ====
from collections import defaultdict

device = 'cuda' if torch.cuda.is_available() else 'cpu'

pairs_ds = ARCPairsDataset(training_data)
# We'll create batches of size N and form two views per batch by shuffling indices.
loader = DataLoader(pairs_ds, batch_size=8, shuffle=True, collate_fn=collate_varsize)

model = DirectionARC(d_model=128).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

alpha_dir = 1.0
alpha_rec = 1.0

batch = next(iter(loader))
# Prepare tensors on device; keep variable sizes per item
zA_list, zB_list, zC_list, zD_list = [], [], [], []
rec_logits_list, rec_targets_list = [], []

# We create two groups in the same batch by splitting it in half.
# Assumption: items from same task may repeat across epochs, but for this sanity we pair indices 0..k-1 with k..2k-1.
Bsz = len(batch['A_oh'])
mid = Bsz // 2 if Bsz >= 2 else Bsz
idx1 = list(range(0, mid))
idx2 = list(range(mid, min(Bsz, mid*2)))
if len(idx2) < len(idx1):
    idx1 = idx1[:len(idx2)]

model.train()
opt.zero_grad()

for i, j in zip(idx1, idx2):
    A_oh = batch['A_oh'][i].unsqueeze(0).to(device)  # (1,C,H,W)
    B_oh = batch['B_oh'][i].unsqueeze(0).to(device)
    C_oh = batch['A_oh'][j].unsqueeze(0).to(device)
    D_oh = batch['B_oh'][j].unsqueeze(0).to(device)

    zA = model.encode(A_oh)
    zB = model.encode(B_oh)
    zC = model.encode(C_oh)
    zD = model.encode(D_oh)

    zA_list.append(zA)
    zB_list.append(zB)
    zC_list.append(zC)
    zD_list.append(zD)

    # Reconstruction for B
    H, W = batch['B_size'][i]
    logitsB = model.decode(zB, (H, W))  # (1,C,H,W)
    targetB = batch['B_lbl'][i].to(device)
    rec_logits_list.append(logitsB)
    rec_targets_list.append(targetB)

if len(zA_list) == 0:
    raise RuntimeError('Batch too small to form pairs for direction loss.')

zA = torch.cat(zA_list, dim=0)
zB = torch.cat(zB_list, dim=0)
zC = torch.cat(zC_list, dim=0)
zD = torch.cat(zD_list, dim=0)

loss_dir = direction_contrastive_loss(zA, zB, zC, zD, temp=0.1)

# Reconstruction loss over all collected targets
loss_rec = 0.0
for logits, tgt in zip(rec_logits_list, rec_targets_list):
    loss_rec = loss_rec + reconstruction_loss(logits, tgt)
loss_rec = loss_rec / max(1, len(rec_logits_list))

loss = alpha_dir * loss_dir + alpha_rec * loss_rec
loss.backward()
opt.step()

print({
    'loss': float(loss.item()),
    'loss_dir': float(loss_dir.item()),
    'loss_rec': float(loss_rec.item()),
})