# Expert-Only Output Viewer

This notebook shows the raw output from each LLM expert **without** Dempster-Shafer combination or MCU orchestration.

For each expert, you can see:
- Per-cell probability distributions (colors 0-9)
- The predicted grid based on argmax of probabilities
- Confidence levels for each cell

## 1. Setup and Imports

In [None]:
import sys
import os
import json
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
import time

# Add MARCO4 to path
sys.path.insert(0, '.')

# Basic imports from MARCO4 (no MCU/D-S needed)
from marco.utils import create_empty_grid, grid_to_string
from marco.main import ARCProblem, evaluate_solution

print("Imports successful!")

In [None]:
# Check for GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 2. Load LLM Models

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, BitsAndBytesConfig
import gc

# Models that are already quantized (don't apply additional quantization)
PREQUANTIZED_MODELS = {'gpt-oss'}
BFLOAT16_MODELS = {'gpt-oss'}

class LLMManager:
    """Manages multiple LLM models for multi-expert system."""
    
    def __init__(self, device: str = "auto", use_4bit: bool = True):
        self.device = device if device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu")
        self.use_4bit = use_4bit and self.device == "cuda"
        self.models = {}
        self.tokenizers = {}
        print(f"LLMManager initialized with device: {self.device}")
    
    def load_model(self, model_path: str, model_name: str = None):
        """Load a model from path."""
        model_name = model_name or Path(model_path).name
        print(f"\nLoading {model_name} from: {model_path}")
        
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        is_prequantized = model_name in PREQUANTIZED_MODELS
        requires_bfloat16 = model_name in BFLOAT16_MODELS
        
        if self.device == "cuda":
            torch_dtype = torch.bfloat16 if requires_bfloat16 else torch.float16
        else:
            torch_dtype = torch.float32
        
        load_kwargs = {
            'device_map': "auto" if self.device == "cuda" else None,
            'torch_dtype': torch_dtype,
            'trust_remote_code': True,
            'low_cpu_mem_usage': True,
        }
        
        if self.use_4bit and not is_prequantized:
            load_kwargs['quantization_config'] = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )
            print("  Using 4-bit quantization (BitsAndBytes)")
        
        model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
        model.eval()
        
        self.models[model_name] = model
        self.tokenizers[model_name] = tokenizer
        
        params = sum(p.numel() for p in model.parameters()) / 1e9
        print(f"  ✓ {model_name} loaded: {params:.2f}B parameters")
        
        return model_name
    
    def get_model(self, model_name: str):
        return self.models.get(model_name), self.tokenizers.get(model_name)
    
    def list_models(self):
        return list(self.models.keys())
    
    def unload(self, model_name: str = None):
        if model_name:
            if model_name in self.models:
                del self.models[model_name]
                del self.tokenizers[model_name]
        else:
            self.models.clear()
            self.tokenizers.clear()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

llm_manager = LLMManager(device="auto", use_4bit=True)

In [None]:
# Model paths
MODEL_PATHS = {
    'qwen': '../models/qwen',
    'phi3': '../models/phi3',
    'gpt-oss': '../models/gpt-oss'
}

# Load ALL models
loaded_models = []
for name, path in MODEL_PATHS.items():
    if os.path.exists(path):
        try:
            llm_manager.load_model(path, name)
            loaded_models.append(name)
        except Exception as e:
            print(f"  Failed to load {name}: {e}")
    else:
        print(f"  Model path not found: {path}")

print(f"\nLoaded {len(loaded_models)} models: {loaded_models}")

## 3. Define Expert (Standalone, No MARCO4 Dependencies)

In [None]:
import torch.nn.functional as F
import re

class StandaloneExpert:
    """
    Standalone expert that predicts ARC output grids.
    
    Supports three modes:
    - 'direct': Per-cell next-token probabilities (fast but limited)
    - 'cell_reasoning': Per-cell with chain-of-thought (slower, more accurate)
    - 'whole_grid': Predict entire grid at once (recommended)
    """
    
    def __init__(self, expert_id: str, llm_manager: LLMManager, model_name: str, 
                 mode: str = 'whole_grid', max_tokens: int = 500):
        """
        Args:
            mode: 'direct', 'cell_reasoning', or 'whole_grid'
            max_tokens: Max tokens to generate
        """
        self.expert_id = expert_id
        self.model_name = model_name
        self.mode = mode
        self.max_tokens = max_tokens
        self.model, self.tokenizer = llm_manager.get_model(model_name)
        if self.model is None:
            raise ValueError(f"Model '{model_name}' not loaded")
        
        self.color_tokens = self._get_color_tokens()
        print(f"Expert '{expert_id}' using model: {model_name} (mode: {mode})")
    
    def _get_color_tokens(self) -> Dict[int, int]:
        """Get token IDs for color digits 0-9."""
        tokens = {}
        for color in range(10):
            token_ids = self.tokenizer.encode(str(color), add_special_tokens=False)
            if token_ids:
                tokens[color] = token_ids[0]
        return tokens
    
    def _format_grid(self, grid: np.ndarray) -> str:
        """Format grid for prompt."""
        rows = []
        for row in grid:
            row_str = ''.join('.' if v == -1 else str(int(v)) for v in row)
            rows.append(row_str)
        return '\n'.join(rows)
    
    def _format_grid_with_marker(self, grid: np.ndarray, target_row: int, target_col: int) -> str:
        """Format grid with a marker showing which cell to predict."""
        rows = []
        for i, row in enumerate(grid):
            row_chars = []
            for j, v in enumerate(row):
                if i == target_row and j == target_col:
                    row_chars.append('?')
                elif v == -1:
                    row_chars.append('.')
                else:
                    row_chars.append(str(int(v)))
            rows.append(''.join(row_chars))
        return '\n'.join(rows)
    
    def _create_whole_grid_prompt(self, task: Dict, target_size: Tuple[int, int]) -> str:
        """Create prompt for predicting the entire output grid at once."""
        h, w = target_size
        
        prompt = """You are solving an ARC (Abstraction and Reasoning Corpus) puzzle.

Study the training examples to find the transformation pattern from input to output.
Then apply the SAME pattern to the test input to produce the output.

Grid notation: Numbers 0-9 represent colors.

"""
        
        train = task.get('train', [])
        if train:
            prompt += "=== Training Examples ===\n"
            for i, example in enumerate(train, 1):
                inp = np.array(example.get('input', []))
                out = np.array(example.get('output', []))
                prompt += f"\nExample {i}:\nInput:\n{self._format_grid(inp)}\nOutput:\n{self._format_grid(out)}\n"
        
        # Test input
        test_input = np.array(task['test'][0]['input'])
        prompt += f"""
=== Test ===
Input:
{self._format_grid(test_input)}

Now apply the same transformation pattern. Output the {h}x{w} grid with one row per line, using only digits 0-9:
Output:
"""
        
        return prompt
    
    def _create_cell_prompt(self, task: Dict, partial_grid: np.ndarray, 
                            row: int, col: int, with_reasoning: bool = False) -> str:
        """Create prompt for predicting a specific cell."""
        if with_reasoning:
            prompt = """You are solving an ARC puzzle. Think step-by-step.

Grid notation: 0-9 = colors, '.' = unfilled, '?' = cell to predict

"""
        else:
            prompt = """You are solving an ARC puzzle.

Grid notation: 0-9 = colors, '.' = unfilled, '?' = cell to predict

"""
        
        train = task.get('train', [])
        if train:
            prompt += "=== Training Examples ===\n"
            for i, example in enumerate(train[:3], 1):
                inp = np.array(example.get('input', []))
                out = np.array(example.get('output', []))
                prompt += f"\nExample {i}:\nInput:\n{self._format_grid(inp)}\nOutput:\n{self._format_grid(out)}\n"
        
        prompt += f"\n=== Test ===\nCurrent output (predict '?'):\n"
        prompt += f"{self._format_grid_with_marker(partial_grid, row, col)}\n\n"
        
        if with_reasoning:
            prompt += f"Cell ({row},{col}): Let me analyze the pattern. "
        else:
            prompt += f"The value for '?' is: "
        
        return prompt
    
    def _generate_text(self, prompt: str, max_tokens: int = None) -> str:
        """Generate text from prompt."""
        if max_tokens is None:
            max_tokens = self.max_tokens
            
        inputs = self.tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                do_sample=False,
                temperature=None,
                top_p=None,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )
        
        generated = outputs[0][inputs['input_ids'].shape[1]:]
        return self.tokenizer.decode(generated, skip_special_tokens=True)
    
    def _get_next_token_probs(self, prompt: str) -> np.ndarray:
        """Get immediate next-token probabilities for colors 0-9."""
        inputs = self.tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs, use_cache=False)
            logits = outputs.logits[0, -1, :]
        
        color_logits = torch.zeros(10, device=logits.device)
        for color, token_id in self.color_tokens.items():
            color_logits[color] = logits[token_id]
        
        probs = F.softmax(color_logits, dim=0)
        return probs.cpu().numpy()
    
    def _parse_grid_from_text(self, text: str, target_size: Tuple[int, int]) -> Optional[np.ndarray]:
        """Parse a grid from generated text."""
        h, w = target_size
        
        # Find lines that look like grid rows (only digits, length matches width)
        lines = text.strip().split('\n')
        grid_lines = []
        
        for line in lines:
            # Clean the line - keep only digits
            cleaned = ''.join(c for c in line if c.isdigit())
            if len(cleaned) == w:
                grid_lines.append(cleaned)
            elif len(cleaned) > w:
                # Take first w digits
                grid_lines.append(cleaned[:w])
        
        if len(grid_lines) >= h:
            # Take first h rows
            grid = np.zeros((h, w), dtype=int)
            for i in range(h):
                for j in range(w):
                    grid[i, j] = int(grid_lines[i][j])
            return grid
        
        return None
    
    def _extract_digit(self, text: str) -> Optional[int]:
        """Extract a single digit answer from text."""
        text = text.lower()
        
        # Look for explicit patterns
        patterns = [
            r"the answer is[:\s]*(\d)",
            r"the value is[:\s]*(\d)",
            r"should be[:\s]*(\d)",
            r"is[:\s]*(\d)\b",
            r"= ?(\d)",
            r"\*\*(\d)\*\*",
        ]
        
        for pattern in patterns:
            match = re.search(pattern, text)
            if match:
                return int(match.group(1))
        
        # Fallback: first digit found
        digits = re.findall(r'\b(\d)\b', text)
        if digits:
            return int(digits[0])
        
        return None
    
    def predict_grid(self, task: Dict, target_size: Tuple[int, int], 
                     verbose: bool = False) -> Dict:
        """
        Predict the output grid.
        
        Returns:
            Dict with 'grid', 'probabilities', 'confidences', and optionally 'raw_output'/'reasoning'
        """
        h, w = target_size
        
        if self.mode == 'whole_grid':
            return self._predict_whole_grid(task, target_size, verbose)
        elif self.mode == 'cell_reasoning':
            return self._predict_cell_by_cell(task, target_size, with_reasoning=True, verbose=verbose)
        else:  # direct
            return self._predict_cell_by_cell(task, target_size, with_reasoning=False, verbose=verbose)
    
    def _predict_whole_grid(self, task: Dict, target_size: Tuple[int, int], 
                            verbose: bool = False) -> Dict:
        """Predict entire grid at once."""
        h, w = target_size
        
        print(f"  Generating complete {h}x{w} grid...")
        
        prompt = self._create_whole_grid_prompt(task, target_size)
        raw_output = self._generate_text(prompt)
        
        if verbose:
            print(f"  Raw output:\n{raw_output[:500]}")
        
        # Parse grid from output
        grid = self._parse_grid_from_text(raw_output, target_size)
        
        if grid is None:
            print(f"  WARNING: Could not parse grid from output, using zeros")
            grid = np.zeros((h, w), dtype=int)
        
        # For whole_grid mode, we don't have per-cell probabilities
        # Use uniform probabilities with slight boost for predicted value
        probabilities = np.ones((h, w, 10)) * 0.05
        for i in range(h):
            for j in range(w):
                probabilities[i, j, grid[i, j]] = 0.55
        probabilities = probabilities / probabilities.sum(axis=2, keepdims=True)
        
        confidences = np.max(probabilities, axis=2)
        
        return {
            'grid': grid,
            'probabilities': probabilities,
            'confidences': confidences,
            'raw_output': raw_output
        }
    
    def _predict_cell_by_cell(self, task: Dict, target_size: Tuple[int, int],
                              with_reasoning: bool = False, verbose: bool = False) -> Dict:
        """Predict grid cell-by-cell."""
        h, w = target_size
        partial_grid = create_empty_grid(h, w)
        
        probabilities = np.zeros((h, w, 10))
        predicted_grid = np.zeros((h, w), dtype=int)
        confidences = np.zeros((h, w))
        reasoning_texts = []
        
        mode_name = "cell_reasoning" if with_reasoning else "direct"
        print(f"  Predicting {h}x{w} = {h*w} cells ({mode_name})...")
        
        for i in range(h):
            for j in range(w):
                prompt = self._create_cell_prompt(task, partial_grid, i, j, with_reasoning)
                
                if with_reasoning:
                    # Generate reasoning and extract answer
                    reasoning = self._generate_text(prompt, max_tokens=200)
                    answer = self._extract_digit(reasoning)
                    reasoning_texts.append({'cell': (i, j), 'text': reasoning})
                    
                    if verbose:
                        print(f"    ({i},{j}): {reasoning[:80]}... -> {answer}")
                    
                    # Create probability distribution
                    probs = np.ones(10) * 0.02
                    if answer is not None and 0 <= answer <= 9:
                        probs[answer] = 0.82
                    probs = probs / probs.sum()
                else:
                    # Direct next-token probabilities
                    probs = self._get_next_token_probs(prompt)
                
                probabilities[i, j] = probs
                predicted_grid[i, j] = np.argmax(probs)
                confidences[i, j] = np.max(probs)
                
                partial_grid[i, j] = predicted_grid[i, j]
        
        result = {
            'grid': predicted_grid,
            'probabilities': probabilities,
            'confidences': confidences,
        }
        
        if with_reasoning:
            result['reasoning'] = reasoning_texts
        
        return result

print("StandaloneExpert defined with modes: 'direct', 'cell_reasoning', 'whole_grid'")

## 4. Visualization

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Rectangle

# ARC color palette
ARC_COLORS = [
    '#000000',  # 0: black
    '#0074D9',  # 1: blue
    '#FF4136',  # 2: red
    '#2ECC40',  # 3: green
    '#FFDC00',  # 4: yellow
    '#AAAAAA',  # 5: gray
    '#F012BE',  # 6: magenta
    '#FF851B',  # 7: orange
    '#7FDBFF',  # 8: cyan
    '#870C25',  # 9: brown
]

def plot_grid(grid: np.ndarray, title: str = "", ax=None, show_values: bool = False):
    """Plot a grid with ARC colors."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(4, 4))
    
    h, w = grid.shape
    cmap = mcolors.ListedColormap(ARC_COLORS)
    bounds = np.arange(-0.5, 10.5, 1)
    norm = mcolors.BoundaryNorm(bounds, cmap.N)
    
    # Handle -1 cells
    display_grid = grid.copy().astype(float)
    display_grid[grid == -1] = np.nan
    ax.set_facecolor('white')
    
    ax.imshow(display_grid, cmap=cmap, norm=norm, interpolation='nearest')
    
    # Draw X for unfilled cells
    for i in range(h):
        for j in range(w):
            if grid[i, j] == -1:
                rect = Rectangle((j - 0.5, i - 0.5), 1, 1, facecolor='white', zorder=1)
                ax.add_patch(rect)
                ax.plot([j - 0.5, j + 0.5], [i - 0.5, i + 0.5], color='#CCCCCC', linewidth=1, zorder=2)
                ax.plot([j - 0.5, j + 0.5], [i + 0.5, i - 0.5], color='#CCCCCC', linewidth=1, zorder=2)
            elif show_values:
                # Show cell value as text
                text_color = 'white' if grid[i, j] in [0, 9] else 'black'
                ax.text(j, i, str(int(grid[i, j])), ha='center', va='center', 
                       fontsize=10, color=text_color, fontweight='bold')
    
    ax.set_title(title, fontsize=10)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim(-0.5, w - 0.5)
    ax.set_ylim(h - 0.5, -0.5)
    
    for i in range(h + 1):
        ax.axhline(i - 0.5, color='#888888', linewidth=0.5, zorder=3)
    for j in range(w + 1):
        ax.axvline(j - 0.5, color='#888888', linewidth=0.5, zorder=3)

def plot_confidence_heatmap(confidences: np.ndarray, title: str = "", ax=None):
    """Plot confidence values as a heatmap."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(4, 4))
    
    h, w = confidences.shape
    im = ax.imshow(confidences, cmap='RdYlGn', vmin=0, vmax=1, interpolation='nearest')
    
    # Show values
    for i in range(h):
        for j in range(w):
            val = confidences[i, j]
            text_color = 'white' if val < 0.5 else 'black'
            ax.text(j, i, f'{val:.2f}', ha='center', va='center', 
                   fontsize=8, color=text_color)
    
    ax.set_title(title, fontsize=10)
    ax.set_xticks([])
    ax.set_yticks([])
    
    for i in range(h + 1):
        ax.axhline(i - 0.5, color='#888888', linewidth=0.5)
    for j in range(w + 1):
        ax.axvline(j - 0.5, color='#888888', linewidth=0.5)
    
    return im

def plot_probability_distribution(probs: np.ndarray, row: int, col: int, ax=None):
    """Plot probability distribution for a single cell."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 3))
    
    colors = ARC_COLORS
    bars = ax.bar(range(10), probs, color=colors, edgecolor='black', linewidth=0.5)
    
    ax.set_xlabel('Color')
    ax.set_ylabel('Probability')
    ax.set_title(f'Cell ({row}, {col}) Probability Distribution')
    ax.set_xticks(range(10))
    ax.set_ylim(0, 1)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Annotate max
    max_idx = np.argmax(probs)
    ax.annotate(f'{probs[max_idx]:.2f}', xy=(max_idx, probs[max_idx]), 
               ha='center', va='bottom', fontsize=9)

print("Visualization utilities loaded")

## 5. Load ARC Tasks

In [None]:
ARC_DATA_PATHS = [
    '../marco2/data/training',
    '../marco2/data',
    '/home/ubuntu/marco2/data/training',
    '/lambda/nfs/marco2/data/training',
]

class SimpleARCLoader:
    def __init__(self, data_path: str):
        self.data_path = Path(data_path)
        self.tasks = []
        self._load_tasks()
    
    def _load_tasks(self):
        json_files = list(self.data_path.glob('*.json'))
        for json_file in sorted(json_files):
            try:
                with open(json_file, 'r') as f:
                    data = json.load(f)
                self.tasks.append({
                    'task_id': json_file.stem,
                    'train': data.get('train', []),
                    'test': data.get('test', [])
                })
            except Exception as e:
                pass
        print(f"Loaded {len(self.tasks)} tasks from {self.data_path}")
    
    def get_task(self, task_id: str):
        for task in self.tasks:
            if task['task_id'] == task_id:
                return task
        return None
    
    def get_sample_tasks(self, n: int = 5):
        indices = np.random.choice(len(self.tasks), size=min(n, len(self.tasks)), replace=False)
        return [self.tasks[i] for i in indices]

arc_loader = None
for path in ARC_DATA_PATHS:
    if os.path.exists(path):
        print(f"Found ARC data at: {path}")
        arc_loader = SimpleARCLoader(path)
        break

if arc_loader is None:
    print("ARC data not found. Using demo task.")

## 6. Create Experts

In [None]:
# Create one expert per loaded model
# 
# Choose mode:
#   'whole_grid'     - Predict entire grid at once (RECOMMENDED - most natural for LLMs)
#   'cell_reasoning' - Per-cell with chain-of-thought reasoning
#   'direct'         - Per-cell immediate next-token (fast but inaccurate)
#
EXPERT_MODE = 'whole_grid'
MAX_TOKENS = 500  # Tokens to generate (increase for larger grids)

experts = {}
for model_name in loaded_models:
    expert = StandaloneExpert(
        expert_id=f"expert_{model_name}",
        llm_manager=llm_manager,
        model_name=model_name,
        mode=EXPERT_MODE,
        max_tokens=MAX_TOKENS
    )
    experts[model_name] = expert

print(f"\nCreated {len(experts)} experts: {list(experts.keys())}")
print(f"Mode: {EXPERT_MODE}")
print(f"Max tokens: {MAX_TOKENS}")

## 7. Select a Task

In [None]:
# Get a sample task or use demo
if arc_loader is not None:
    sample_tasks = arc_loader.get_sample_tasks(n=1)
    task = sample_tasks[0] if sample_tasks else None
else:
    task = None

if task is None:
    # Demo task: fill with 1s
    task = {
        'task_id': 'demo_fill_ones',
        'train': [
            {'input': [[0, 0], [0, 0]], 'output': [[1, 1], [1, 1]]},
            {'input': [[0, 0, 0], [0, 0, 0]], 'output': [[1, 1, 1], [1, 1, 1]]}
        ],
        'test': [
            {'input': [[0, 0], [0, 0], [0, 0]], 'output': [[1, 1], [1, 1], [1, 1]]}
        ]
    }

print(f"Task: {task['task_id']}")
print(f"Training examples: {len(task['train'])}")
print(f"Test examples: {len(task['test'])}")

# Determine target size
train_output = task['train'][0]['output']
target_size = (len(train_output), len(train_output[0]))
print(f"Target size: {target_size}")

# Show task
n_train = len(task['train'])
fig, axes = plt.subplots(n_train, 2, figsize=(8, 4 * n_train))
if n_train == 1:
    axes = axes.reshape(1, 2)

for i, example in enumerate(task['train']):
    inp = np.array(example['input'])
    out = np.array(example['output'])
    plot_grid(inp, f"Train {i+1} Input", axes[i, 0])
    plot_grid(out, f"Train {i+1} Output", axes[i, 1])

plt.suptitle(f"Task: {task['task_id']}")
plt.tight_layout()
plt.show()

## 8. Get Predictions from Each Expert

In [None]:
# Get predictions from each expert
expert_results = {}

for model_name, expert in experts.items():
    print(f"\n{'='*50}")
    print(f"Expert: {model_name}")
    print(f"{'='*50}")
    
    start_time = time.time()
    result = expert.predict_grid(task, target_size)
    elapsed = time.time() - start_time
    
    expert_results[model_name] = result
    
    print(f"  Time: {elapsed:.2f}s")
    print(f"  Avg confidence: {result['confidences'].mean():.3f}")
    print(f"  Min confidence: {result['confidences'].min():.3f}")
    print(f"  Max confidence: {result['confidences'].max():.3f}")

print(f"\n\nAll {len(expert_results)} experts have completed predictions.")

## 9. Visualize Expert Outputs

In [None]:
# Show each expert's predicted grid and confidence
n_experts = len(expert_results)

fig, axes = plt.subplots(n_experts, 3, figsize=(12, 4 * n_experts))
if n_experts == 1:
    axes = axes.reshape(1, 3)

# Get ground truth if available
ground_truth = None
if task['test'] and 'output' in task['test'][0]:
    ground_truth = np.array(task['test'][0]['output'])

for idx, (model_name, result) in enumerate(expert_results.items()):
    # Predicted grid
    plot_grid(result['grid'], f"{model_name}\nPredicted Grid", axes[idx, 0], show_values=True)
    
    # Confidence heatmap
    im = plot_confidence_heatmap(result['confidences'], f"{model_name}\nConfidence", axes[idx, 1])
    
    # Ground truth comparison (if available)
    if ground_truth is not None:
        plot_grid(ground_truth, "Ground Truth", axes[idx, 2], show_values=True)
        # Check accuracy
        correct = np.sum(result['grid'] == ground_truth)
        total = ground_truth.size
        accuracy = correct / total
        axes[idx, 2].set_xlabel(f"Accuracy: {correct}/{total} = {accuracy:.1%}")
    else:
        axes[idx, 2].text(0.5, 0.5, "No ground truth", ha='center', va='center')
        axes[idx, 2].set_title("Ground Truth")

plt.suptitle(f"Expert Predictions for Task: {task['task_id']}", fontsize=14)
plt.tight_layout()
plt.show()

## 10. View Expert Reasoning (if mode='reasoning')

## 10. View Expert Reasoning (if mode='reasoning')

In [None]:
# View the raw output / reasoning from each expert
for expert_name, result in expert_results.items():
    print(f"\n{'='*70}")
    print(f"Expert: {expert_name}")
    print(f"{'='*70}")
    
    if 'raw_output' in result:
        # whole_grid mode
        print("\nRaw generated output:")
        print("-" * 50)
        print(result['raw_output'])
        print("-" * 50)
        
    elif 'reasoning' in result:
        # cell_reasoning mode
        print("\nPer-cell reasoning:")
        for item in result['reasoning'][:3]:  # Show first 3 cells
            cell = item['cell']
            text = item['text']
            predicted = result['grid'][cell[0], cell[1]]
            print(f"\n  Cell {cell} -> {predicted}:")
            print(f"    {text[:200]}...")
        if len(result['reasoning']) > 3:
            print(f"\n  ... and {len(result['reasoning']) - 3} more cells")
    else:
        print("\n  (direct mode - no reasoning available)")

## 11. Compare Probability Distributions Across Experts

In [None]:
# Select which expert to visualize
selected_expert = list(expert_results.keys())[0]  # First expert
result = expert_results[selected_expert]

h, w = target_size
fig, axes = plt.subplots(h, w, figsize=(4 * w, 3 * h))
if h == 1 and w == 1:
    axes = np.array([[axes]])
elif h == 1:
    axes = axes.reshape(1, -1)
elif w == 1:
    axes = axes.reshape(-1, 1)

for i in range(h):
    for j in range(w):
        ax = axes[i, j]
        probs = result['probabilities'][i, j]
        
        ax.bar(range(10), probs, color=ARC_COLORS, edgecolor='black', linewidth=0.3)
        ax.set_ylim(0, 1)
        ax.set_xticks(range(10))
        ax.set_xticklabels(range(10), fontsize=7)
        ax.set_title(f'({i},{j}) → {np.argmax(probs)}', fontsize=9)
        ax.tick_params(axis='y', labelsize=7)

plt.suptitle(f"{selected_expert}: Per-Cell Probability Distributions", fontsize=14)
plt.tight_layout()
plt.show()

## 12. Expert Agreement Analysis

In [None]:
# Analyze where experts agree and disagree
if len(expert_results) > 1:
    predictions = np.stack([r['grid'] for r in expert_results.values()])
    
    # Count agreements
    h, w = target_size
    agreement_map = np.zeros((h, w))
    
    for i in range(h):
        for j in range(w):
            cell_preds = predictions[:, i, j]
            unique_preds = len(np.unique(cell_preds))
            # 1 = all agree, 0 = all different
            agreement_map[i, j] = 1.0 - (unique_preds - 1) / (len(expert_results) - 1)
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    
    # Agreement heatmap
    im = axes[0].imshow(agreement_map, cmap='RdYlGn', vmin=0, vmax=1, interpolation='nearest')
    axes[0].set_title('Expert Agreement\n(green = all agree, red = all different)')
    plt.colorbar(im, ax=axes[0])
    
    # Show values
    for i in range(h):
        for j in range(w):
            preds = [str(int(predictions[k, i, j])) for k in range(len(expert_results))]
            axes[0].text(j, i, '/'.join(preds), ha='center', va='center', fontsize=8)
    
    # Summary stats
    total_cells = h * w
    full_agreement = np.sum(agreement_map == 1.0)
    partial_agreement = np.sum((agreement_map > 0) & (agreement_map < 1))
    no_agreement = np.sum(agreement_map == 0)
    
    labels = ['Full Agreement', 'Partial', 'No Agreement']
    sizes = [full_agreement, partial_agreement, no_agreement]
    colors_pie = ['#2ECC40', '#FFDC00', '#FF4136']
    
    axes[1].pie(sizes, labels=labels, colors=colors_pie, autopct='%1.0f%%', startangle=90)
    axes[1].set_title(f'Agreement Summary\n({total_cells} cells total)')
    
    plt.tight_layout()
    plt.show()
else:
    print("Need at least 2 experts for agreement analysis.")

## 13. Cleanup

In [None]:
# Unload models to free memory
llm_manager.unload()

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("GPU memory freed")