# Tiny-ONN-ARC


In [None]:
import torch
import torch.nn.functional as F
import json
from pathlib import Path
from collections import defaultdict
import sys

# Add parent directory to path to import TinyOnnArcConfig, TinyOnnForArcReconstruction, ArcDataset
sys.path.append('../exp/tiny_onn_arc')
from config import TinyOnnArcConfig
from model import TinyOnnForArcReconstruction
from data import ArcDataset, pad_grid, augment_grid


In [1]:
# 2. 定义模型配置和加载模型
config = TinyOnnArcConfig()
device = "cuda" if torch.cuda.is_available() else "cpu"

model = TinyOnnForArcReconstruction(config).to(device)
checkpoint_path = Path("exp/ARC-Killer.pt") # 预训练权重路径
if checkpoint_path.exists():
    ckpt = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    print(f"Loaded model from {checkpoint_path}")
else:
    print(f"Warning: Checkpoint {checkpoint_path} not found. Using randomly initialized model.")

model.eval()


NameError: name 'TinyOnnArcConfig' is not defined

In [None]:
# 3. 数据处理辅助函数
newline_token = 10
pad_token = 0 # For submission, we assume 0 is padding

def to_grid_and_crop(seq: torch.Tensor, original_h: int, original_w: int) -> list[list[int]]:
    # Remove newline tokens and pad tokens (0s) from the sequence
    pixel_seq = seq[seq != newline_token]
    pixel_seq = pixel_seq[pixel_seq != pad_token]

    # If the sequence is empty after removing special tokens, return an empty 1x1 grid or handle as error
    if pixel_seq.numel() == 0:
        return [[0]] # Return a minimal valid grid

    # Attempt to infer the grid dimensions based on the original output dimensions
    # This is a heuristic and might need refinement based on model behavior
    # For now, we'll try to reshape to original_h x original_w, then crop
    
    # First, try to reshape to a square or a rectangle that fits the number of pixels
    # This is a simplification. A more robust solution might involve trying various factors
    # or using a separate model to predict output dimensions.
    
    # For ARC, output grids are often small and compact. We'll try to find the smallest bounding box.
    # Reshape to a large enough grid (e.g., 30x30) and then find the actual content bounds.
    max_dim = 30 # Max possible dimension for ARC grids
    padded_pixel_seq = F.pad(pixel_seq, (0, max_dim * max_dim - pixel_seq.numel()), "constant", pad_token)
    temp_grid = padded_pixel_seq.view(max_dim, max_dim)

    # Find the bounding box of non-padding pixels
    rows = torch.any(temp_grid != pad_token, dim=1)
    cols = torch.any(temp_grid != pad_token, dim=0)

    if not torch.any(rows) or not torch.any(cols): # If all pixels are padding
        return [[0]]

    min_r, max_r = torch.where(rows)[0].min(), torch.where(rows)[0].max()
    min_c, max_c = torch.where(cols)[0].min(), torch.where(cols)[0].max()

    cropped_grid = temp_grid[min_r : max_r + 1, min_c : max_c + 1]

    return cropped_grid.tolist()

def preprocess_input_grid(grid: list[list[int]], max_h: int = 30, max_w: int = 30) -> torch.Tensor:
    input_tensor = pad_grid(grid, max_h, max_w)
    # No augmentation for evaluation
    
    input_rows = [torch.cat((row, torch.tensor([newline_token], dtype=torch.long))) for row in input_tensor]
    input_seq = torch.cat(input_rows)
    return input_seq


In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tiny-ONN-ARC\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import json\n",
    "from pathlib import Path\n",
    "from collections import defaultdict\n",
    "import sys\n",
    "\n",
    "# Add parent directory to path to import TinyOnnArcConfig, TinyOnnForArcReconstruction, ArcDataset\n",
    "sys.path.append('../exp/tiny_onn_arc')\n",
    "from config import TinyOnnArcConfig\n",
    "from model import TinyOnnForArcReconstruction\n",
    "from data import ArcDataset, pad_grid, augment_grid\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'TinyOnnArcConfig' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mNameError\u001b[39m                                 Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m      1\u001b[39m \u001b[38;5;66;03m# 2. 定义模型配置和加载模型\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m config = \u001b[43mTinyOnnArcConfig\u001b[49m()\n\u001b[32m      3\u001b[39m device = \u001b[33m\"\u001b[39m\u001b[33mcuda\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m torch.cuda.is_available() \u001b[38;5;28;01melse\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mcpu\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m      5\u001b[39m model = TinyOnnForArcReconstruction(config).to(device)\n",
      "\u001b[31mNameError\u001b[39m: name 'TinyOnnArcConfig' is not defined"
     ]
    }
   ],
   "source": [
    "# 2. 定义模型配置和加载模型\n",
    "config = TinyOnnArcConfig()\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "model = TinyOnnForArcReconstruction(config).to(device)\n",
    "checkpoint_path = Path(\"exp/ARC-Killer.pt\") # 预训练权重路径\n",
    "if checkpoint_path.exists():\n",
    "    ckpt = torch.load(checkpoint_path, map_location=device)\n",
    "    model.load_state_dict(ckpt['model_state_dict'])\n",
    "    print(f\"Loaded model from {checkpoint_path}\")\n",
    "else:\n",
    "    print(f\"Warning: Checkpoint {checkpoint_path} not found. Using randomly initialized model.\")\n",
    "\n",
    "model.eval()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3. 数据处理辅助函数\n",
    "newline_token = 10\n",
    "pad_token = 0 # For submission, we assume 0 is padding\n",
    "\n",
    "def to_grid_and_crop(seq: torch.Tensor, original_h: int, original_w: int) -> list[list[int]]:\n",
    "    # Remove newline tokens and pad tokens (0s) from the sequence\n",
    "    pixel_seq = seq[seq != newline_token]\n",
    "    pixel_seq = pixel_seq[pixel_seq != pad_token]\n",
    "\n",
    "    # If the sequence is empty after removing special tokens, return an empty 1x1 grid or handle as error\n",
    "    if pixel_seq.numel() == 0:\n",
    "        return [[0]] # Return a minimal valid grid\n",
    "\n",
    "    # Attempt to infer the grid dimensions based on the original output dimensions\n",
    "    # This is a heuristic and might need refinement based on model behavior\n",
    "    # For now, we'll try to reshape to original_h x original_w, then crop\n",
    "    \n",
    "    # First, try to reshape to a square or a rectangle that fits the number of pixels\n",
    "    # This is a simplification. A more robust solution might involve trying various factors\n",
    "    # or using a separate model to predict output dimensions.\n",
    "    \n",
    "    # For ARC, output grids are often small and compact. We'll try to find the smallest bounding box.\n",
    "    # Reshape to a large enough grid (e.g., 30x30) and then find the actual content bounds.\n",
    "    max_dim = 30 # Max possible dimension for ARC grids\n",
    "    padded_pixel_seq = F.pad(pixel_seq, (0, max_dim * max_dim - pixel_seq.numel()), \"constant\", pad_token)\n",
    "    temp_grid = padded_pixel_seq.view(max_dim, max_dim)\n",
    "\n",
    "    # Find the bounding box of non-padding pixels\n",
    "    rows = torch.any(temp_grid != pad_token, dim=1)\n",
    "    cols = torch.any(temp_grid != pad_token, dim=0)\n",
    "\n",
    "    if not torch.any(rows) or not torch.any(cols): # If all pixels are padding\n",
    "        return [[0]]\n",
    "\n",
    "    min_r, max_r = torch.where(rows)[0].min(), torch.where(rows)[0].max()\n",
    "    min_c, max_c = torch.where(cols)[0].min(), torch.where(cols)[0].max()\n",
    "\n",
    "    cropped_grid = temp_grid[min_r : max_r + 1, min_c : max_c + 1]\n",
    "\n",
    "    return cropped_grid.tolist()\n",
    "\n",
    "def preprocess_input_grid(grid: list[list[int]], max_h: int = 30, max_w: int = 30) -> torch.Tensor:\n",
    "    input_tensor = pad_grid(grid, max_h, max_w)\n",
    "    # No augmentation for evaluation\n",
    "    \n",
    "    input_rows = [torch.cat((row, torch.tensor([newline_token], dtype=torch.long))) for row in input_tensor]\n",
    "    input_seq = torch.cat(input_rows)\n",
    "    return input_seq\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 4. 实现推理逻辑\n",
    "def predict_arc_output(model: TinyOnnForArcReconstruction, input_grid: list[list[int]], max_len: int = 1861) -> list[list[int]]:\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        input_seq = preprocess_input_grid(input_grid).to(model.model.embeddings.position_ids.device)\n",
    "        \n",
    "        # Autoregressive generation\n",
    "        generated_sequence = input_seq.clone()\n",
    "        \n",
    "        # Determine the start of the output sequence based on input length\n",
    "        output_start_idx = len(input_seq)\n",
    "\n",
    "        for _ in range(max_len - len(input_seq)): # Generate up to max_len tokens\n",
    "            if generated_sequence.numel() >= config.max_position_embeddings:\n",
    "                break\n",
    "            \n",
    "            logits, _ = model(input_ids=generated_sequence.unsqueeze(0))\n",
    "            next_token_logits = logits[0, -1, :]\n",
    "            next_token = torch.argmax(next_token_logits).item()\n",
    "            \n",
    "            generated_sequence = torch.cat((generated_sequence, torch.tensor([next_token], device=generated_sequence.device)))\n",
    "            \n",
    "            # Stop if we generate a newline token and the sequence length is reasonable\n",
    "            # This heuristic might need adjustment\n",
    "            if next_token == newline_token and generated_sequence.numel() > output_start_idx + 1:\n",
    "                # Check if we have generated enough rows to form a grid\n",
    "                # A simple heuristic: if we have at least one full row (width + newline)\n",
    "                # and the last token was a newline, we can consider stopping.\n",
    "                # This is a simplification; a more robust approach would involve predicting output dimensions.\n",
    "                # For now, we'll rely on the to_grid_and_crop to handle the final shape.\n",
    "                pass # Continue generating until max_len or a more robust stop condition\n",
    "        \n",
    "        # Extract only the generated output part\n",
    "        predicted_output_seq = generated_sequence[output_start_idx:]\n",
    "        \n",
    "        # Convert to 2D grid and crop for attempt_1\n",
    "        # For attempt_1, we'll use a simpler to_grid that might include padding to 30x30\n",
    "        # For now, let's use to_grid_and_crop for both, as it's more robust for ARC submission.\n",
    "        # If the model is trained to output 30x30 padded grids, this needs adjustment.\n",
    "        # Based on train.py's visualize_predictions, it pads to 30x30, so we should too for attempt_1\n",
    "        \n",
    "        # Re-implement to_grid from train.py for attempt_1\n",
    "        def to_grid_for_attempt1(seq: torch.Tensor, h: int = 30, w: int = 30) -> list[list[int]]:\n",
    "            pixel_seq = seq[seq != newline_token]\n",
    "            pixel_seq = pixel_seq[pixel_seq != -100] # Remove -100 pad token from labels\n",
    "            \n",
    "            num_pixels = pixel_seq.numel()\n",
    "            if num_pixels > h * w:\n",
    "                pixel_seq = pixel_seq[:h*w]\n",
    "            elif num_pixels < h * w:\n",
    "                pixel_seq = F.pad(pixel_seq, (0, h * w - num_pixels), \"constant\", 0)\n",
    "            \n",
    "            return pixel_seq.view(h, w).tolist()\n",
    "            \n",
    "        attempt1_grid = to_grid_for_attempt1(predicted_output_seq)\n",
    "        attempt2_grid = to_grid_and_crop(predicted_output_seq, 0, 0) # original_h, original_w are not used in to_grid_and_crop\n",
    "        \n",
    "        return attempt1_grid, attempt2_grid\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "{\n",
    " \"cells\": [\n",
    "  {\n",
    "   \"cell_type\": \"markdown\",\n",
    "   \"metadata\": {},\n",
    "   \"source\": [\n",
    "    \"# Kaggle ARC Prize 2025 - Tiny-ONN-ARC Submission Notebook\"\n",
    "   ]\n",
    "  },\n",
    "  {\n",
    "   \"cell_type\": \"code\",\n",
    "   \"execution_count\": null,\n",
    "   \"metadata\": {},\n",
    "   \"outputs\": [],\n",
    "   \"source\": [\n",
    "    \"import torch\\n\",\n",
    "    \"import torch.nn as nn\\n\",\n",
    "    \"import torch.nn.functional as F\\n\",\n",
    "    \"import json\\n\",\n",
    "    \"from pathlib import Path\\n\",\n",
    "    \"from collections import defaultdict\\n\",\n",
    "    \"import sys\\n\",\n",
    "    \"import math\\n\",\n",
    "    \"from typing import Any\\n\",\n",
    "    \"from einops import rearrange\\n\",\n",
    "    \"from torch.utils.checkpoint import checkpoint\\n\"\n",
    "   ]\n",
    "  },\n",
    "  {\n",
    "   \"cell_type\": \"code\",\n",
    "   \"execution_count\": null,\n",
    "   \"metadata\": {},\n",
    "   \"outputs\": [],\n",
    "   \"source\": [\n",
    "    \"# TinyOnnArcConfig class (copied from exp/tiny_onn_arc/config.py)\\n\",\n",
    "    \"class TinyOnnArcConfig:\\n\",\n",
    "    \"    model_type = \\\"tiny_onn_arc\\\"\\n\",\n",
    "    \"\\n\",\n",
    "    \"    vocab_size: int = 12\\n\",\n",
    "    \"    mask_token_id: int = 11\\n\",\n",
    "    \"    hidden_size: int = 256\\n\",\n",
    "    \"    num_hidden_layers: int = 16\\n\",\n",
    "    \"    max_position_embeddings: int = 1861\\n\",\n",
    "    \"    type_vocab_size: int = 2\\n\",\n",
    "    \"\\n\",\n",
    "    \"    # DynSMHA specific\\n\",\n",
    "    \"    max_attention_experts: int = 96\\n\",\n",
    "    \"    min_attention_experts: int = 32\\n\",\n",
    "    \"    head_dim: int = 24\\n\",\n",
    "    \"\\n\",\n",
    "    \"    # DynMoE specific\\n\",\n",
    "    \"    max_moe_experts: int = 64\\n\",\n",
    "    \"    min_moe_experts: int = 32\\n\",\n",
    "    \"    intermediate_size: int = 16\\n\",\n",
    "    \"\\n\",\n",
    "    \"    # Loss weights\\n\",\n",
    "    \"    w_ce_smha: float = 1.0\\n\",\n",
    "    \"    w_kl_smha: float = 1.0\\n\",\n",
    "    \"    w_aux_smha: float = 1.0\\n\",\n",
    "    \"    w_ce_moe: float = 1.0\\n\",\n",
    "    \"    w_kl_moe: float = 1.0\\n\",\n",
    "    \"    w_aux_moe: float = 1.0\\n\",\n",
    "    \"\\n\",\n",
    "    \"    # Predictive Integrity Score specific\\n\",\n",
    "    \"    pi_alpha: float = 64.0\\n\",\n",
    "    \"    pi_gamma: float = 0.5\\n\",\n",
    "    \"    \\n\",\n",
    "    \"    def __init__(self, **kwargs):\\n\",\n",
    "    \"        for key, value in kwargs.items():\\n\",\n",
    "    \"            setattr(self, key, value)\\n\"\n",
    "   ]\n",
    "  },\n",
    "  {\n",
    "   \"cell_type\": \"code\",\n",
    "   \"execution_count\": null,\n",
    "   \"metadata\": {},\n",
    "   \"outputs\": [],\n",
    "   \"source\": [\n",
    "    \"# Model related classes (copied from exp/tiny_onn_arc/model.py)\\n\",\n",
    "    \"ExpertID = tuple[str, int, int]\\n\",\n",
    "    \"\\n\",\n",
    "    \"class STEFunction(torch.autograd.Function):\\n\",\n",
    "    \"    @staticmethod\\n\",\n",
    "    \"    def forward(ctx: Any, scores: torch.Tensor) -> torch.Tensor:\\n\",\n",
    "    \"        return (scores > 0).to(scores.dtype)\\n\",\n",
    "    \"\\n\",\n",
    "    \"    @staticmethod\\n\",\n",
    "    \"    def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:\\n\",\n",
    "    \"        return grad_output\\n\",\n",
    "    \"\\n\",\n",
    "    \"@torch.jit.script\\n\",\n",
    "    \"def _gating_logic(\\n\",\n",
    "    \"    hidden_states: torch.Tensor, sim_matrix: torch.Tensor, gates: torch.Tensor, max_experts: int, min_experts: int\\n\",\n",
    "    \") -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\\n\",\n",
    "    \"    b, t, c = hidden_states.shape\\n\",\n",
    "    \"    flat_hidden_states = hidden_states.view(b * t, c)\\n\",\n",
    "    \"    logits = torch.matmul(F.normalize(flat_hidden_states, dim=-1), F.normalize(sim_matrix, dim=0)) - torch.sigmoid(\\n\",\n",
    "    \"        gates\\n\",\n",
    "    \"    )\\n\",\n",
    "    \"    gated_logits = F.relu(logits)\\n\",\n",
    "    \"    activation_mask = STEFunction.apply(gated_logits)\\n\",\n",
    "    \"    inactive_mask = torch.sum(activation_mask, dim=1) == 0\\n\",\n",
    "    \"    if inactive_mask.any():\\n\",\n",
    "    \"        inactive_logits = logits[inactive_mask]\\n\",\n",
    "    \"        fallback_indices = torch.topk(inactive_logits, min_experts, dim=-1).indices\\n\",\n",
    "    \"        inactive_b_indices = torch.where(inactive_mask)[0]\\n\",\n",
    "    \"        activation_mask.index_put_(\\n\",\n",
    "    \"            (inactive_b_indices.unsqueeze(1).expand(-1, min_experts), fallback_indices),\\n\",\n",
    "    \"            torch.tensor(1.0, device=hidden_states.device, dtype=activation_mask.dtype),\\n\",\n",
    "    \"        )\\n\",\n",
    "    \"    gated_logits_masked = torch.where(\\n\",\n",
    "    \"        activation_mask > 0,\\n\",\n",
    "    \"        gated_logits,\\n\",\n",
    "    \"        torch.tensor(-torch.inf, dtype=gated_logits.dtype, device=gated_logits.device),\\n\",\n",
    "    \"    )\\n\",\n",
    "    \"    return F.softmax(gated_logits_masked, dim=-1), logits, activation_mask, gated_logits\\n\",\n",
    "    \"\\n\",\n",
    "    \"class GatingNetwork(nn.Module):\\n\",\n",
    "    \"    def __init__(self, config: TinyOnnArcConfig, max_experts: int, min_experts: int):\\n\",\n",
    "    \"        super().__init__()\\n\",\n",
    "    \"        self.config = config\\n\",\n",
    "    \"        self.max_experts = max_experts\\n\",\n",
    "    \"        self.min_experts = min_experts\\n\",\n",
    "    \"        self.sim_matrix = nn.Parameter(torch.randn(config.hidden_size, max_experts))\\n\",\n",
    "    \"        self.gates = nn.Parameter(torch.zeros(max_experts))\\n\",\n",
    "    \"\\n\",\n",
    "    \"    def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:\\n\",\n",
    "    \"        active_expert_probs, logits, activation_mask, gated_logits = _gating_logic(\\n\",\n",
    "    \"            hidden_states, self.sim_matrix, self.gates, self.max_experts, self.min_experts\\n\",\n",
    "    \"        )\\n\",\n",
    "    \"        return active_expert_probs, {\\\"logits\\\": logits, \\\"activation_mask\\\": activation_mask, \\\"gated_logits\\\": gated_logits}\\n\",\n",
    "    \"\\n\",\n",
    "    \"class DynSMHALayer(nn.Module):\\n\",\n",
    "    \"    def __init__(self, config: TinyOnnArcConfig, is_causal: bool = False):\\n\",\n",
    "    \"        super().__init__()\\n\",\n",
    "    \"        self.config = config\\n\",\n",
    "    \"        self.max_experts = config.max_attention_experts\\n\",\n",
    "    \"        self.is_causal = is_causal\\n\",\n",
    "    \"        self.gating_network = GatingNetwork(config, self.max_experts, config.min_attention_experts)\\n\",\n",
    "    \"        self.q_proj = nn.Parameter(torch.empty(self.max_experts, config.hidden_size, config.head_dim))\\n\",\n",
    "    \"        self.k_proj = nn.Parameter(torch.empty(self.max_experts, config.hidden_size, config.head_dim))\\n\",\n",
    "    \"        self.v_proj = nn.Parameter(torch.empty(self.max_experts, config.hidden_size, config.head_dim))\\n\",\n",
    "    \"        self.o_proj = nn.Parameter(torch.empty(self.max_experts, config.head_dim, config.hidden_size))\\n\",\n",
    "    \"\\n\",\n",
    "    \"        for i in range(self.max_experts):\\n\",\n",
    "    \"            nn.init.xavier_uniform_(self.q_proj[i])\\n\",\n",
    "    \"            nn.init.xavier_uniform_(self.k_proj[i])\\n\",\n",
    "    \"            nn.init.xavier_uniform_(self.v_proj[i])\\n\",\n",
    "    \"            nn.init.xavier_uniform_(self.o_proj[i])\\n\",\n",
    "    \"\\n\",\n",
    "    \"    def forward_gating(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:\\n\",\n",
    "    \"        return self.gating_network(hidden_states)\\n\",\n",
    "    \"\\n\",\n",
    "    \"    def forward_main(self, hidden_states: torch.Tensor, routing_weights: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:\\n\",\n",
    "    \"        B, T, C = hidden_states.shape\\n\",\n",
    "    \"        routing_weights_reshaped = rearrange(routing_weights, \\\"(b t) e -> b t e\\\", b=B)\\n\",\n",
    "    \"\\n\",\n",
    "    \"        q_experts = torch.einsum(\\\"btc,ech->bteh\\\", hidden_states, self.q_proj)\\n\",\n",
    "    \"        k_experts = torch.einsum(\\\"btc,ech->bteh\\\", hidden_states, self.k_proj)\\n\",\n",
    "    \"        v_experts = torch.einsum(\\\"btc,ech->bteh\\\", hidden_states, self.v_proj)\\n\",\n",
    "    \"\\n\",\n",
    "    \"        q_agg = torch.einsum(\\\"bteh,bte->bth\\\", q_experts, routing_weights_reshaped)\\n\",\n",
    "    \"        k_agg = torch.einsum(\\\"bteh,bte->bth\\\", k_experts, routing_weights_reshaped)\\n\",\n",
    "    \"        v_agg = torch.einsum(\\\"bteh,bte->bth\\\", v_experts, routing_weights_reshaped) \\n\",\n",
    "    \"\\n\",\n",
    "    \"        q = rearrange(q_agg, \\\"b t h -> b 1 t h\\\")\\n\",\n",
    "    \"        k = rearrange(k_agg, \\\"b t h -> b 1 t h\\\")\\n\",\n",
    "    \"        v = rearrange(v_agg, \\\"b t h -> b 1 t h\\\")\\n\",\n",
    "    \"\\n\",\n",
    "    \"        attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=self.is_causal)\\n\",\n",
    "    \"        attn_output = rearrange(attn_output, \\\"b 1 t h -> b t h\\\")\\n\",\n",
    "    \"\\n\",\n",
    "    \"        output_experts = torch.einsum(\\\"bth,ehc->btec\\\", attn_output, self.o_proj)\\n\",\n",
    "    \"        final_output = torch.einsum(\\\"btec,bte->btc\\\", output_experts, routing_weights_reshaped)\\n\",\n",
    "    \"\\n\",\n",
    "    \"        cache = {\\\"final_output\\\": final_output, \\\"routing_weights\\\": routing_weights, \\\"B\\\": B, \\\"T\\\": T}\\n\",\n",
    "    \"        return final_output, cache\\n\",\n",
    "    \"\\n\",\n",
    "    \"class DynamicMoELayer(nn.Module):\\n\",\n",
    "    \"    def __init__(self, config: TinyOnnArcConfig):\\n\",\n",
    "    \"        super().__init__()\\n\",\n",
    "    \"        self.config = config\\n\",\n",
    "    \"        self.max_experts = config.max_moe_experts\\n\",\n",
    "    \"        self.gating_network = GatingNetwork(config, self.max_experts, config.min_moe_experts)\\n\",\n",
    "    \"        self.w1 = nn.Parameter(torch.empty(self.max_experts, config.hidden_size, config.intermediate_size))\\n\",\n",
    "    \"        self.w2 = nn.Parameter(torch.empty(self.max_experts, config.intermediate_size, config.hidden_size))\\n\",\n",
    "    \"        for i in range(self.max_experts):\\n\",\n",
    "    \"            nn.init.kaiming_uniform_(self.w1[i], a=math.sqrt(5))\\n\",\n",
    "    \"            nn.init.kaiming_uniform_(self.w2[i], a=math.sqrt(5))\\n\",\n",
    "    \"\\n\",\n",
    "    \"    def forward_gating(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:\\n\",\n",
    "    \"        return self.gating_network(hidden_states)\\n\",\n",
    "    \"\\n\",\n",
    "    \"    def forward_main(self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, gate_cache: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, Any]]:\\n\",\n",
    "    \"        B, T, C = hidden_states.shape\\n\",\n",
    "    \"        routing_weights_reshaped = rearrange(routing_weights, \\\"(b t) e -> b t e\\\", b=B)\\n\",\n",
    "    \"\\n\",\n",
    "    \"        intermediate_experts = F.gelu(torch.einsum(\\\"btc,eci->btei\\\", hidden_states, self.w1))\\n\",\n",
    "    \"        output_experts = torch.einsum(\\\"btei,eic->btec\\\", intermediate_experts, self.w2)\\n\",\n",
    "    \"\\n\",\n",
    "    \"        final_output = torch.einsum(\\\"btec,bte->btc\\\", output_experts, routing_weights_reshaped)\\n\",\n",
    "    \"\\n\",\n",
    "    \"        cache = {\\\"final_output\\\": final_output, \\\"gate_cache\\\": gate_cache, \\\"routing_weights\\\": routing_weights, \\\"B\\\": B, \\\"T\\\": T, \\\"normed_hs\\\": hidden_states, \\\"layer\\\": self}\\n\",\n",
    "    \"        return final_output, cache\\n\",\n",
    "    \"\\n\",\n",
    "    \"class Block(nn.Module):\\n\",\n",
    "    \"    def __init__(self, config: TinyOnnArcConfig, layer_index: int):\\n\",\n",
    "    \"        super().__init__()\\n\",\n",
    "    \"        self.layer_index = layer_index\\n\",\n",
    "    \"        self.ln1 = nn.LayerNorm(config.hidden_size)\\n\",\n",
    "    \"        self.smha_layer = DynSMHALayer(config, is_causal=True)\\n\",\n",
    "    \"        self.ln2 = nn.LayerNorm(config.hidden_size)\\n\",\n",
    "    \"        self.moe_layer = DynamicMoELayer(config)\\n\",\n",
    "    \"\\n\",\n",
    "    \"    def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict[ExpertID, Any]]:\\n\",\n",
    "    \"        residual = hidden_states\\n\",\n",
    "    \"        normed_hs_smha = self.ln1(hidden_states)\\n\",\n",
    "    \"        smha_routing_weights, smha_gate_cache = self.smha_layer.forward_gating(normed_hs_smha)\\n\",\n",
    "    \"        B, T, C = hidden_states.shape\\n\",\n",
    "    \"        smha_routing_weights_flat = smha_routing_weights.view(B * T, -1)\\n\",\n",
    "    \"\\n\",\n",
    "    \"        def smha_checkpointed_fn(hs_norm: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:\\n\",\n",
    "    \"            return self.smha_layer.forward_main(hs_norm, smha_routing_weights_flat)\\n\",\n",
    "    \"\\n\",\n",
    "    \"        attn_output, smha_cache = checkpoint(smha_checkpointed_fn, normed_hs_smha, use_reentrant=False)\\n\",\n",
    "    \"        smha_cache[\\\"gate_cache\\\"] = smha_gate_cache\\n\",\n",
    "    \"        smha_cache[\\\"normed_hs\\\"] = normed_hs_smha\\n\",\n",
    "    \"        smha_cache[\\\"layer\\\"] = self.smha_layer\\n\",\n",
    "    \"        hidden_states = residual + attn_output\\n\",\n",
    "    \"\\n\",\n",
    "    \"        residual = hidden_states\\n\",\n",
    "    \"        normed_hs_moe = self.ln2(hidden_states)\\n\",\n",
    "    \"        moe_routing_weights, moe_gate_cache = self.moe_layer.forward_gating(normed_hs_moe)\\n\",\n",
    "    \"        moe_routing_weights_flat = moe_routing_weights.view(B * T, -1)\\n\",\n",
    "    \"\\n\",\n",
    "    \"        def moe_checkpointed_fn(hs_norm: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:\\n\",\n",
    "    \"            return self.moe_layer.forward_main(hs_norm, moe_routing_weights_flat, moe_gate_cache)\\n\",\n",
    "    \"\\n\",\n",
    "    \"        moe_output, moe_cache = checkpoint(moe_checkpointed_fn, normed_hs_moe, use_reentrant=False)\\n\",\n",
    "    \"        moe_cache[\\\"normed_hs\\\"] = normed_hs_moe\\n\",\n",
    "    \"        hidden_states = residual + moe_output\\n\",\n",
    "    \"\\n\",\n",
    "    \"        block_cache = {(\\\"smha\\\", self.layer_index, 0): smha_cache, (\\\"moe\\\", self.layer_index, 0): moe_cache}\\n\",\n",
    "    \"        return hidden_states, block_cache\\n\",\n",
    "    \"\\n\",\n",
    "    \"class AutoregressiveEmbedding(nn.Module):\\n\",\n",
    "    \"    def __init__(self, config: TinyOnnArcConfig):\\n\",\n",
    "    \"        super().__init__()\\n\",\n",
    "    \"        self.tok_embed = nn.Embedding(config.vocab_size, config.hidden_size)\\n\",\n",
    "    \"        self.pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)\\n\",\n",
    "    \"        self.register_buffer(\\\"position_ids\\\", torch.arange(config.max_position_embeddings))\\n\",\n",
    "    \"\\n\",\n",
    "    \"    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:\\n\",\n",
    "    \"        seq_len = input_ids.size(1)\\n\",\n",
    "    \"        pos_ids = self.position_ids[:seq_len]\\n\",\n",
    "    \"        \\n\",\n",
    "    \"        tok_embeds = self.tok_embed(input_ids)\\n\",\n",
    "    \"        pos_embeds = self.pos_embed(pos_ids)\\n\",\n",
    "    \"        \\n\",\n",
    "    \"        return tok_embeds + pos_embeds\\n\",\n",
    "    \"\\n\",\n",
    "    \"class TinyOnnModel(nn.Module):\\n\",\n",
    "    \"    def __init__(self, config: TinyOnnArcConfig):\\n\",\n",
    "    \"        super().__init__()\\n\",\n",
    "    \"        self.config = config\\n\",\n",
    "    \"        self.embeddings = AutoregressiveEmbedding(config)\\n\",\n",
    "    \"        self.layers = nn.ModuleList([Block(config, i) for i in range(config.num_hidden_layers)])\\n\",\n",
    "    \"        self.final_ln = nn.LayerNorm(config.hidden_size)\\n\",\n",
    "    \"\\n\",\n",
    "    \"    def forward(self, input_ids: torch.Tensor, **kwargs: Any) -> tuple[torch.Tensor, dict[ExpertID, Any]]:\\n\",\n",
    "    \"        hidden_states = self.embeddings(input_ids)\\n\",\n",
    "    \"\\n\",\n",
    "    \"        flat_forward_cache: dict[ExpertID, Any] = {}\\n\",\n",
    "    \"        for layer in self.layers:\\n\",\n",
    "    \"            hidden_states, block_cache = layer(hidden_states)\\n\",\n",
    "    \"            flat_forward_cache.update(block_cache)\\n\",\n",
    "    \"            \\n\",\n",
    "    \"        hidden_states = self.final_ln(hidden_states)\\n\",\n",
    "    \"        return hidden_states, flat_forward_cache\\n\",\n",
    "    \"\\n\",\n",
    "    \"class TinyOnnForArcReconstruction(nn.Module):\\n\",\n",
    "    \"    def __init__(self, config: TinyOnnArcConfig):\\n\",\n",
    "    \"        super().__init__()\\n\",\n",
    "    \"        self.model = TinyOnnModel(config)\\n\",\n",
    "    \"        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\\n\",\n",
    "    \"\\n\",\n",
    "    \"    def forward(self, input_ids: torch.Tensor, **kwargs: Any) -> tuple[torch.Tensor, dict[ExpertID, Any]]:\\n\",\n",
    "    \"        hidden_states, flat_forward_cache = self.model(input_ids=input_ids, **kwargs)\\n\",\n",
    "    \"        final_logits = self.lm_head(hidden_states)\\n\",\n",
    "    \"        return final_logits, flat_forward_cache\\n\"\n",
    "   ]\n",
    "  },\n",
    "  {\n",
    "   \"cell_type\": \"code\",\n",
    "   \"execution_count\": null,\n",
    "   \"metadata\": {},\n",
    "   \"outputs\": [],\n",
    "   \"source\": [\n",
    "    \"# Data processing functions (copied from exp/tiny_onn_arc/data.py, simplified for evaluation)\\n\",\n",
    "    \"def pad_grid(grid: list[list[int]], max_h: int, max_w: int) -> torch.Tensor:\\n\",\n",
    "    \"    grid_tensor = torch.tensor(grid, dtype=torch.long)\\n\",\n",
    "    \"    h, w = grid_tensor.shape\\n\",\n",
    "    \"\\n\",\n",
    "    \"    pad_h = max_h - h\\n\",\n",
    "    \"    pad_w = max_w - w\\n\",\n",
    "    \"\\n\",\n",
    "    \"    if pad_h > 0 or pad_w > 0:\\n\",\n",
    "    \"        grid_tensor = F.pad(grid_tensor, (0, pad_w, 0, pad_h), \\\"constant\\\", 0)\\n\",\n",
    "    \"\\n\",\n",
    "    \"    return grid_tensor\\n\",\n",
    "    \"\\n\",\n",
    "    \"def augment_grid(grid: torch.Tensor, flip_lr: bool, flip_ud: bool, rot_k: int) -> torch.Tensor:\\n\",\n",
    "    \"    # No augmentation for evaluation, return original grid\\n\",\n",
    "    \"    return grid\\n\",\n",
    "    \"\\n\",\n",
    "    \"class ArcDataset(torch.utils.data.Dataset):\\n\",\n",
    "    \"    def __init__(self, task_files: list[Path], use_test_pairs: bool = False):\\n\",\n",
    "    \"        self.max_h = 30\\n\",\n",
    "    \"        self.max_w = 30\\n\",\n",
    "    \"        self.use_test_pairs = use_test_pairs\\n\",\n",
    "    \"        self.samples = []\\n\",\n",
    "    \"        for task_file in task_files:\\n\",\n",
    "    \"            with open(task_file, \\\"r\\\") as f:\\n\",\n",
    "    \"                task = json.load(f)\\n\",\n",
    "    \"                pairs = task[\\\"test\\\"] if self.use_test_pairs else task[\\\"train\\\"]\\n\",\n",
    "    \"                for pair in pairs:\\n\",\n",
    "    \"                    self.samples.append(pair)\\n\",\n",
    "    \"\\n\",\n",
    "    \"    def __len__(self) -> int:\\n\",\n",
    "    \"        return len(self.samples)\\n\",\n",
    "    \"\\n\",\n",
    "    \"    def __getitem__(self, idx: int):\\n\",\n",
    "    \"        pair = self.samples[idx]\\n\",\n",
    "    \"\\n\",\n",
    "    \"        input_tensor = pad_grid(pair[\\\"input\\\"], self.max_h, self.max_w)\\n\",\n",
    "    \"        # No output_tensor needed for evaluation, as we predict it\\n\",\n",
    "    \"\\n\",\n",
    "    \"        # No augmentation for evaluation\\n\",\n",
    "    \"        aug_input = augment_grid(input_tensor, False, False, 0)\\n\",\n",
    "    \"\\n\",\n",
    "    \"        newline_token = 10\\n\",\n",
    "    \"        input_rows = [torch.cat((row, torch.tensor([newline_token], dtype=torch.long))) for row in aug_input]\\n\",\n",
    "    \"        input_seq = torch.cat(input_rows)\\n\",\n",
    "    \"        \\n\",\n",
    "    \"        # For evaluation, we only need the input sequence\\n\",\n",
    "    \"        return input_seq\\n\",\n",
    "    \"\\n\",\n",
    "    \"def collate_fn(batch):\\n\",\n",
    "    \"    # Custom collate_fn for evaluation, as we only have input_seq\\n\",\n",
    "    \"    batch = [b for b in batch if b is not None]\\n\",\n",
    "    \"    if not batch:\\n\",\n",
    "    \"        return None\\n\",\n",
    "    \"    return torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0)\\n\"\n",
    "   ]\n",
    "  },\n",
    "  {\n",
    "   \"cell_type\": \"code\",\n",
    "   \"execution_count\": null,\n",
    "   \"metadata\": {},\n",
    "   \"outputs\": [],\n",
    "   \"source\": [\n",
    "    \"# Model loading and inference logic\\n\",\n",
    "    \"config = TinyOnnArcConfig()\\n\",\n",
    "    \"device = \\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\"\\n\",\n",
    "    \"\\n\",\n",
    "    \"model = TinyOnnForArcReconstruction(config).to(device)\\n\",\n",
    "    \"checkpoint_path = Path(\\\"exp/ARC-Killer.pt\\\") # Pre-trained weights path\\n\",\n",
    "    \"if checkpoint_path.exists():\\n\",\n",
    "    \"    ckpt = torch.load(checkpoint_path, map_location=device)\\n\",\n",
    "    \"    model.load_state_dict(ckpt['model_state_dict'])\\n\",\n",
    "    \"    print(f\\\"Loaded model from {checkpoint_path}\\\")\\n\",\n",
    "    \"else:\\n\",\n",
    "    \"    print(f\\\"Warning: Checkpoint {checkpoint_path} not found. Using randomly initialized model.\\\")\\n\",\n",
    "    \"\\n\",\n",
    "    \"model.eval()\\n\",\n",
    "    \"\\n\",\n",
    "    \"newline_token = 10\\n\",\n",
    "    \"pad_token = 0 # For submission, we assume 0 is padding\\n\",\n",
    "    \"\\n\",\n",
    "    \"def to_grid_and_crop(seq: torch.Tensor) -> list[list[int]]:\\n\",\n",
    "    \"    # Remove newline tokens and pad tokens (0s) from the sequence\\n\",\n",
    "    \"    pixel_seq = seq[seq != newline_token]\\n\",\n",
    "    \"    pixel_seq = pixel_seq[pixel_seq != pad_token]\\n\",\n",
    "    \"\\n\",\n",
    "    \"    # If the sequence is empty after removing special tokens, return an empty 1x1 grid or handle as error\\n\",\n",
    "    \"    if pixel_seq.numel() == 0:\\n\",\n",
    "    \"        return [[0]] # Return a minimal valid grid\\n\",\n",
    "    \"\\n\",\n",
    "    \"    # Reshape to a large enough grid (e.g., 30x30) and then find the actual content bounds.\\n\",\n",
    "    \"    max_dim = 30 # Max possible dimension for ARC grids\\n\",\n",
    "    \"    padded_pixel_seq = F.pad(pixel_seq, (0, max_dim * max_dim - pixel_seq.numel()), \\\"constant\\\", pad_token)\\n\",\n",
    "    \"    temp_grid = padded_pixel_seq.view(max_dim, max_dim)\\n\",\n",
    "    \"\\n\",\n",
    "    \"    # Find the bounding box of non-padding pixels\\n\",\n",
    "    \"    rows = torch.any(temp_grid != pad_token, dim=1)\\n\",\n",
    "    \"    cols = torch.any(temp_grid != pad_token, dim=0)\\n\",\n",
    "    \"\\n\",\n",
    "    \"    if not torch.any(rows) or not torch.any(cols): # If all pixels are padding\\n\",\n",
    "    \"        return [[0]]\\n\",\n",
    "    \"\\n\",\n",
    "    \"    min_r, max_r = torch.where(rows)[0].min(), torch.where(rows)[0].max()\\n\",\n",
    "    \"    min_c, max_c = torch.where(cols)[0].min(), torch.where(cols)[0].max()\\n\",\n",
    "    \"\\n\",\n",
    "    \"    cropped_grid = temp_grid[min_r : max_r + 1, min_c : max_c + 1]\\n\",\n",
    "    \"\\n\",\n",
    "    \"    return cropped_grid.tolist()\\n\",\n",
    "    \"\\n\",\n",
    "    \"def to_grid_for_attempt1(seq: torch.Tensor, h: int = 30, w: int = 30) -> list[list[int]]:\\n\",\n",
    "    \"    pixel_seq = seq[seq != newline_token]\\n\",\n",
    "    \"    pixel_seq = pixel_seq[pixel_seq != -100] # Remove -100 pad token from labels\\n\",\n",
    "    \"    \\n\",\n",
    "    \"    num_pixels = pixel_seq.numel()\\n\",\n",
    "    \"    if num_pixels > h * w:\\n\",\n",
    "    \"        pixel_seq = pixel_seq[:h*w]\\n\",\n",
    "    \"    elif num_pixels < h * w:\\n\",\n",
    "    \"        pixel_seq = F.pad(pixel_seq, (0, h * w - num_pixels), \\\"constant\\\", 0)\\n\",\n",
    "    \"    \\n\",\n",
    "    \"    return pixel_seq.view(h, w).tolist()\\n\",\n",
    "    \"\\n\",\n",
    "    \"def predict_arc_output(model: TinyOnnForArcReconstruction, input_grid: list[list[int]], max_len: int = 1861) -> tuple[list[list[int]], list[list[int]]]:\\n\",\n",
    "    \"    model.eval()\\n\",\n",
    "    \"    with torch.no_grad():\\n\",\n",
    "    \"        input_seq = preprocess_input_grid(input_grid).to(model.model.embeddings.position_ids.device)\\n\",\n",
    "    \"        \\n\",\n",
    "    \"        generated_sequence = input_seq.clone()\\n\",\n",
    "    \"        \\n\",\n",
    "    \"        output_start_idx = len(input_seq)\\n\",\n",
    "    \"\\n\",\n",
    "    \"        for _ in range(max_len - len(input_seq)): # Generate up to max_len tokens\\n\",\n",
    "    \"            if generated_sequence.numel() >= config.max_position_embeddings:\\n\",\n",
    "    \"                break\\n\",\n",
    "    \"            \\n\",\n",
    "    \"            logits, _ = model(input_ids=generated_sequence.unsqueeze(0))\\n\",\n",
    "    \"            next_token_logits = logits[0, -1, :]\\n\",\n",
    "    \"            next_token = torch.argmax(next_token_logits).item()\\n\",\n",
    "    \"            \\n\",\n",
    "    \"            generated_sequence = torch.cat((generated_sequence, torch.tensor([next_token], device=generated_sequence.device)))\\n\",\n",
    "    \"            \\n\",\n",
    "    \"            # Stop if we generate a newline token and the sequence length is reasonable\\n\",\n",
    "    \"            # This heuristic might need adjustment\\n\",\n",
    "    \"            if next_token == newline_token and generated_sequence.numel() > output_start_idx + 1:\\n\",\n",
    "    \"                # A simple heuristic: if we have generated at least one full row (width + newline)\\n\",\n",
    "    \"                # and the last token was a newline, we can consider stopping.\\n\",\n",
    "    \"                # For now, we'll rely on the to_grid_and_crop to handle the final shape.\\n\",\n",
    "    \"                pass\\n\",\n",
    "    \"        \\n\",\n",
    "    \"        predicted_output_seq = generated_sequence[output_start_idx:]\\n\",\n",
    "    \"        \\n\",\n",
    "    \"        attempt1_grid = to_grid_for_attempt1(predicted_output_seq)\\n\",\n",
    "    \"        attempt2_grid = to_grid_and_crop(predicted_output_seq)\\n\",\n",
    "    \"        \\n\",\n",
    "    \"        return attempt1_grid, attempt2_grid\\n\"\n",
    "   ]\n",
    "  },\n",
    "  {\n",
    "   \"cell_type\": \"code\",\n",
    "   \"execution_count\": null,\n",
    "   \"metadata\": {},\n",
    "   \"outputs\": [],\n",
    "   \"source\": [\n",
    "    \"# Main execution logic\\n\",\n",
    "    \"submission = defaultdict(list)\\n\",\n",
    "    \"eval_data_path = Path(\\\"data/arc-agi_evaluation_challenges.json\\\") # Updated path\\n\",\n",
    "    \"\\n\",\n",
    "    \"if not eval_data_path.exists():\\n\",\n",
    "    \"    # Fallback for Kaggle environment where data might be in /kaggle/input\\n\",\n",
    "    \"    eval_data_path = Path(\\\"/kaggle/input/arc-agi-2/data/evaluation/arc-agi_evaluation_challenges.json\\\")\\n\",\n",
    "    \"    if not eval_data_path.exists():\\n\",\n",
    "    \"        print(f\\\"Error: Evaluation data not found at {eval_data_path}. Please check the path.\\\")\\n\",\n",
    "    \"        sys.exit(1)\\n\",\n",
    "    \"\\n\",\n",
    "    \"with open(eval_data_path, \\\"r\\\") as f:\\n\",\n",
    "    \"    evaluation_tasks = json.load(f)\\n\",\n",
    "    \"\\n\",\n",
    "    \"for task_id, task_data in evaluation_tasks.items():\\n\",\n",
    "    \"    task_predictions = []\\n\",\n",
    "    \"    for test_pair in task_data[\\\"test\\\"]:\\n\",\n",
    "    \"        input_grid = test_pair[\\\"input\\\"]\\n\",\n",
    "    \"        \\n\",\n",
    "    \"        # Predict output\\n\",\n",
    "    \"        attempt1_output, attempt2_output = predict_arc_output(model, input_grid)\\n\",\n",
    "    \"        \\n\",\n",
    "    \"        task_predictions.append({\\\"attempt_1\\\": attempt1_output, \\\"attempt_2\\\": attempt2_output})\\n\",\n",
    "    \"    submission[task_id] = task_predictions\\n\",\n",
    "    \"\\n\",\n",
    "    \"# Save submission.json\\n\",\n",
    "    \"submission_path = Path(\\\"submission.json\\\")\\n\",\n",
    "    \"with open(submission_path, \\\"w\\\") as f:\\n\",\n",
    "    \"    json.dump(submission, f, indent=2)\\n\",\n",
    "    \"\\n\",\n",
    "    \"print(f\\\"Submission file saved to {submission_path}\\\")\\n\"\n",
    "   ]\n",
    "  }\n",
    " ],\n",
    " \"metadata\": {\n",
    "  \"kernelspec\": {\n",
    "   \"display_name\": \"Python 3\",\n",
    "   \"language\": \"python\",\n",
    "   \"name\": \"python3\"\n",
    "  },\n",
    "  \"language_info\": {\n",
    "   \"codemirror_mode\": {\n",
    "    \"name\": \"ipython\",\n",
    "    \"version\": 3\n",
    "   },\n",
    "   \"file_extension\": \".py\",\n",
    "   \"mimetype\": \"text/x-python\",\n",
    "   \"name\": \"python\",\n",
    "   \"nbconvert_exporter\": \"python\",\n",
    "   \"pygments_lexer\": \"ipython3\",\n",
    "   \"version\": \"3.9.18\"\n",
    "  }\n",
    " },\n",
    " \"nbformat\": 4,\n",
    " \"nbformat_minor\": 4\n",
    "}\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}


In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Kaggle ARC Prize 2025 - Tiny-ONN-ARC Submission Notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import json\n",
    "from pathlib import Path\n",
    "from collections import defaultdict\n",
    "import sys\n",
    "import math\n",
    "from typing import Any\n",
    "from einops import rearrange\n",
    "from torch.utils.checkpoint import checkpoint\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TinyOnnArcConfig class (copied from exp/tiny_onn_arc/config.py)\n",
    "class TinyOnnArcConfig:\n",
    "    model_type = \"tiny_onn_arc\"\n",
    "\n",
    "    vocab_size: int = 12\n",
    "    mask_token_id: int = 11\n",
    "    hidden_size: int = 256\n",
    "    num_hidden_layers: int = 16\n",
    "    max_position_embeddings: int = 1861\n",
    "    type_vocab_size: int = 2\n",
    "\n",
    "    # DynSMHA specific\n",
    "    max_attention_experts: int = 96\n",
    "    min_attention_experts: int = 32\n",
    "    head_dim: int = 24\n",
    "\n",
    "    # DynMoE specific\n",
    "    max_moe_experts: int = 64\n",
    "    min_moe_experts: int = 32\n",
    "    intermediate_size: int = 16\n",
    "\n",
    "    # Loss weights\n",
    "    w_ce_smha: float = 1.0\n",
    "    w_kl_smha: float = 1.0\n",
    "    w_aux_smha: float = 1.0\n",
    "    w_ce_moe: float = 1.0\n",
    "    w_kl_moe: float = 1.0\n",
    "    w_aux_moe: float = 1.0\n",
    "\n",
    "    # Predictive Integrity Score specific\n",
    "    pi_alpha: float = 64.0\n",
    "    pi_gamma: float = 0.5\n",
    "    \n",
    "    def __init__(self, **kwargs):\n",
    "        for key, value in kwargs.items():\n",
    "            setattr(self, key, value)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model related classes (copied from exp/tiny_onn_arc/model.py)\n",
    "ExpertID = tuple[str, int, int]\n",
    "\n",
    "class STEFunction(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx: Any, scores: torch.Tensor) -> torch.Tensor:\n",
    "        return (scores > 0).to(scores.dtype)\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:\n",
    "        return grad_output\n",
    "\n",
    "@torch.jit.script\n",
    "def _gating_logic(\n",
    "    hidden_states: torch.Tensor, sim_matrix: torch.Tensor, gates: torch.Tensor, max_experts: int, min_experts: int\n",
    ") -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n",
    "    b, t, c = hidden_states.shape\n",
    "    flat_hidden_states = hidden_states.view(b * t, c)\n",
    "    logits = torch.matmul(F.normalize(flat_hidden_states, dim=-1), F.normalize(sim_matrix, dim=0)) - torch.sigmoid(\n",
    "        gates\n",
    "    )\n",
    "    gated_logits = F.relu(logits)\n",
    "    activation_mask = STEFunction.apply(gated_logits)\n",
    "    inactive_mask = torch.sum(activation_mask, dim=1) == 0\n",
    "    if inactive_mask.any():\n",
    "        inactive_logits = logits[inactive_mask]\n",
    "        fallback_indices = torch.topk(inactive_logits, min_experts, dim=-1).indices\n",
    "        inactive_b_indices = torch.where(inactive_mask)[0]\n",
    "        activation_mask.index_put_(\n",
    "            (inactive_b_indices.unsqueeze(1).expand(-1, min_experts), fallback_indices),\n",
    "            torch.tensor(1.0, device=hidden_states.device, dtype=activation_mask.dtype),\n",
    "        )\n",
    "    gated_logits_masked = torch.where(\n",
    "        activation_mask > 0,\n",
    "        gated_logits,\n",
    "        torch.tensor(-torch.inf, dtype=gated_logits.dtype, device=gated_logits.device),\n",
    "    )\n",
    "    return F.softmax(gated_logits_masked, dim=-1), logits, activation_mask, gated_logits\n",
    "\n",
    "class GatingNetwork(nn.Module):\n",
    "    def __init__(self, config: TinyOnnArcConfig, max_experts: int, min_experts: int):\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "        self.max_experts = max_experts\n",
    "        self.min_experts = min_experts\n",
    "        self.sim_matrix = nn.Parameter(torch.randn(config.hidden_size, max_experts))\n",
    "        self.gates = nn.Parameter(torch.zeros(max_experts))\n",
    "\n",
    "    def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:\n",
    "        active_expert_probs, logits, activation_mask, gated_logits = _gating_logic(\n",
    "            hidden_states, self.sim_matrix, self.gates, self.max_experts, self.min_experts\n",
    "        )\n",
    "        return active_expert_probs, {\"logits\": logits, \"activation_mask\": activation_mask, \"gated_logits\": gated_logits}\n",
    "\n",
    "class DynSMHALayer(nn.Module):\n",
    "    def __init__(self, config: TinyOnnArcConfig, is_causal: bool = False):\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "        self.max_experts = config.max_attention_experts\n",
    "        self.is_causal = is_causal\n",
    "        self.gating_network = GatingNetwork(config, self.max_experts, config.min_attention_experts)\n",
    "        self.q_proj = nn.Parameter(torch.empty(self.max_experts, config.hidden_size, config.head_dim))\n",
    "        self.k_proj = nn.Parameter(torch.empty(self.max_experts, config.hidden_size, config.head_dim))\n",
    "        self.v_proj = nn.Parameter(torch.empty(self.max_experts, config.hidden_size, config.head_dim))\n",
    "        self.o_proj = nn.Parameter(torch.empty(self.max_experts, config.head_dim, config.hidden_size))\n",
    "\n",
    "        for i in range(self.max_experts):\n",
    "            nn.init.xavier_uniform_(self.q_proj[i])\n",
    "            nn.init.xavier_uniform_(self.k_proj[i])\n",
    "            nn.init.xavier_uniform_(self.v_proj[i])\n",
    "            nn.init.xavier_uniform_(self.o_proj[i])\n",
    "\n",
    "    def forward_gating(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:\n",
    "        return self.gating_network(hidden_states)\n",
    "\n",
    "    def forward_main(self, hidden_states: torch.Tensor, routing_weights: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:\n",
    "        B, T, C = hidden_states.shape\n",
    "        routing_weights_reshaped = rearrange(routing_weights, \"(b t) e -> b t e\", b=B)\n",
    "\n",
    "        q_experts = torch.einsum(\"btc,ech->bteh\", hidden_states, self.q_proj)\n",
    "        k_experts = torch.einsum(\"btc,ech->bteh\", hidden_states, self.k_proj)\n",
    "        v_experts = torch.einsum(\"btc,ech->bteh\", hidden_states, self.v_proj)\n",
    "\n",
    "        q_agg = torch.einsum(\"bteh,bte->bth\", q_experts, routing_weights_reshaped)\n",
    "        k_agg = torch.einsum(\"bteh,bte->bth\", k_experts, routing_weights_reshaped)\n",
    "        v_agg = torch.einsum(\"bteh,bte->bth\", v_experts, routing_weights_reshaped) \n",
    "\n",
    "        q = rearrange(q_agg, \"b t h -> b 1 t h\")\n",
    "        k = rearrange(k_agg, \"b t h -> b 1 t h\")\n",
    "        v = rearrange(v_agg, \"b t h -> b 1 t h\")\n",
    "\n",
    "        attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=self.is_causal)\n",
    "        attn_output = rearrange(attn_output, \"b 1 t h -> b t h\")\n",
    "\n",
    "        output_experts = torch.einsum(\"bth,ehc->btec\", attn_output, self.o_proj)\n",
    "        final_output = torch.einsum(\"btec,bte->btc\", output_experts, routing_weights_reshaped)\n",
    "\n",
    "        cache = {\"final_output\": final_output, \"routing_weights\": routing_weights, \"B\": B, \"T\": T}\n",
    "        return final_output, cache\n",
    "\n",
    "class DynamicMoELayer(nn.Module):\n",
    "    def __init__(self, config: TinyOnnArcConfig):\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "        self.max_experts = config.max_moe_experts\n",
    "        self.gating_network = GatingNetwork(config, self.max_experts, config.min_moe_experts)\n",
    "        self.w1 = nn.Parameter(torch.empty(self.max_experts, config.hidden_size, config.intermediate_size))\n",
    "        self.w2 = nn.Parameter(torch.empty(self.max_experts, config.intermediate_size, config.hidden_size))\n",
    "        for i in range(self.max_experts):\n",
    "            nn.init.kaiming_uniform_(self.w1[i], a=math.sqrt(5))\n",
    "            nn.init.kaiming_uniform_(self.w2[i], a=math.sqrt(5))\n",
    "\n",
    "    def forward_gating(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:\n",
    "        return self.gating_network(hidden_states)\n",
    "\n",
    "    def forward_main(self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, gate_cache: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict[str, Any]]:\n",
    "        B, T, C = hidden_states.shape\n",
    "        routing_weights_reshaped = rearrange(routing_weights, \"(b t) e -> b t e\", b=B)\n",
    "\n",
    "        intermediate_experts = F.gelu(torch.einsum(\"btc,eci->btei\", hidden_states, self.w1))\n",
    "        output_experts = torch.einsum(\"btei,eic->btec\", intermediate_experts, self.w2)\n",
    "\n",
    "        final_output = torch.einsum(\"btec,bte->btc\", output_experts, routing_weights_reshaped)\n",
    "\n",
    "        cache = {\"final_output\": final_output, \"gate_cache\": gate_cache, \"routing_weights\": routing_weights, \"B\": B, \"T\": T, \"normed_hs\": hidden_states, \"layer\": self}\n",
    "        return final_output, cache\n",
    "\n",
    "class Block(nn.Module):\n",
    "    def __init__(self, config: TinyOnnArcConfig, layer_index: int):\n",
    "        super().__init__()\n",
    "        self.layer_index = layer_index\n",
    "        self.ln1 = nn.LayerNorm(config.hidden_size)\n",
    "        self.smha_layer = DynSMHALayer(config, is_causal=True)\n",
    "        self.ln2 = nn.LayerNorm(config.hidden_size)\n",
    "        self.moe_layer = DynamicMoELayer(config)\n",
    "\n",
    "    def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict[ExpertID, Any]]:\n",
    "        residual = hidden_states\n",
    "        normed_hs_smha = self.ln1(hidden_states)\n",
    "        smha_routing_weights, smha_gate_cache = self.smha_layer.forward_gating(normed_hs_smha)\n",
    "        B, T, C = hidden_states.shape\n",
    "        smha_routing_weights_flat = smha_routing_weights.view(B * T, -1)\n",
    "\n",
    "        def smha_checkpointed_fn(hs_norm: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:\n",
    "            return self.smha_layer.forward_main(hs_norm, smha_routing_weights_flat)\n",
    "\n",
    "        attn_output, smha_cache = checkpoint(smha_checkpointed_fn, normed_hs_smha, use_reentrant=False)\n",
    "        smha_cache[\"gate_cache\"] = smha_gate_cache\n",
    "        smha_cache[\"normed_hs\"] = normed_hs_smha\n",
    "        smha_cache[\"layer\"] = self.smha_layer\n",
    "        hidden_states = residual + attn_output\n",
    "\n",
    "        residual = hidden_states\n",
    "        normed_hs_moe = self.ln2(hidden_states)\n",
    "        moe_routing_weights, moe_gate_cache = self.moe_layer.forward_gating(normed_hs_moe)\n",
    "        moe_routing_weights_flat = moe_routing_weights.view(B * T, -1)\n",
    "\n",
    "        def moe_checkpointed_fn(hs_norm: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:\n",
    "            return self.moe_layer.forward_main(hs_norm, moe_routing_weights_flat, moe_gate_cache)\n",
    "\n",
    "        moe_output, moe_cache = checkpoint(moe_checkpointed_fn, normed_hs_moe, use_reentrant=False)\n",
    "        moe_cache[\"normed_hs\"] = normed_hs_moe\n",
    "        hidden_states = residual + moe_output\n",
    "\n",
    "        block_cache = {(\"smha\", self.layer_index, 0): smha_cache, (\"moe\", self.layer_index, 0): moe_cache}\n",
    "        return hidden_states, block_cache\n",
    "\n",
    "class AutoregressiveEmbedding(nn.Module):\n",
    "    def __init__(self, config: TinyOnnArcConfig):\n",
    "        super().__init__()\n",
    "        self.tok_embed = nn.Embedding(config.vocab_size, config.hidden_size)\n",
    "        self.pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)\n",
    "        self.register_buffer(\"position_ids\", torch.arange(config.max_position_embeddings))\n",
    "\n",
    "    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:\n",
    "        seq_len = input_ids.size(1)\n",
    "        pos_ids = self.position_ids[:seq_len]\n",
    "        \n",
    "        tok_embeds = self.tok_embed(input_ids)\n",
    "        pos_embeds = self.pos_embed(pos_ids)\n",
    "        \n",
    "        return tok_embeds + pos_embeds\n",
    "\n",
    "class TinyOnnModel(nn.Module):\n",
    "    def __init__(self, config: TinyOnnArcConfig):\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "        self.embeddings = AutoregressiveEmbedding(config)\n",
    "        self.layers = nn.ModuleList([Block(config, i) for i in range(config.num_hidden_layers)])\n",
    "        self.final_ln = nn.LayerNorm(config.hidden_size)\n",
    "\n",
    "    def forward(self, input_ids: torch.Tensor, **kwargs: Any) -> tuple[torch.Tensor, dict[ExpertID, Any]]:\n",
    "        hidden_states = self.embeddings(input_ids)\n",
    "\n",
    "        flat_forward_cache: dict[ExpertID, Any] = {}\n",
    "        for layer in self.layers:\n",
    "            hidden_states, block_cache = layer(hidden_states)\n",
    "            flat_forward_cache.update(block_cache)\n",
    "            \n",
    "        hidden_states = self.final_ln(hidden_states)\n",
    "        return hidden_states, flat_forward_cache\n",
    "\n",
    "class TinyOnnForArcReconstruction(nn.Module):\n",
    "    def __init__(self, config: TinyOnnArcConfig):\n",
    "        super().__init__()\n",
    "        self.model = TinyOnnModel(config)\n",
    "        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n",
    "\n",
    "    def forward(self, input_ids: torch.Tensor, **kwargs: Any) -> tuple[torch.Tensor, dict[ExpertID, Any]]:\n",
    "        hidden_states, flat_forward_cache = self.model(input_ids=input_ids, **kwargs)\n",
    "        final_logits = self.lm_head(hidden_states)\n",
    "        return final_logits, flat_forward_cache\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data processing functions (copied from exp/tiny_onn_arc/data.py, simplified for evaluation)\n",
    "def pad_grid(grid: list[list[int]], max_h: int, max_w: int) -> torch.Tensor:\n",
    "    grid_tensor = torch.tensor(grid, dtype=torch.long)\n",
    "    h, w = grid_tensor.shape\n",
    "\n",
    "    pad_h = max_h - h\n",
    "    pad_w = max_w - w\n",
    "\n",
    "    if pad_h > 0 or pad_w > 0:\n",
    "        grid_tensor = F.pad(grid_tensor, (0, pad_w, 0, pad_h), \"constant\", 0)\n",
    "\n",
    "    return grid_tensor\n",
    "\n",
    "def augment_grid(grid: torch.Tensor, flip_lr: bool, flip_ud: bool, rot_k: int) -> torch.Tensor:\n",
    "    # No augmentation for evaluation, return original grid\n",
    "    return grid\n",
    "\n",
    "class ArcDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, task_files: list[Path], use_test_pairs: bool = False):\n",
    "        self.max_h = 30\n",
    "        self.max_w = 30\n",
    "        self.use_test_pairs = use_test_pairs\n",
    "        self.samples = []\n",
    "        for task_file in task_files:\n",
    "            with open(task_file, \"r\") as f:\n",
    "                task = json.load(f)\n",
    "                pairs = task[\"test\"] if self.use_test_pairs else task[\"train\"]\n",
    "                for pair in pairs:\n",
    "                    self.samples.append(pair)\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        return len(self.samples)\n",
    "\n",
    "    def __getitem__(self, idx: int):\n",
    "        pair = self.samples[idx]\n",
    "\n",
    "        input_tensor = pad_grid(pair[\"input\"], self.max_h, self.max_w)\n",
    "        # No output_tensor needed for evaluation, as we predict it\n",
    "\n",
    "        # No augmentation for evaluation\n",
    "        aug_input = augment_grid(input_tensor, False, False, 0)\n",
    "\n",
    "        newline_token = 10\n",
    "        input_rows = [torch.cat((row, torch.tensor([newline_token], dtype=torch.long))) for row in aug_input]\n",
    "        input_seq = torch.cat(input_rows)\n",
    "        \n",
    "        # For evaluation, we only need the input sequence\n",
    "        return input_seq\n",
    "\n",
    "def collate_fn(batch):\n",
    "    # Custom collate_fn for evaluation, as we only have input_seq\n",
    "    batch = [b for b in batch if b is not None]\n",
    "    if not batch:\n",
    "        return None\n",
    "    return torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model loading and inference logic\n",
    "config = TinyOnnArcConfig()\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "model = TinyOnnForArcReconstruction(config).to(device)\n",
    "checkpoint_path = Path(\"exp/ARC-Killer.pt\") # Pre-trained weights path\n",
    "if checkpoint_path.exists():\n",
    "    ckpt = torch.load(checkpoint_path, map_location=device)\n",
    "    model.load_state_dict(ckpt['model_state_dict'])\n",
    "    print(f\"Loaded model from {checkpoint_path}\")\n",
    "else:\n",
    "    print(f\"Warning: Checkpoint {checkpoint_path} not found. Using randomly initialized model.\")\n",
    "\n",
    "model.eval()\n",
    "\n",
    "newline_token = 10\n",
    "pad_token = 0 # For submission, we assume 0 is padding\n",
    "\n",
    "def to_grid_and_crop(seq: torch.Tensor) -> list[list[int]]:\n",
    "    # Remove newline tokens and pad tokens (0s) from the sequence\n",
    "    pixel_seq = seq[seq != newline_token]\n",
    "    pixel_seq = pixel_seq[pixel_seq != pad_token]\n",
    "\n",
    "    # If the sequence is empty after removing special tokens, return an empty 1x1 grid or handle as error\n",
    "    if pixel_seq.numel() == 0:\n",
    "        return [[0]] # Return a minimal valid grid\n",
    "\n",
    "    # Reshape to a large enough grid (e.g., 30x30) and then find the actual content bounds.\n",
    "    max_dim = 30 # Max possible dimension for ARC grids\n",
    "    padded_pixel_seq = F.pad(pixel_seq, (0, max_dim * max_dim - pixel_seq.numel()), \"constant\", pad_token)\n",
    "    temp_grid = padded_pixel_seq.view(max_dim, max_dim)\n",
    "\n",
    "    # Find the bounding box of non-padding pixels\n",
    "    rows = torch.any(temp_grid != pad_token, dim=1)\n",
    "    cols = torch.any(temp_grid != pad_token, dim=0)\n",
    "\n",
    "    if not torch.any(rows) or not torch.any(cols): # If all pixels are padding\n",
    "        return [[0]]\n",
    "\n",
    "    min_r, max_r = torch.where(rows)[0].min(), torch.where(rows)[0].max()\n",
    "    min_c, max_c = torch.where(cols)[0].min(), torch.where(cols)[0].max()\n",
    "\n",
    "    cropped_grid = temp_grid[min_r : max_r + 1, min_c : max_c + 1]\n",
    "\n",
    "    return cropped_grid.tolist()\n",
    "\n",
    "def to_grid_for_attempt1(seq: torch.Tensor, h: int = 30, w: int = 30) -> list[list[int]]:\n",
    "    pixel_seq = seq[seq != newline_token]\n",
    "    pixel_seq = pixel_seq[pixel_seq != -100] # Remove -100 pad token from labels\n",
    "    \n",
    "    num_pixels = pixel_seq.numel()\n",
    "    if num_pixels > h * w:\n",
    "        pixel_seq = pixel_seq[:h*w]\n",
    "    elif num_pixels < h * w:\n",
    "        pixel_seq = F.pad(pixel_seq, (0, h * w - num_pixels), \"constant\", 0)\n",
    "    \n",
    "    return pixel_seq.view(h, w).tolist()\n",
    "\n",
    "def predict_arc_output(model: TinyOnnForArcReconstruction, input_grid: list[list[int]], max_len: int = 1861) -> tuple[list[list[int]], list[list[int]]]:\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        input_seq = preprocess_input_grid(input_grid).to(model.model.embeddings.position_ids.device)\n",
    "        \n",
    "        generated_sequence = input_seq.clone()\n",
    "        \n",
    "        output_start_idx = len(input_seq)\n",
    "\n",
    "        for _ in range(max_len - len(input_seq)): # Generate up to max_len tokens\n",
    "            if generated_sequence.numel() >= config.max_position_embeddings:\n",
    "                break\n",
    "            \n",
    "            logits, _ = model(input_ids=generated_sequence.unsqueeze(0))\n",
    "            next_token_logits = logits[0, -1, :]\n",
    "            next_token = torch.argmax(next_token_logits).item()\n",
    "            \n",
    "            generated_sequence = torch.cat((generated_sequence, torch.tensor([next_token], device=generated_sequence.device)))\n",
    "            \n",
    "            # Stop if we generate a newline token and the sequence length is reasonable\n",
    "            # This heuristic might need adjustment\n",
    "            if next_token == newline_token and generated_sequence.numel() > output_start_idx + 1:\n",
    "                # A simple heuristic: if we have generated at least one full row (width + newline)\n",
    "                # and the last token was a newline, we can consider stopping.\n",
    "                # For now, we'll rely on the to_grid_and_crop to handle the final shape.\n",
    "                pass\n",
    "        \n",
    "        predicted_output_seq = generated_sequence[output_start_idx:]\n",
    "        \n",
    "        attempt1_grid = to_grid_for_attempt1(predicted_output_seq)\n",
    "        attempt2_grid = to_grid_and_crop(predicted_output_seq)\n",
    "        \n",
    "        return attempt1_grid, attempt2_grid\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Main execution logic\n",
    "submission = defaultdict(list)\n",
    "eval_data_path = Path(\"data/arc-agi_evaluation_challenges.json\") # Updated path\n",
    "\n",
    "if not eval_data_path.exists():\n",
    "    # Fallback for Kaggle environment where data might be in /kaggle/input\n",
    "    eval_data_path = Path(\"/kaggle/input/arc-agi-2/data/evaluation/arc-agi_evaluation_challenges.json\")\n",
    "    if not eval_data_path.exists():\n",
    "        print(f\"Error: Evaluation data not found at {eval_data_path}. Please check the path.\")\n",
    "        sys.exit(1)\n",
    "\n",
    "with open(eval_data_path, \"r\") as f:\n",
    "    evaluation_tasks = json.load(f)\n",
    "\n",
    "for task_id, task_data in evaluation_tasks.items():\n",
    "    task_predictions = []\n",
    "    for test_pair in task_data[\"test\"]:\n",
    "        input_grid = test_pair[\"input\"]\n",
    "        \n",
    "        # Predict output\n",
    "        attempt1_output, attempt2_output = predict_arc_output(model, input_grid)\n",
    "        \n",
    "        task_predictions.append({\"attempt_1\": attempt1_output, \"attempt_2\": attempt2_output})\n",
    "    submission[task_id] = task_predictions\n",
    "\n",
    "# Save submission.json\n",
    "submission_path = Path(\"submission.json\")\n",
    "with open(submission_path, \"w\") as f:\n",
    "    json.dump(submission, f, indent=2)\n",
    "\n",
    "print(f\"Submission file saved to {submission_path}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
