# MARCO4: Per-Cell Logits Architecture for ARC-AGI

This notebook demonstrates the MARCO4 system using:
- **Per-cell logits** from LLM experts (probabilities for colors 0-9)
- **Dempster-Shafer theory** for evidence combination across experts
- **MCU-driven branching** for uncertain cells
- **A* search** through Cognitive State Space

## Models Available
- `models/qwen` - Qwen model (BitsAndBytes 4-bit)
- `models/phi3` - Phi-3 model (BitsAndBytes 4-bit)
- `models/gpt-oss` - GPT-OSS model (Mxfp4 pre-quantized)

## 1. Setup and Imports

In [None]:
import sys
import os
import json
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
import time

# Add MARCO4 to path
sys.path.insert(0, '.')

# MARCO4 imports
from marco import (
    MCU, Expert, MARCO4Config,
    dempster_combine, compute_belief, compute_plausibility,
    token_probs_to_belief_mass, THETA,
    create_empty_grid, is_complete_grid, grid_to_string,
    ARCProblem, solve_task, evaluate_solution,
    CognitiveStateSpace, BranchStatus
)

print("MARCO4 imported successfully!")

In [None]:
# Check for GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 2. Load LLM Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, BitsAndBytesConfig
import gc

# Models that are already quantized (don't apply additional quantization)
# These models have quantization baked into their config.json
PREQUANTIZED_MODELS = {'gpt-oss'}

# Models that require bfloat16 instead of float16
BFLOAT16_MODELS = {'gpt-oss'}

class LLMManager:
    """Manages multiple LLM models for multi-expert system."""
    
    def __init__(self, device: str = "auto", use_4bit: bool = True):
        self.device = device if device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu")
        self.use_4bit = use_4bit and self.device == "cuda"
        self.models = {}       # model_name -> model
        self.tokenizers = {}   # model_name -> tokenizer
        print(f"LLMManager initialized with device: {self.device}")
    
    def load_model(self, model_path: str, model_name: str = None):
        """Load a model from path."""
        model_name = model_name or Path(model_path).name
        print(f"\nLoading {model_name} from: {model_path}")
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # Check if model is already quantized (e.g., gpt-oss with Mxfp4Config)
        is_prequantized = model_name in PREQUANTIZED_MODELS
        requires_bfloat16 = model_name in BFLOAT16_MODELS
        
        # For pre-quantized models, check if the config has quantization info
        if is_prequantized:
            try:
                config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
                quant_info = getattr(config, 'quantization_config', None)
                if quant_info:
                    quant_method = quant_info.get('quant_method', 'unknown') if isinstance(quant_info, dict) else 'custom'
                    print(f"  Model has built-in quantization: {quant_method}")
            except Exception as e:
                print(f"  Warning: Could not read config: {e}")
        
        # Determine dtype
        if self.device == "cuda":
            if requires_bfloat16:
                torch_dtype = torch.bfloat16
                print("  Using bfloat16 dtype")
            else:
                torch_dtype = torch.float16
        else:
            torch_dtype = torch.float32
        
        # Build loading kwargs
        load_kwargs = {
            'device_map': "auto" if self.device == "cuda" else None,
            'torch_dtype': torch_dtype,
            'trust_remote_code': True,
            'low_cpu_mem_usage': True,
        }
        
        # Only apply BitsAndBytes quantization to non-prequantized models
        if self.use_4bit and not is_prequantized:
            load_kwargs['quantization_config'] = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )
            print("  Using 4-bit quantization (BitsAndBytes)")
        elif is_prequantized:
            # Don't pass quantization_config - let model use its built-in config
            print("  Loading with built-in quantization config")
        
        # Load model
        model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
        
        if self.device == "cpu" and not is_prequantized:
            model = model.to(self.device)
        
        model.eval()
        
        # Store
        self.models[model_name] = model
        self.tokenizers[model_name] = tokenizer
        
        params = sum(p.numel() for p in model.parameters()) / 1e9
        print(f"  ✓ {model_name} loaded: {params:.2f}B parameters")
        
        return model_name
    
    def get_model(self, model_name: str):
        """Get a loaded model by name."""
        return self.models.get(model_name), self.tokenizers.get(model_name)
    
    def list_models(self):
        """List all loaded models."""
        return list(self.models.keys())
    
    def unload(self, model_name: str = None):
        """Unload model(s) to free memory."""
        if model_name:
            if model_name in self.models:
                del self.models[model_name]
                del self.tokenizers[model_name]
                print(f"Unloaded {model_name}")
        else:
            # Unload all
            self.models.clear()
            self.tokenizers.clear()
            print("All models unloaded")
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

# Initialize manager
llm_manager = LLMManager(device="auto", use_4bit=True)

In [None]:
# Model paths
MODEL_PATHS = {
    'qwen': '../models/qwen',
    'phi3': '../models/phi3',
    'gpt-oss': '../models/gpt-oss'
}

# Load ALL models for multi-expert system
# Each expert will use a different model for true diversity
loaded_models = []
for name, path in MODEL_PATHS.items():
    if os.path.exists(path):
        try:
            llm_manager.load_model(path, name)
            loaded_models.append(name)
        except Exception as e:
            print(f"  Failed to load {name}: {e}")
    else:
        print(f"  Model path not found: {path}")

print(f"\nLoaded {len(loaded_models)} models: {loaded_models}")

if len(loaded_models) == 0:
    raise RuntimeError("No models could be loaded. Check model paths.")

## 3. Create LLM-Based Expert

In [None]:
from marco.expert import Expert, CellLogits
import torch.nn.functional as F

class TransformerExpert(Expert):
    """
    Expert that uses a transformer model to output per-cell logits.
    
    For each cell, outputs log-probabilities for colors 0-9.
    No augmentation - single forward pass per cell.
    
    Pipeline:
    1. For each cell position, create a prompt asking for the cell value
    2. Get logits for tokens '0'-'9' from the LLM
    3. Apply log_softmax to get log-probabilities
    """
    
    def __init__(self, expert_id: str, llm_manager: LLMManager, model_name: str,
                 config=None):
        super().__init__(expert_id, config)
        self.llm_manager = llm_manager
        self.model_name = model_name
        
        self.model, self.tokenizer = llm_manager.get_model(model_name)
        if self.model is None:
            raise ValueError(f"Model '{model_name}' not loaded")
        
        # Cache token IDs for colors 0-9
        self.color_tokens = self._get_color_tokens()
        print(f"  Expert '{expert_id}' using model: {model_name}")
    
    def _get_color_tokens(self) -> Dict[int, int]:
        """Get token IDs for color digits 0-9."""
        tokens = {}
        for color in range(10):
            token_ids = self.tokenizer.encode(str(color), add_special_tokens=False)
            if token_ids:
                tokens[color] = token_ids[0]
        return tokens
    
    def _format_grid(self, grid: np.ndarray) -> str:
        """Format grid for prompt - compact representation."""
        rows = []
        for row in grid:
            row_str = ''.join('.' if v == -1 else str(int(v)) for v in row)
            rows.append(row_str)
        return '\n'.join(rows)
    
    def _format_grid_with_marker(self, grid: np.ndarray, target_row: int, target_col: int) -> str:
        """Format grid with a marker showing which cell to predict."""
        rows = []
        for i, row in enumerate(grid):
            row_chars = []
            for j, v in enumerate(row):
                if i == target_row and j == target_col:
                    row_chars.append('?')  # Mark target cell
                elif v == -1:
                    row_chars.append('.')
                else:
                    row_chars.append(str(int(v)))
            rows.append(''.join(row_chars))
        return '\n'.join(rows)
    
    def _create_cell_prompt(self, problem: Any, partial_grid: np.ndarray, 
                            row: int, col: int) -> str:
        """Create prompt for predicting a specific cell value."""
        prompt = """You are solving an ARC (Abstraction and Reasoning Corpus) puzzle.

Your task is to identify the TRANSFORMATION PATTERN from the training examples, then apply it to predict the output.

Study how each input transforms to its output:
- Look for patterns like: copying, mirroring, rotation, color replacement, counting, filling regions, etc.
- The same transformation rule applies to all examples.

Grid notation:
- Numbers 0-9 represent colors
- '.' represents unfilled cells (value -1) that need to be determined
- '?' marks the specific cell you must predict

"""
        
        if hasattr(problem, 'train') and problem.train:
            prompt += "=== Training Examples (learn the pattern) ===\n"
            for i, example in enumerate(problem.train[:3], 1):
                inp = np.array(example.get('input', []))
                out = np.array(example.get('output', []))
                prompt += f"\nExample {i}:\nInput:\n{self._format_grid(inp)}\n"
                prompt += f"Output:\n{self._format_grid(out)}\n"
        
        prompt += "\n=== Test Case ===\n"
        prompt += "Apply the same transformation pattern to fill in the output grid.\n"
        prompt += f"Current output grid (cells marked '.' are unfilled, '?' is your target):\n"
        prompt += f"{self._format_grid_with_marker(partial_grid, row, col)}\n\n"
        prompt += f"Based on the transformation pattern and any already-filled cells, the value for '?' is: "
        
        return prompt
    
    def _get_cell_logits_single(self, prompt: str) -> np.ndarray:
        """
        Get log-probabilities for colors 0-9 given a prompt.
        
        Returns:
            Array of shape (10,) with log-probs for each color
        """
        inputs = self.tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.model(**inputs, use_cache=False)
            logits = outputs.logits[0, -1, :]  # Last position
        
        # Extract logits for color tokens
        color_logits = torch.zeros(10, device=logits.device)
        for color, token_id in self.color_tokens.items():
            color_logits[color] = logits[token_id]
        
        # Apply log_softmax
        log_probs = F.log_softmax(color_logits, dim=0)
        return log_probs.cpu().numpy()
    
    def get_cell_logits(self, problem: Any, partial_grid: np.ndarray) -> CellLogits:
        """
        Get per-cell logits for the entire grid.
        
        Single forward pass per cell - no augmentation.
        """
        h, w = partial_grid.shape
        logits = np.zeros((h, w, 10))
        
        for i in range(h):
            for j in range(w):
                if partial_grid[i, j] >= 0:
                    # Already filled - deterministic (100% mass on that color)
                    logits[i, j, :] = -100.0
                    logits[i, j, int(partial_grid[i, j])] = 0.0
                else:
                    # Query LLM for this cell
                    prompt = self._create_cell_prompt(problem, partial_grid, i, j)
                    logits[i, j] = self._get_cell_logits_single(prompt)
        
        return CellLogits(logits=logits)

print("TransformerExpert defined with per-cell logits (no augmentation)")

## 4. Load ARC-AGI Tasks

In [None]:
# ARCTaskLoader is defined in the next cell along with data loading

In [None]:
# Set path to ARC data
# Based on MARCO3 setup, data is at ../marco2/data/training/
ARC_DATA_PATHS = [
    '../marco2/data/training',      # MARCO3 location
    '../marco2/data',               # Alternative
    '/home/ubuntu/marco2/data/training',  # Absolute path on server
    '/lambda/nfs/marco2/data/training',
    '../arc-agi/data',
    os.path.expanduser('~/arc-agi')
]

# Simple loader for flat directory of JSON files
class SimpleARCLoader:
    """Load ARC tasks from a flat directory of JSON files."""
    
    def __init__(self, data_path: str):
        self.data_path = Path(data_path)
        self.tasks = []
        self._load_tasks()
    
    def _load_tasks(self):
        """Load all JSON task files."""
        json_files = list(self.data_path.glob('*.json'))
        for json_file in sorted(json_files):
            try:
                with open(json_file, 'r') as f:
                    data = json.load(f)
                self.tasks.append({
                    'task_id': json_file.stem,
                    'train': data.get('train', []),
                    'test': data.get('test', [])
                })
            except Exception as e:
                print(f"Error loading {json_file.name}: {e}")
        print(f"Loaded {len(self.tasks)} tasks from {self.data_path}")
    
    def get_task(self, task_id: str):
        """Get task by ID."""
        for task in self.tasks:
            if task['task_id'] == task_id:
                return task
        return None
    
    def get_training_tasks(self):
        return self.tasks
    
    def get_sample_tasks(self, n: int = 5, split: str = 'training'):
        """Get random sample of tasks."""
        indices = np.random.choice(len(self.tasks), size=min(n, len(self.tasks)), replace=False)
        return [self.tasks[i] for i in indices]

arc_loader = None
for path in ARC_DATA_PATHS:
    if os.path.exists(path):
        print(f"Found ARC data at: {path}")
        arc_loader = SimpleARCLoader(path)
        break

if arc_loader is None:
    print("ARC data not found at expected locations.")
    print("Tried:", ARC_DATA_PATHS)
    print("\nWill use synthetic demo tasks instead.")

## 5. Configure MARCO4 System

In [None]:
# Create configuration
config = MARCO4Config()

# Pruning thresholds (MCU-level)
config.pruning.mcu_prune_threshold = 0.05  # Prune branches below this combined mass
config.pruning.max_conflict = 0.50         # Max conflict before pruning
config.pruning.no_progress_rounds = 5      # Prune if no progress for N rounds

# Confidence thresholds  
config.confidence.solution_threshold = 0.80  # Accept solution above this
config.confidence.high_confidence = 0.30     # Fix cell if belief > this

# Search parameters
config.search.max_iterations = 100
config.search.max_branches = 100             # Max active branches in CSS
config.search.beam_width = 10                # Top branches per iteration

# MCU-driven branching
config.search.branch_threshold = 0.15        # Create branches for candidates above this
config.search.max_branches_per_cell = 3      # Max branches per uncertain cell

# Expert parameters
config.expert.num_experts = 3
config.expert.temperature = 0.7

print("Configuration:")
print(json.dumps(config.to_dict(), indent=2))

## 6. Create MARCO4 System

In [None]:
# Create one expert per loaded model - TRUE diversity
# Each expert uses a DIFFERENT LLM (qwen, phi3, gpt-oss)
experts = []
for model_name in loaded_models:
    expert = TransformerExpert(
        expert_id=f"expert_{model_name}",
        llm_manager=llm_manager,
        model_name=model_name,
        config=config
    )
    experts.append(expert)

print(f"\nCreated {len(experts)} experts with different models:")
for exp in experts:
    print(f"  - {exp.expert_id} -> {exp.model_name}")

# Update config to match actual number of experts
config.expert.num_experts = len(experts)

# Create MCU with PARALLEL execution enabled
mcu = MCU(experts, config, parallel=True)
print(f"\nMCU initialized with {len(experts)} experts (parallel execution: ON)")

## 7. Visualization Utilities

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Rectangle

# ARC color palette (colors 0-9)
ARC_COLORS = [
    '#000000',  # 0: black
    '#0074D9',  # 1: blue
    '#FF4136',  # 2: red
    '#2ECC40',  # 3: green
    '#FFDC00',  # 4: yellow
    '#AAAAAA',  # 5: gray
    '#F012BE',  # 6: magenta
    '#FF851B',  # 7: orange
    '#7FDBFF',  # 8: cyan
    '#870C25',  # 9: brown
]

# Color for unfilled cells (-1)
UNFILLED_COLOR = '#FFFFFF'  # White background
UNFILLED_PATTERN_COLOR = '#CCCCCC'  # Light gray for pattern

def plot_grid(grid: np.ndarray, title: str = "", ax=None, show_unfilled_count: bool = True):
    """
    Plot a single grid with ARC colors.
    
    Unfilled cells (-1) are shown as white with a diagonal hatch pattern.
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(4, 4))
    
    h, w = grid.shape
    
    # Create a masked array for proper color mapping
    # We'll handle -1 cells separately
    cmap = mcolors.ListedColormap(ARC_COLORS)
    bounds = np.arange(-0.5, 10.5, 1)
    norm = mcolors.BoundaryNorm(bounds, cmap.N)
    
    # Create display grid: replace -1 with 0 for imshow, we'll overlay unfilled cells
    display_grid = grid.copy().astype(float)
    unfilled_mask = (grid == -1)
    display_grid[unfilled_mask] = np.nan  # Will show as white (bad color)
    
    # Set white background for the axes
    ax.set_facecolor('white')
    
    # Plot the grid
    ax.imshow(display_grid, cmap=cmap, norm=norm, interpolation='nearest')
    
    # Draw hatched rectangles for unfilled cells (-1)
    unfilled_count = 0
    for i in range(h):
        for j in range(w):
            if grid[i, j] == -1:
                unfilled_count += 1
                # Add a hatched rectangle to indicate unfilled
                rect = Rectangle(
                    (j - 0.5, i - 0.5), 1, 1,
                    linewidth=0,
                    edgecolor='none',
                    facecolor=UNFILLED_COLOR,
                    zorder=1
                )
                ax.add_patch(rect)
                # Add diagonal lines pattern
                ax.plot([j - 0.5, j + 0.5], [i - 0.5, i + 0.5], 
                       color=UNFILLED_PATTERN_COLOR, linewidth=1, zorder=2)
                ax.plot([j - 0.5, j + 0.5], [i + 0.5, i - 0.5], 
                       color=UNFILLED_PATTERN_COLOR, linewidth=1, zorder=2)
    
    # Update title with unfilled count if requested
    if show_unfilled_count and unfilled_count > 0:
        total_cells = h * w
        title = f"{title}\n[{unfilled_count}/{total_cells} unfilled]"
    
    ax.set_title(title, fontsize=10)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlim(-0.5, w - 0.5)
    ax.set_ylim(h - 0.5, -0.5)
    
    # Add grid lines
    for i in range(h + 1):
        ax.axhline(i - 0.5, color='#888888', linewidth=0.5, zorder=3)
    for j in range(w + 1):
        ax.axvline(j - 0.5, color='#888888', linewidth=0.5, zorder=3)

def plot_task(task: Dict):
    """Plot all examples in a task."""
    train = task.get('train', [])
    test = task.get('test', [])
    
    n_train = len(train)
    n_test = len(test)
    
    fig, axes = plt.subplots(n_train + n_test, 2, figsize=(8, 4 * (n_train + n_test)))
    if n_train + n_test == 1:
        axes = axes.reshape(1, 2)
    
    for i, example in enumerate(train):
        inp = np.array(example.get('input', example[0] if isinstance(example, tuple) else []))
        out = np.array(example.get('output', example[1] if isinstance(example, tuple) else []))
        plot_grid(inp, f"Train {i+1} Input", axes[i, 0], show_unfilled_count=False)
        plot_grid(out, f"Train {i+1} Output", axes[i, 1], show_unfilled_count=False)
    
    for i, example in enumerate(test):
        inp = np.array(example.get('input', example[0] if isinstance(example, tuple) else []))
        plot_grid(inp, f"Test {i+1} Input", axes[n_train + i, 0], show_unfilled_count=False)
        if 'output' in example:
            out = np.array(example['output'])
            plot_grid(out, f"Test {i+1} Output (ground truth)", axes[n_train + i, 1], show_unfilled_count=False)
        else:
            axes[n_train + i, 1].text(0.5, 0.5, "?", ha='center', va='center', fontsize=40)
            axes[n_train + i, 1].set_title("Test Output (to predict)")
    
    plt.suptitle(f"Task: {task.get('task_id', 'Unknown')}")
    plt.tight_layout()
    plt.show()

def plot_solution_comparison(task: Dict, solution: np.ndarray):
    """Compare predicted solution with ground truth."""
    test = task.get('test', [{}])[0]
    expected = np.array(test.get('output', [[]]))
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    plot_grid(np.array(test.get('input', [[]])), "Test Input", axes[0], show_unfilled_count=False)
    plot_grid(solution, "Predicted", axes[1])
    
    if expected.size > 0:
        plot_grid(expected, "Expected", axes[2], show_unfilled_count=False)
        correct = np.array_equal(solution, expected)
        fig.suptitle(f"{'✓ CORRECT' if correct else '✗ INCORRECT'}", fontsize=14)
    
    plt.tight_layout()
    plt.show()

print("Visualization utilities loaded - unfilled cells (-1) shown with X pattern")

## 7b. Live Search Visualization

The MCU now supports a progress callback that lets you visualize the grid evolution during search.

In [None]:
from IPython.display import display, clear_output
from marco.mcu import GridState

class SearchVisualizer:
    """
    Visualizes the evolution of the grid during MCU search.
    
    Shows:
    - Filled cells with ARC colors
    - Unfilled cells (-1) with X pattern
    - Count of unfilled cells per iteration
    
    Usage:
        visualizer = SearchVisualizer()
        mcu_with_viz = MCU(experts, config, progress_callback=visualizer.callback)
        result = mcu_with_viz.solve(problem, target_size=(h, w))
        visualizer.show_history()  # Show all iterations
    """
    
    def __init__(self, live_update: bool = True, max_history: int = 50):
        """
        Args:
            live_update: If True, update display in real-time (requires Jupyter)
            max_history: Maximum number of iterations to store
        """
        self.live_update = live_update
        self.max_history = max_history
        self.history = []  # List of (iteration, grid, confidence, conflict, branches)
        self.fig = None
        self.axes = None
    
    def callback(self, iteration, grid_state, partial_grid, best_solution, 
                 best_confidence, conflict_level, active_branches, total_branches):
        """Progress callback for MCU.solve()"""
        # Count unfilled cells
        unfilled_cells = np.sum(partial_grid == -1)
        total_cells = partial_grid.size
        filled_cells = total_cells - unfilled_cells
        
        # Store in history
        self.history.append({
            'iteration': iteration,
            'partial_grid': partial_grid.copy(),
            'best_solution': best_solution.copy() if best_solution is not None else None,
            'best_confidence': best_confidence,
            'conflict_level': conflict_level,
            'active_branches': active_branches,
            'total_branches': total_branches,
            'filled_cells': filled_cells,
            'unfilled_cells': unfilled_cells,
            'total_cells': total_cells
        })
        
        # Trim history if needed
        if len(self.history) > self.max_history:
            self.history = self.history[-self.max_history:]
        
        # Live update
        if self.live_update:
            self._update_display()
    
    def _update_display(self):
        """Update the live display."""
        if not self.history:
            return
        
        clear_output(wait=True)
        
        latest = self.history[-1]
        
        # Create figure with current state
        fig, axes = plt.subplots(1, 2, figsize=(10, 4))
        
        # Current partial grid - plot_grid now shows unfilled cells with X pattern
        plot_grid(latest['partial_grid'], 
                  f"Iteration {latest['iteration']}: Partial Grid", axes[0])
        
        # Best solution so far (or partial if none)
        if latest['best_solution'] is not None:
            plot_grid(latest['best_solution'], 
                      f"Best Solution (conf: {latest['best_confidence']:.3f})", axes[1])
        else:
            axes[1].text(0.5, 0.5, "No complete\nsolution yet", 
                        ha='center', va='center', fontsize=12)
            axes[1].set_title("Best Solution")
            axes[1].set_facecolor('white')
        
        # Add stats as text - now includes unfilled count
        stats_text = (
            f"Filled: {latest['filled_cells']}/{latest['total_cells']} | "
            f"Unfilled: {latest['unfilled_cells']} | "
            f"Conflict: {latest['conflict_level']:.3f} | "
            f"Branches: {latest['active_branches']}/{latest['total_branches']}"
        )
        fig.suptitle(stats_text, fontsize=10)
        
        plt.tight_layout()
        plt.show()
    
    def show_history(self, step: int = 1):
        """
        Show the evolution of the grid across iterations.
        
        Args:
            step: Show every Nth iteration (default: 1 = all)
        """
        if not self.history:
            print("No history recorded.")
            return
        
        # Select iterations to show
        selected = self.history[::step]
        n = len(selected)
        
        # Create grid of plots
        cols = min(4, n)
        rows = (n + cols - 1) // cols
        
        fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
        if rows == 1 and cols == 1:
            axes = np.array([[axes]])
        elif rows == 1:
            axes = axes.reshape(1, -1)
        elif cols == 1:
            axes = axes.reshape(-1, 1)
        
        for idx, record in enumerate(selected):
            row, col = idx // cols, idx % cols
            ax = axes[row, col]
            
            grid = record['partial_grid']
            # Title shows iteration and unfilled count
            title = f"Iter {record['iteration']}"
            plot_grid(grid, title, ax, show_unfilled_count=True)
        
        # Hide empty subplots
        for idx in range(n, rows * cols):
            row, col = idx // cols, idx % cols
            axes[row, col].axis('off')
        
        plt.suptitle("Grid Evolution During Search (X = unfilled cells)", fontsize=14)
        plt.tight_layout()
        plt.show()
    
    def show_confidence_curve(self):
        """Plot confidence, fill progress, and unfilled cells over iterations."""
        if not self.history:
            print("No history recorded.")
            return
        
        iterations = [h['iteration'] for h in self.history]
        confidences = [h['best_confidence'] for h in self.history]
        fill_pcts = [h['filled_cells'] / h['total_cells'] * 100 for h in self.history]
        unfilled_counts = [h['unfilled_cells'] for h in self.history]
        conflicts = [h['conflict_level'] for h in self.history]
        
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        
        # Confidence
        axes[0, 0].plot(iterations, confidences, 'b-', linewidth=2, marker='o', markersize=4)
        axes[0, 0].set_xlabel('Iteration')
        axes[0, 0].set_ylabel('Best Confidence')
        axes[0, 0].set_title('Confidence Evolution')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Fill percentage
        axes[0, 1].plot(iterations, fill_pcts, 'g-', linewidth=2, marker='o', markersize=4)
        axes[0, 1].set_xlabel('Iteration')
        axes[0, 1].set_ylabel('Cells Filled (%)')
        axes[0, 1].set_title('Grid Completion')
        axes[0, 1].set_ylim(0, 105)
        axes[0, 1].grid(True, alpha=0.3)
        
        # Unfilled cells count
        axes[1, 0].plot(iterations, unfilled_counts, 'm-', linewidth=2, marker='s', markersize=4)
        axes[1, 0].set_xlabel('Iteration')
        axes[1, 0].set_ylabel('Unfilled Cells (-1)')
        axes[1, 0].set_title('Remaining Unfilled Cells')
        axes[1, 0].grid(True, alpha=0.3)
        # Add horizontal line at 0
        axes[1, 0].axhline(y=0, color='green', linestyle='--', alpha=0.5, label='Complete')
        axes[1, 0].legend()
        
        # Conflict
        axes[1, 1].plot(iterations, conflicts, 'r-', linewidth=2, marker='o', markersize=4)
        axes[1, 1].set_xlabel('Iteration')
        axes[1, 1].set_ylabel('Conflict Level')
        axes[1, 1].set_title('Expert Conflict')
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def reset(self):
        """Clear history for a new search."""
        self.history = []

print("SearchVisualizer defined - tracks unfilled cells (-1) with visual indicators!")

In [None]:
# Example: Solve with live visualization
# Set live_update=False if you just want to see the history after solving

visualizer = SearchVisualizer(live_update=False)  # Set True for real-time updates

# Create MCU with visualizer callback
mcu_viz = MCU(experts, config, parallel=True, progress_callback=visualizer.callback)

# Use a demo task for visualization
viz_task = {
    'task_id': 'viz_demo',
    'train': [
        {'input': [[0, 0], [0, 0]], 'output': [[1, 1], [1, 1]]},
    ],
    'test': [
        {'input': [[0, 0, 0], [0, 0, 0]], 'output': [[1, 1, 1], [1, 1, 1]]}
    ]
}

# Solve
problem = ARCProblem(viz_task)
result = mcu_viz.solve(problem, target_size=(2, 3))

print(f"Solved in {result.iterations} iterations")
print(f"Status: {result.status}, Confidence: {result.confidence:.3f}")

# Show the grid evolution
visualizer.show_history()

# Show metrics curves
visualizer.show_confidence_curve()

## 8. Solve a Single Task

In [None]:
def solve_arc_task(task: Dict, mcu: MCU, verbose: bool = True, visualizer: SearchVisualizer = None) -> Dict:
    """
    Solve an ARC task using MARCO4.
    
    Args:
        task: ARC task dict with train/test examples
        mcu: MCU instance (will be recreated with visualizer if provided)
        verbose: Print progress info
        visualizer: Optional SearchVisualizer for tracking grid evolution
    
    Returns:
        Dict with solution, confidence, metrics
    """
    task_id = task.get('task_id', 'unknown')
    if verbose:
        print(f"\n{'='*60}")
        print(f"Solving task: {task_id}")
        print(f"{'='*60}")
    
    # Create ARCProblem
    problem = ARCProblem(task)
    
    # Infer target size from training
    target_size = None
    if task.get('train'):
        output = task['train'][0].get('output', [[]])
        target_size = (len(output), len(output[0]) if output else 0)
        if verbose:
            print(f"Target size: {target_size}")
    
    # If visualizer provided, create MCU with callback
    if visualizer is not None:
        visualizer.reset()  # Clear previous history
        solve_mcu = MCU(
            mcu.experts, 
            mcu.config, 
            parallel=mcu.parallel,
            progress_callback=visualizer.callback
        )
    else:
        solve_mcu = mcu
    
    # Solve
    start_time = time.time()
    result = solve_mcu.solve(problem, target_size=target_size)
    elapsed = time.time() - start_time
    
    if verbose:
        print(f"\nResult:")
        print(f"  Status: {result.status}")
        print(f"  Confidence: {result.confidence:.4f}")
        print(f"  Iterations: {result.iterations}")
        print(f"  Branches explored: {result.branches_explored}")
        print(f"  Time: {elapsed:.2f}s")
    
    # Evaluate if ground truth available
    evaluation = None
    if task.get('test') and 'output' in task['test'][0]:
        expected = task['test'][0]['output']
        if result.solution is not None:
            evaluation = evaluate_solution(result.solution, expected)
            if verbose:
                print(f"\nEvaluation:")
                print(f"  Correct: {evaluation['correct']}")
                print(f"  Accuracy: {evaluation['accuracy']:.2%}")
    
    return {
        'task_id': task_id,
        'solution': result.solution,
        'confidence': result.confidence,
        'status': result.status,
        'iterations': result.iterations,
        'branches': result.branches_explored,
        'time': elapsed,
        'evaluation': evaluation
    }

In [None]:
# Get a sample task
if arc_loader is not None:
    sample_tasks = arc_loader.get_sample_tasks(n=1, split='training')
    if sample_tasks:
        task = sample_tasks[0]
        print(f"Selected task: {task['task_id']}")
        plot_task(task)
else:
    # Use a simple synthetic task for demo
    task = {
        'task_id': 'demo_fill_ones',
        'train': [
            {'input': [[0, 0], [0, 0]], 'output': [[1, 1], [1, 1]]},
            {'input': [[0, 0, 0], [0, 0, 0]], 'output': [[1, 1, 1], [1, 1, 1]]}
        ],
        'test': [
            {'input': [[0, 0], [0, 0], [0, 0]], 'output': [[1, 1], [1, 1], [1, 1]]}
        ]
    }
    print("Using synthetic demo task")
    plot_task(task)

In [None]:
# Create visualizer for tracking search progress
# Set live_update=True to see real-time updates (may slow down solving)
task_visualizer = SearchVisualizer(live_update=False)

# Solve the task with visualization
result = solve_arc_task(task, mcu, verbose=True, visualizer=task_visualizer)

# Visualize solution comparison
if result['solution'] is not None:
    plot_solution_comparison(task, result['solution'])

# Show grid evolution during search
print("\n" + "="*60)
print("Grid Evolution During Search")
print("="*60)
task_visualizer.show_history()

# Show metrics curves
task_visualizer.show_confidence_curve()

## 9. Batch Evaluation

In [None]:
def evaluate_batch(tasks: List[Dict], mcu: MCU, verbose: bool = False) -> Dict:
    """
    Evaluate MARCO4 on a batch of tasks.
    
    Returns:
        Summary statistics
    """
    results = []
    correct = 0
    total = 0
    
    for i, task in enumerate(tasks):
        print(f"\nTask {i+1}/{len(tasks)}: {task.get('task_id', 'unknown')}")
        
        try:
            result = solve_arc_task(task, mcu, verbose=verbose)
            results.append(result)
            
            if result['evaluation'] is not None:
                total += 1
                if result['evaluation']['correct']:
                    correct += 1
                    print(f"  ✓ Correct (confidence: {result['confidence']:.3f})")
                else:
                    print(f"  ✗ Incorrect (accuracy: {result['evaluation']['accuracy']:.1%})")
            else:
                print(f"  ? No ground truth (confidence: {result['confidence']:.3f})")
                
        except Exception as e:
            print(f"  Error: {e}")
            results.append({'task_id': task.get('task_id'), 'error': str(e)})
    
    # Summary
    accuracy = correct / total if total > 0 else 0
    avg_time = np.mean([r.get('time', 0) for r in results if 'time' in r])
    avg_iterations = np.mean([r.get('iterations', 0) for r in results if 'iterations' in r])
    
    summary = {
        'total_tasks': len(tasks),
        'evaluated': total,
        'correct': correct,
        'accuracy': accuracy,
        'avg_time': avg_time,
        'avg_iterations': avg_iterations,
        'results': results
    }
    
    print(f"\n{'='*60}")
    print(f"BATCH RESULTS")
    print(f"{'='*60}")
    print(f"Accuracy: {correct}/{total} = {accuracy:.1%}")
    print(f"Avg time: {avg_time:.2f}s")
    print(f"Avg iterations: {avg_iterations:.1f}")
    
    return summary

In [None]:
# Evaluate on sample of tasks
if arc_loader is not None:
    sample_tasks = arc_loader.get_sample_tasks(n=5, split='training')
    summary = evaluate_batch(sample_tasks, mcu, verbose=False)
else:
    print("No ARC data loaded. Skipping batch evaluation.")

## 10. Analyze D-S Combination

In [None]:
# Demonstrate Dempster-Shafer combination
from marco.dempster_shafer import (
    dempster_combine, compute_belief, compute_plausibility,
    get_pignistic_distribution, get_conflict_level, mass_to_string, THETA
)

print("Dempster-Shafer Combination Demo")
print("="*50)

# Expert 1: High confidence in color 1
m1 = {frozenset([1]): 0.8, THETA: 0.2}
print(f"\nExpert 1: {mass_to_string(m1)}")

# Expert 2: Medium confidence in color 1
m2 = {frozenset([1]): 0.6, frozenset([1, 2]): 0.2, THETA: 0.2}
print(f"Expert 2: {mass_to_string(m2)}")

# Combine
combined = dempster_combine(m1, m2)
print(f"\nCombined: {mass_to_string(combined)}")

# Pignistic probabilities
pignistic = get_pignistic_distribution(combined)
print(f"\nPignistic probabilities:")
for color, prob in sorted(pignistic.items(), key=lambda x: -x[1]):
    if prob > 0.01:
        print(f"  Color {color}: {prob:.3f}")

# Conflict example
print("\n" + "="*50)
print("High Conflict Example")
m3 = {frozenset([1]): 0.9, THETA: 0.1}
m4 = {frozenset([2]): 0.9, THETA: 0.1}
print(f"\nExpert A: {mass_to_string(m3)}")
print(f"Expert B: {mass_to_string(m4)}")

conflict = get_conflict_level(m3, m4)
print(f"\nConflict level K = {conflict:.3f}")

combined_conflict = dempster_combine(m3, m4)
print(f"Combined (with high conflict): {mass_to_string(combined_conflict)}")

## 11. CSS Analysis

In [None]:
# Analyze CSS after solving
stats = mcu.get_statistics()

print("Cognitive State Space Statistics")
print("="*50)
print(f"Total branches created: {stats['css']['total_created']}")
print(f"Active branches: {stats['css']['active_branches']}")
print(f"Complete branches: {stats['css']['complete_branches']}")
print(f"Pruned branches: {stats['css']['pruned_branches']}")
print(f"\nMax active mass: {stats['css']['max_active_mass']:.4f}")
print(f"Avg active mass: {stats['css']['avg_active_mass']:.4f}")
print(f"Max complete mass: {stats['css']['max_complete_mass']:.4f}")

## 12. Cleanup

In [None]:
# Unload model to free memory
llm_manager.unload()

# Clear CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU memory freed")