# 🎯 LangGraph Multi-Agent MCTS Framework

## Production Demo with Trained Neural Meta-Controllers

This notebook demonstrates the **LangGraph Multi-Agent MCTS Framework** - a sophisticated multi-agent system that combines:

- **LangGraph** for explicit state management and agent orchestration
- **Monte Carlo Tree Search (MCTS)** for strategic planning and exploration
- **Neural Meta-Controllers** (RNN and BERT with LoRA) for intelligent agent routing

### 🧠 Trained Models
- **RNN Meta-Controller**: GRU-based sequential pattern recognition for fast routing
- **BERT with LoRA**: Transformer-based text understanding with parameter-efficient fine-tuning

### 🤖 Agents
- **HRM (Hierarchical Reasoning Model)**: Decomposes complex problems hierarchically
- **TRM (Tiny Recursive Model)**: Iterative refinement for progressive improvement
- **MCTS**: Monte Carlo Tree Search for optimization and strategic exploration

---

**Repository**: [github.com/ianshank/langgraph_multi_agent_mcts](https://github.com/ianshank/langgraph_multi_agent_mcts)

## 🚀 Step 1: Environment Setup

Clone the repository and install all dependencies. This cell handles:
- Repository cloning (using Python's `shutil` for safe directory removal)
- Dependency installation (including `nest_asyncio` for async support in Colab)
- Path configuration

### Why `nest_asyncio`?
Google Colab already runs an asyncio event loop in the background. When LangGraph agents
try to create their own event loops, this causes conflicts. `nest_asyncio` patches the
asyncio module to allow nested event loops, enabling async agent operations within Colab's
existing loop.

In [None]:
"""Environment setup cell - clone repository and install dependencies."""
from __future__ import annotations

import os
import shutil
import subprocess
import sys
from pathlib import Path
from typing import Final

# =============================================================================
# CONFIGURATION CONSTANTS
# =============================================================================
COLAB_CONTENT_DIR: Final[str] = "/content"
REPO_NAME: Final[str] = "langgraph_multi_agent_mcts"
REPO_PATH: Final[str] = f"{COLAB_CONTENT_DIR}/{REPO_NAME}"
REPO_URL: Final[str] = "https://github.com/ianshank/langgraph_multi_agent_mcts.git"

# =============================================================================
# SETUP FUNCTIONS
# =============================================================================

def safe_remove_directory(path: Path) -> None:
    """Safely remove a directory with validation.
    
    Args:
        path: Path to the directory to remove.
    
    Raises:
        ValueError: If path is outside allowed locations.
    """
    path = path.resolve()
    allowed_prefixes = [Path(COLAB_CONTENT_DIR).resolve(), Path("/tmp").resolve()]
    
    if not any(str(path).startswith(str(prefix)) for prefix in allowed_prefixes):
        raise ValueError(f"Cannot delete directory outside allowed paths: {path}")
    
    if path.exists() and path.is_dir():
        print(f"   Removing existing directory: {path}")
        shutil.rmtree(path)

# =============================================================================
# MAIN SETUP
# =============================================================================

print("📦 Setting up repository...")
repo_dir = Path(REPO_PATH)
safe_remove_directory(repo_dir)

# Clone repository
print("📦 Cloning repository...")
!git clone {REPO_URL} {REPO_PATH}

# Change to repo directory
%cd {REPO_PATH}

# Install dependencies
print("\n📦 Installing dependencies...")
!pip install -q -r requirements.txt

# Install Colab-specific packages
# - nest_asyncio: Enables nested event loops (required for LangGraph in Colab)
# - ipywidgets: Interactive widgets for Jupyter
# - matplotlib: Visualization library for routing probability charts
print("\n📦 Installing Colab-specific packages...")
!pip install -q nest_asyncio ipywidgets matplotlib

# Apply nest_asyncio for async support
import nest_asyncio
nest_asyncio.apply()

# Add repo to Python path
sys.path.insert(0, REPO_PATH)

print("\n✅ Setup complete!")
print(f"📁 Working directory: {os.getcwd()}")

## 🔑 Step 2: API Key Configuration (Optional)

The neural meta-controllers use **pre-trained local models**, so API keys are **optional**.

However, if you want to:
- Use LLM-powered agents (HRM/TRM with actual generation)
- Enable LangSmith tracing for debugging
- Use Weights & Biases for experiment tracking

You can configure your API keys below.

In [None]:
"""API key configuration with secure handling."""
from __future__ import annotations

import getpass
import os
from typing import Optional


def set_key(name: str, required: bool = False) -> Optional[str]:
    """Set API key from Colab Secrets or manual input.
    
    Attempts to load API key from Google Colab's secret manager first,
    falling back to secure password input if required.
    
    Args:
        name: The environment variable name for the API key.
        required: If True, prompt user for input when not found in secrets.
    
    Returns:
        The API key value or None if not set and not required.
    
    Examples:
        >>> key = set_key("OPENAI_API_KEY")
        >>> if key:
        ...     os.environ["OPENAI_API_KEY"] = key
    """
    # Try Colab Secrets first
    try:
        from google.colab import userdata
        value = userdata.get(name)
        if value:
            print(f"✅ {name} loaded from Colab Secrets")
            return value
    except ImportError:
        # Not running in Colab environment
        pass
    except (KeyError, AttributeError):
        # Secret not found in Colab userdata
        pass
    
    if required:
        return getpass.getpass(f"Enter your {name}: ")
    
    print(f"⚠️ {name} not set (optional)")
    return None


# Configure API Keys
print("🔑 Configuring API Keys...\n")

# OpenAI - for LLM-powered agents (optional)
openai_key = set_key("OPENAI_API_KEY")
if openai_key:
    os.environ["OPENAI_API_KEY"] = openai_key

# LangSmith - for tracing (optional but recommended)
langchain_key = set_key("LANGCHAIN_API_KEY")
if langchain_key:
    os.environ["LANGCHAIN_TRACING_V2"] = "true"
    os.environ["LANGCHAIN_API_KEY"] = langchain_key
    os.environ["LANGCHAIN_PROJECT"] = "langgraph-multi-agent-mcts"

# Weights & Biases - for experiment tracking (optional)
wandb_key = set_key("WANDB_API_KEY")
if wandb_key:
    os.environ["WANDB_API_KEY"] = wandb_key

print("\n✅ API key configuration complete!")

## 🧠 Step 3: Load Trained Neural Meta-Controllers

Initialize the framework with the pre-trained models:
- **RNN Meta-Controller**: Fast, captures sequential patterns (10D features → 3-class routing)
- **BERT with LoRA**: Context-aware text understanding for complex routing decisions

In [None]:
"""Load trained neural meta-controllers with secure model loading."""
from __future__ import annotations

from pathlib import Path
from typing import Dict, Any

import torch

# Local imports
from src.agents.meta_controller.rnn_controller import RNNMetaController
from src.agents.meta_controller.bert_controller_v2 import BERTMetaController
from src.agents.meta_controller.base import MetaControllerFeatures

# =============================================================================
# CONSTANTS
# =============================================================================
MIN_PYTORCH_VERSION: str = "1.13.0"  # Minimum version with weights_only support

print("🧠 Initializing Neural Meta-Controllers...\n")

# Detect device
device: str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"💻 Device: {device}")
if device == "cuda":
    print(f"   GPU: {torch.cuda.get_device_name(0)}")


def load_torch_weights_secure(path: Path, map_location: str) -> Dict[str, Any]:
    """Load PyTorch weights securely with weights_only=True.
    
    This function enforces secure model loading to prevent arbitrary code
    execution from malicious pickle files. Requires PyTorch >= 1.13.0.
    
    Args:
        path: Path to the model weights file (.pt or .pth).
        map_location: Device to load the model onto ('cpu' or 'cuda').
    
    Returns:
        The loaded state dictionary.
    
    Raises:
        RuntimeError: If PyTorch version is too old for secure loading.
        FileNotFoundError: If the model file doesn't exist.
    
    Note:
        The weights_only=True parameter prevents pickle deserialization
        attacks (CVE-2022-0778 class vulnerabilities). Never use
        torch.load() without this parameter on untrusted files.
    """
    if not path.exists():
        raise FileNotFoundError(f"Model file not found: {path}")
    
    # Require secure loading - no fallback to unsafe torch.load()
    # PyTorch >= 1.13.0 is required (repo requires torch>=2.1.0)
    return torch.load(path, map_location=map_location, weights_only=True)


# =============================================================================
# LOAD RNN CONTROLLER
# =============================================================================
print("\n🔄 Loading RNN Meta-Controller...")
rnn_controller = RNNMetaController(name="RNNController", seed=42, device=device)

rnn_model_path = Path(REPO_PATH) / "models" / "rnn_meta_controller.pt"
if rnn_model_path.exists():
    try:
        checkpoint = load_torch_weights_secure(rnn_model_path, map_location=device)
        rnn_controller.model.load_state_dict(checkpoint)
        rnn_controller.model.eval()
        print(f"   ✅ Loaded trained weights from {rnn_model_path.name}")
    except RuntimeError as e:
        print(f"   ⚠️ PyTorch error loading weights: {e}")
        print("      Using untrained model.")
else:
    print(f"   ⚠️ Weights not found at {rnn_model_path}")
    print("      Using untrained model.")

# =============================================================================
# LOAD BERT CONTROLLER
# =============================================================================
print("\n🤖 Loading BERT Meta-Controller with LoRA...")
bert_controller = BERTMetaController(
    name="BERTController", seed=42, device=device, use_lora=True
)

bert_model_path = Path(REPO_PATH) / "models" / "bert_lora" / "final_model"
if bert_model_path.exists():
    try:
        bert_controller.load_model(str(bert_model_path))
        print(f"   ✅ Loaded trained LoRA weights from {bert_model_path.name}")
    except (OSError, IOError) as e:
        print(f"   ⚠️ File I/O error: {e}")
        print("      Check that model files are not corrupted.")
    except RuntimeError as e:
        print(f"   ⚠️ PyTorch runtime error: {e}")
        print("      This may be due to CUDA/model architecture mismatch.")
    except ValueError as e:
        print(f"   ⚠️ Configuration error: {e}")
        print("      Model may be incompatible with current BERT/LoRA config.")
else:
    print(f"   ⚠️ Weights not found at {bert_model_path}")
    print("      Using untrained model.")

print("\n✅ Meta-controllers loaded successfully!")

## 🎮 Step 4: Interactive Agent Routing Demo

Try the neural meta-controllers! Enter a query and see how the controllers decide which agent to route it to.

In [None]:
"""Feature extraction and routing functions with comprehensive type hints."""
from __future__ import annotations

import re
from typing import List, Optional, Tuple, Final

# =============================================================================
# KEYWORD CONSTANTS FOR HEURISTIC FEATURE EXTRACTION
# =============================================================================
TECHNICAL_KEYWORDS: Final[List[str]] = [
    "algorithm", "code", "implement", "technical", "system",
    "function", "class", "method", "api", "database",
]

COMPARISON_KEYWORDS: Final[List[str]] = [
    "vs", "versus", "compare", "difference", "better",
    "pros", "cons", "tradeoff", "advantage", "disadvantage",
]

OPTIMIZATION_KEYWORDS: Final[List[str]] = [
    "optimize", "best", "improve", "maximize", "minimize",
    "efficient", "performance", "faster", "reduce", "scale",
]

# Input validation constants
MAX_QUERY_LENGTH: Final[int] = 2000
QUERY_DISPLAY_LENGTH: Final[int] = 60

# =============================================================================
# FEATURE EXTRACTOR INITIALIZATION
# =============================================================================
_feature_extractor: Optional["FeatureExtractor"] = None

try:
    from src.agents.meta_controller.feature_extractor import (
        FeatureExtractor,
        FeatureExtractorConfig,
    )
    config = FeatureExtractorConfig.from_env()
    config.device = device
    _feature_extractor = FeatureExtractor(config)
    print("✅ Using semantic feature extraction")
except ImportError as e:
    print(f"⚠️ Feature extractor module not found: {e}")
    print("   Falling back to heuristic feature extraction.")
except AttributeError as e:
    print(f"⚠️ Feature extractor configuration error: {e}")
    print("   Check that FeatureExtractorConfig has from_env() method.")
except RuntimeError as e:
    print(f"⚠️ Feature extractor initialization failed: {e}")
    print("   This may be due to missing embedding models.")

# =============================================================================
# FEATURE EXTRACTION FUNCTIONS
# =============================================================================

def _extract_heuristic_features(
    query: str,
    iteration: int,
    last_agent: str,
) -> MetaControllerFeatures:
    """Extract features using keyword-based heuristics.
    
    Args:
        query: The input query text.
        iteration: Current routing iteration (for multi-turn).
        last_agent: Name of previously selected agent.
    
    Returns:
        MetaControllerFeatures instance with heuristic-based values.
    """
    query_lower = query.lower()
    query_length = len(query)
    
    has_technical = any(word in query_lower for word in TECHNICAL_KEYWORDS)
    has_comparison = any(word in query_lower for word in COMPARISON_KEYWORDS)
    has_optimization = any(word in query_lower for word in OPTIMIZATION_KEYWORDS)
    
    # Calculate confidence scores based on keyword presence
    hrm_conf = 0.5 + (0.2 if has_technical else 0.0)
    trm_conf = 0.5 + (0.2 if has_comparison else 0.0)
    mcts_conf = 0.5 + (0.2 if has_optimization else 0.0)
    
    # Normalize to sum to 1
    total = hrm_conf + trm_conf + mcts_conf
    
    return MetaControllerFeatures(
        hrm_confidence=hrm_conf / total,
        trm_confidence=trm_conf / total,
        mcts_value=mcts_conf / total,
        consensus_score=0.6,
        last_agent=last_agent,
        iteration=iteration,
        query_length=query_length,
        has_rag_context=query_length > 50,
        rag_relevance_score=0.7 if query_length > 50 else 0.0,
        is_technical_query=has_technical,
    )


def extract_features(
    query: str,
    iteration: int = 0,
    last_agent: str = "none",
) -> MetaControllerFeatures:
    """Extract features from a query for the meta-controller.
    
    Uses semantic embeddings if available, otherwise falls back to
    keyword-based heuristic extraction.
    
    Args:
        query: The input query text.
        iteration: Current routing iteration (for multi-turn).
        last_agent: Name of previously selected agent.
    
    Returns:
        MetaControllerFeatures instance for routing decision.
    
    Examples:
        >>> features = extract_features("Compare PostgreSQL vs MongoDB")
        >>> features.is_technical_query
        True
    """
    if _feature_extractor is not None:
        return _feature_extractor.extract_features(query, iteration, last_agent)
    return _extract_heuristic_features(query, iteration, last_agent)


def validate_query(query: str) -> Tuple[bool, str]:
    """Validate query input for safety and length.
    
    Args:
        query: The query string to validate.
    
    Returns:
        Tuple of (is_valid, error_message). If valid, error_message is empty.
    """
    if not query or not query.strip():
        return False, "Please enter a query."
    
    query = query.strip()
    
    if len(query) > MAX_QUERY_LENGTH:
        return False, f"Query too long. Maximum {MAX_QUERY_LENGTH} characters."
    
    return True, ""


def route_query(
    query: str,
    controller_type: str = "rnn",
) -> Tuple["MetaControllerPrediction", MetaControllerFeatures]:
    """Route a query using the specified meta-controller.
    
    Args:
        query: The input query to route.
        controller_type: Either 'rnn' or 'bert'.
    
    Returns:
        Tuple of (prediction, features).
    
    Raises:
        ValueError: If controller_type is not 'rnn' or 'bert'.
    """
    features = extract_features(query)
    
    controller_type_lower = controller_type.lower()
    if controller_type_lower == "rnn":
        prediction = rnn_controller.predict(features)
    elif controller_type_lower == "bert":
        prediction = bert_controller.predict(features)
    else:
        raise ValueError(f"Invalid controller_type: {controller_type}")
    
    return prediction, features


print("✅ Routing functions ready!")
print(f"   Technical keywords: {len(TECHNICAL_KEYWORDS)}")
print(f"   Comparison keywords: {len(COMPARISON_KEYWORDS)}")
print(f"   Optimization keywords: {len(OPTIMIZATION_KEYWORDS)}")

In [None]:
"""Interactive routing demo with example queries."""
from __future__ import annotations

from typing import List, Final

# =============================================================================
# CONSTANTS
# =============================================================================
SEPARATOR_LENGTH: Final[int] = 60

EXAMPLE_QUERIES: Final[List[str]] = [
    "What are the key factors when choosing between microservices and monolithic architecture?",
    "How can we optimize a Python application that processes 10GB of log files daily?",
    "Compare B-trees vs LSM-trees for write-heavy workloads",
    "Design a distributed rate limiting system for 100k requests per second",
    "Explain the difference between supervised and unsupervised learning",
]


def print_controller_prediction(name: str, prediction: "MetaControllerPrediction") -> None:
    """Print formatted controller prediction results.
    
    Args:
        name: Display name for the controller.
        prediction: The prediction result to display.
    """
    print(f"\n  {name}:")
    print(f"     Selected Agent: {prediction.agent.upper()}")
    print(f"     Confidence: {prediction.confidence:.1%}")
    probs = prediction.probabilities
    print(f"     Probabilities: HRM={probs['hrm']:.1%}, "
          f"TRM={probs['trm']:.1%}, MCTS={probs['mcts']:.1%}")


# =============================================================================
# DEMO EXECUTION
# =============================================================================
print("🧠 Neural Meta-Controller Routing Demo")
print("=" * SEPARATOR_LENGTH)

for i, query in enumerate(EXAMPLE_QUERIES, 1):
    print(f"\n📝 Query {i}: {query[:QUERY_DISPLAY_LENGTH]}...")
    print("-" * SEPARATOR_LENGTH)
    
    # Get predictions from both controllers
    rnn_pred, features = route_query(query, "rnn")
    bert_pred, _ = route_query(query, "bert")
    
    print_controller_prediction("🔄 RNN Controller", rnn_pred)
    print_controller_prediction("🤖 BERT Controller", bert_pred)
    
    # Agreement check
    if rnn_pred.agent == bert_pred.agent:
        print(f"\n  ✅ Controllers AGREE: {rnn_pred.agent.upper()}")
    else:
        print(f"\n  ⚠️ Controllers DISAGREE: "
              f"RNN={rnn_pred.agent.upper()}, BERT={bert_pred.agent.upper()}")

## 🎲 Step 5: Monte Carlo Tree Search (MCTS) Demo

Explore the MCTS engine - the strategic planning component that simulates multiple decision paths.

This demo uses the framework's `MCTSConfig` with proper attribute names and seeded RNG for reproducibility.

In [None]:
"""MCTS demonstration with proper configuration and seeded RNG."""
from __future__ import annotations

import random
from typing import List, Final

from src.framework.mcts.core import MCTSNode, MCTSState
from src.framework.mcts.config import MCTSConfig, ConfigPreset, create_preset_config

# =============================================================================
# CONSTANTS
# =============================================================================
MCTS_SEED: Final[int] = 42
MCTS_DEMO_ITERATIONS: Final[int] = 100
MCTS_VALUE_MIN: Final[float] = 0.3
MCTS_VALUE_MAX: Final[float] = 0.9
NUM_ACTIONS: Final[int] = 3
ASCII_OFFSET_A: Final[int] = 65  # ord('A')
TOP_ACTIONS_DISPLAY: Final[int] = 3

print("🎲 Monte Carlo Tree Search Demo")
print("=" * SEPARATOR_LENGTH)

# Initialize seeded RNG for reproducibility
rng = random.Random(MCTS_SEED)

# Create MCTS configuration using framework's preset
# Note: create_preset_config takes ConfigPreset enum, not string
config: MCTSConfig = create_preset_config(ConfigPreset.BALANCED)

print(f"\n📊 MCTS Configuration (BALANCED preset):")
print(f"   Iterations: {config.num_iterations}")  # Correct attribute name
print(f"   Exploration Weight (C): {config.exploration_weight}")
print(f"   Max Tree Depth: {config.max_tree_depth}")  # Correct attribute name
print(f"   Random Seed: {MCTS_SEED}")


# =============================================================================
# MCTS DOMAIN FUNCTIONS
# =============================================================================

def generate_actions(state: MCTSState) -> List[str]:
    """Generate possible actions from current state.
    
    Generates up to NUM_ACTIONS actions based on the current state depth.
    Returns empty list if maximum tree depth is reached (terminal state).
    
    Args:
        state: Current MCTS state to generate actions from.
    
    Returns:
        List of action strings, or empty list if terminal.
    
    Note:
        In a real application, this would generate domain-specific actions
        based on the problem state (e.g., chess moves, planning steps).
    """
    depth = len(state.state_id.split("_"))
    if depth > config.max_tree_depth:
        return []  # Terminal state
    return [f"action_{chr(ASCII_OFFSET_A + i)}" for i in range(NUM_ACTIONS)]


def evaluate_state(state: MCTSState) -> float:
    """Evaluate state value using seeded RNG for reproducibility.
    
    This is a placeholder evaluation function for demonstration purposes.
    In production, this would use a trained neural network or domain-specific
    heuristics to evaluate the state.
    
    Args:
        state: The MCTS state to evaluate.
    
    Returns:
        Value estimate in range [MCTS_VALUE_MIN, MCTS_VALUE_MAX].
    
    Note:
        Uses seeded RNG instead of random.uniform() to ensure
        deterministic behavior for debugging and testing.
    """
    return rng.uniform(MCTS_VALUE_MIN, MCTS_VALUE_MAX)


def transition(state: MCTSState, action: str) -> MCTSState:
    """Transition to new state by applying action.
    
    Args:
        state: Current state.
        action: Action to apply.
    
    Returns:
        New MCTSState after applying the action.
    """
    return MCTSState(f"{state.state_id}_{action}")


def run_mcts_simulation(
    root: MCTSNode,
    iterations: int,
    exploration_weight: float,
) -> MCTSNode:
    """Run MCTS simulation for specified iterations.
    
    Executes the four MCTS phases: Selection, Expansion, Simulation,
    and Backpropagation.
    
    Args:
        root: Root node to start simulation from.
        iterations: Number of simulation iterations.
        exploration_weight: UCB1 exploration constant.
    
    Returns:
        Updated root node after simulation.
    """
    for _ in range(iterations):
        # Selection - traverse to leaf using UCB1
        node = root
        while node.children and not node.terminal:
            node = node.select_child(exploration_weight)
        
        # Expansion - add children if not terminal
        if not node.terminal and node.visits > 0:
            actions = generate_actions(node.state)
            if actions:
                action = rng.choice(actions)
                child_state = transition(node.state, action)
                node = node.add_child(action=action, child_state=child_state)
            else:
                node.terminal = True
        
        # Simulation - evaluate
        value = evaluate_state(node.state)
        
        # Backpropagation - update all ancestors
        while node is not None:
            node.visits += 1
            node.value_sum += value
            node = node.parent
    
    return root


# =============================================================================
# RUN SIMULATION
# =============================================================================
print(f"\n🎯 Running MCTS simulation ({MCTS_DEMO_ITERATIONS} iterations)...")

root = MCTSNode(state=MCTSState("root"))
root = run_mcts_simulation(root, MCTS_DEMO_ITERATIONS, config.exploration_weight)

# Results
print(f"\n📊 MCTS Results:")
print(f"   Root visits: {root.visits}")
print(f"   Root value: {root.value:.3f}")
print(f"   Children expanded: {len(root.children)}")

if root.children:
    print(f"\n🏆 Best Actions (by visits - most robust selection):")
    sorted_children = sorted(root.children, key=lambda c: c.visits, reverse=True)
    
    for i, child in enumerate(sorted_children[:TOP_ACTIONS_DISPLAY], 1):
        ucb1 = (
            child.value + 
            config.exploration_weight * (root.visits ** 0.5) / (child.visits + 1)
        )
        print(f"   {i}. {child.action}: "
              f"visits={child.visits}, value={child.value:.3f}, UCB1={ucb1:.3f}")
    
    best = sorted_children[0]
    print(f"\n   ✅ Recommended: {best.action} (highest visit count)")
    print(f"   💡 Reproducible with seed={MCTS_SEED}")

## 📊 Step 6: Visualize Routing Probabilities

Create a visual comparison of how different queries are routed.

In [None]:
"""Visualization of routing probabilities with refactored chart creation."""
from __future__ import annotations

from typing import List, Final, Tuple

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray

# =============================================================================
# VISUALIZATION CONSTANTS
# =============================================================================
FIGURE_WIDTH: Final[int] = 14
FIGURE_HEIGHT: Final[int] = 6
BAR_WIDTH: Final[float] = 0.25
BAR_ALPHA: Final[float] = 0.8
PROBABILITY_MIN: Final[float] = 0.0
PROBABILITY_MAX: Final[float] = 1.0
GRID_ALPHA: Final[float] = 0.3

AGENT_NAMES: Final[List[str]] = ['HRM', 'TRM', 'MCTS']
AGENT_COLORS: Final[List[str]] = ['#2ecc71', '#3498db', '#e74c3c']

VISUALIZATION_QUERIES: Final[List[str]] = [
    "Implement a binary search tree",
    "Compare PostgreSQL vs MongoDB",
    "Optimize the neural network training",
    "What is machine learning?",
    "Design a caching strategy for APIs",
]


def create_routing_chart(
    ax: plt.Axes,
    probs_arr: NDArray[np.floating],
    title: str,
) -> None:
    """Create a routing probability bar chart on the given axes.
    
    Args:
        ax: Matplotlib axes to draw on.
        probs_arr: Array of shape (n_queries, 3) with probabilities.
        title: Chart title.
    """
    x = np.arange(len(probs_arr))
    
    for i, (agent, color) in enumerate(zip(AGENT_NAMES, AGENT_COLORS)):
        ax.bar(
            x + i * BAR_WIDTH,
            probs_arr[:, i],
            BAR_WIDTH,
            label=agent,
            color=color,
            alpha=BAR_ALPHA,
        )
    
    ax.set_xlabel('Query')
    ax.set_ylabel('Probability')
    ax.set_title(title)
    ax.set_xticks(x + BAR_WIDTH)
    ax.set_xticklabels([f'Q{i+1}' for i in range(len(probs_arr))])
    ax.legend()
    ax.set_ylim(PROBABILITY_MIN, PROBABILITY_MAX)
    ax.grid(axis='y', alpha=GRID_ALPHA)


def collect_predictions(
    queries: List[str],
) -> Tuple[NDArray[np.floating], NDArray[np.floating]]:
    """Collect routing predictions for all queries.
    
    Args:
        queries: List of query strings.
    
    Returns:
        Tuple of (rnn_probs, bert_probs) arrays.
    """
    rnn_probs: List[List[float]] = []
    bert_probs: List[List[float]] = []
    
    for query in queries:
        rnn_pred, _ = route_query(query, "rnn")
        bert_pred, _ = route_query(query, "bert")
        
        rnn_probs.append([
            rnn_pred.probabilities['hrm'],
            rnn_pred.probabilities['trm'],
            rnn_pred.probabilities['mcts'],
        ])
        bert_probs.append([
            bert_pred.probabilities['hrm'],
            bert_pred.probabilities['trm'],
            bert_pred.probabilities['mcts'],
        ])
    
    return np.array(rnn_probs), np.array(bert_probs)


# =============================================================================
# CREATE VISUALIZATION
# =============================================================================
rnn_probs_arr, bert_probs_arr = collect_predictions(VISUALIZATION_QUERIES)

fig, axes = plt.subplots(1, 2, figsize=(FIGURE_WIDTH, FIGURE_HEIGHT))

create_routing_chart(axes[0], rnn_probs_arr, 'RNN Meta-Controller Routing')
create_routing_chart(axes[1], bert_probs_arr, 'BERT Meta-Controller Routing')

plt.tight_layout()
plt.suptitle(
    'Neural Meta-Controller Agent Routing Comparison',
    y=1.02,
    fontsize=14,
    fontweight='bold',
)
plt.show()

# Print query legend
print("\n📝 Query Legend:")
for i, query in enumerate(VISUALIZATION_QUERIES):
    print(f"   Q{i+1}: {query}")

## 🏛️ Step 7: Full Framework Demo with Gradio UI

Launch the complete Gradio interface for interactive exploration.

**Security Note**: Debug mode is disabled for shared demos to prevent exposing internal details.

In [None]:
"""Full Gradio app launcher (optional)."""
from __future__ import annotations

import importlib.util
from pathlib import Path

print("🏛️ Launching Full Framework UI...")
print("\nNote: This will run the complete Gradio interface.")
print("Click the public URL to access from any device.\n")

# Import and run the app
# Uncomment the following lines to launch:

# app_path = Path(REPO_PATH) / "app.py"
# spec = importlib.util.spec_from_file_location("app", app_path)
# app_module = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(app_module)
# app_module.demo.launch(share=True, debug=False)  # debug=False for security

print("ℹ️ To launch the full UI, uncomment the lines above and run this cell.")
print("   Or run: !python app.py")

In [None]:
"""Quick mini Gradio demo with input validation and secure settings."""
from __future__ import annotations

from typing import Tuple, Final

import gradio as gr

# =============================================================================
# CONSTANTS
# =============================================================================
TEXTBOX_LINES: Final[int] = 3
RESPONSE_TRUNCATE_LENGTH: Final[int] = 80
BAR_CHART_MAX_LENGTH: Final[int] = 20


def format_probability_bar(name: str, prob: float) -> str:
    """Format a probability with a visual bar chart.
    
    Args:
        name: Agent name to display.
        prob: Probability value (0.0 to 1.0).
    
    Returns:
        Formatted string with percentage and visual bar.
    """
    bar = '█' * int(prob * BAR_CHART_MAX_LENGTH)
    return f"- {name}: {prob:.1%} {bar}"


def process_query_mini(
    query: str,
    controller_type: str,
) -> Tuple[str, str, str]:
    """Process a query with the selected controller.
    
    Args:
        query: User's input query.
        controller_type: 'RNN' or 'BERT'.
    
    Returns:
        Tuple of (response, routing_info, features_info).
    """
    # Validate input
    is_valid, error_msg = validate_query(query)
    if not is_valid:
        return error_msg, "", ""
    
    try:
        prediction, features = route_query(query, controller_type.lower())
    except Exception as e:
        # Don't leak internal details in error messages
        return f"⚠️ Error processing query. Please try again.", "", ""
    
    # Format routing decision
    routing = f"""🧠 **Meta-Controller Decision**

**Selected Agent:** `{prediction.agent.upper()}`
**Confidence:** {prediction.confidence:.1%}

**Routing Probabilities:**
{format_probability_bar('HRM', prediction.probabilities['hrm'])}
{format_probability_bar('TRM', prediction.probabilities['trm'])}
{format_probability_bar('MCTS', prediction.probabilities['mcts'])}
"""
    
    # Format features
    features_str = f"""📊 **Extracted Features**

- Query Length: {features.query_length}
- Technical Query: {'Yes' if features.is_technical_query else 'No'}
- Has RAG Context: {'Yes' if features.has_rag_context else 'No'}
- HRM Confidence: {features.hrm_confidence:.3f}
- TRM Confidence: {features.trm_confidence:.3f}
- MCTS Value: {features.mcts_value:.3f}
"""
    
    # Simulated agent response
    agent_responses = {
        "hrm": f"[HRM] Breaking down hierarchically: {query[:RESPONSE_TRUNCATE_LENGTH]}...",
        "trm": f"[TRM] Applying iterative refinement: {query[:RESPONSE_TRUNCATE_LENGTH]}...",
        "mcts": f"[MCTS] Strategic exploration: {query[:RESPONSE_TRUNCATE_LENGTH]}...",
    }
    response = agent_responses.get(prediction.agent, "Unknown agent")
    
    return response, routing, features_str


# =============================================================================
# CREATE GRADIO INTERFACE
# =============================================================================
with gr.Blocks(title="LangGraph Multi-Agent MCTS - Mini Demo") as mini_demo:
    gr.Markdown("""
    # 🎯 LangGraph Multi-Agent MCTS - Quick Demo
    
    Test the neural meta-controllers with your own queries!
    """)
    
    with gr.Row():
        with gr.Column(scale=2):
            query_input = gr.Textbox(
                label="Query",
                placeholder="Enter your question...",
                lines=TEXTBOX_LINES,
                max_lines=TEXTBOX_LINES * 2,
            )
        with gr.Column(scale=1):
            controller_select = gr.Radio(
                choices=["RNN", "BERT"],
                value="RNN",
                label="Controller Type",
            )
    
    submit_btn = gr.Button("🚀 Process Query", variant="primary")
    
    with gr.Row():
        response_output = gr.Textbox(
            label="Agent Response",
            lines=TEXTBOX_LINES,
        )
    
    with gr.Row():
        routing_output = gr.Markdown(label="Routing Decision")
        features_output = gr.Markdown(label="Features")
    
    submit_btn.click(
        fn=process_query_mini,
        inputs=[query_input, controller_select],
        outputs=[response_output, routing_output, features_output],
    )

# Launch with secure settings
# debug=False prevents exposing internal details via public URL
mini_demo.launch(share=True, debug=False)

## 🎭 Step 8: Training Your Own Models (Optional)

Learn how to train the neural meta-controllers on your own data.

In [None]:
"""Training pipeline demo with proper error handling."""
from __future__ import annotations

from typing import Final

from src.training.data_generator import MetaControllerDataGenerator
from src.training.train_rnn import RNNTrainer

# =============================================================================
# TRAINING CONFIGURATION CONSTANTS
# =============================================================================
SAMPLES_PER_CLASS: Final[int] = 50
TRAIN_RATIO: Final[float] = 0.7
VAL_RATIO: Final[float] = 0.15

# Model hyperparameters
RNN_HIDDEN_DIM: Final[int] = 32
RNN_NUM_LAYERS: Final[int] = 1
RNN_DROPOUT: Final[float] = 0.1
LEARNING_RATE: Final[float] = 1e-3
BATCH_SIZE: Final[int] = 16
TRAINING_EPOCHS: Final[int] = 3
EARLY_STOPPING_PATIENCE: Final[int] = 2
TRAINING_SEED: Final[int] = 42

print("🎭 Training Pipeline Demo")
print("=" * SEPARATOR_LENGTH)

# =============================================================================
# GENERATE DATA
# =============================================================================
print("\n📦 Generating synthetic training data...")
generator = MetaControllerDataGenerator(seed=TRAINING_SEED)
features_list, labels_list = generator.generate_balanced_dataset(
    samples_per_class=SAMPLES_PER_CLASS
)

print(f"   Total samples: {len(features_list)}")
print(f"   Classes: {set(labels_list)}")

# Convert to tensors
X, y = generator.to_tensor_dataset(features_list, labels_list)
print(f"   Feature tensor shape: {X.shape}")
print(f"   Label tensor shape: {y.shape}")

# Split dataset
splits = generator.split_dataset(X, y, train_ratio=TRAIN_RATIO, val_ratio=VAL_RATIO)
print(f"\n📊 Dataset splits:")
print(f"   Training: {splits['X_train'].shape[0]} samples")
print(f"   Validation: {splits['X_val'].shape[0]} samples")
print(f"   Test: {splits['X_test'].shape[0]} samples")

# =============================================================================
# TRAIN MODEL
# =============================================================================
print(f"\n🏋️ Training RNN model ({TRAINING_EPOCHS} epochs)...")

trainer = RNNTrainer(
    hidden_dim=RNN_HIDDEN_DIM,
    num_layers=RNN_NUM_LAYERS,
    dropout=RNN_DROPOUT,
    lr=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    epochs=TRAINING_EPOCHS,
    early_stopping_patience=EARLY_STOPPING_PATIENCE,
    seed=TRAINING_SEED,
)

try:
    history = trainer.train(
        train_data=(splits["X_train"], splits["y_train"]),
        val_data=(splits["X_val"], splits["y_val"]),
    )
    
    print(f"\n✅ Training complete!")
    print(f"   Best validation accuracy: {history['best_val_accuracy']:.2%}")
    print(f"   Best validation loss: {history['best_val_loss']:.4f}")
    
    # Evaluate on test set
    print("\n🧪 Evaluating on test set...")
    test_loader = trainer.create_dataloader(
        splits["X_test"], splits["y_test"], shuffle=False
    )
    results = trainer.evaluate(test_loader)
    
    print(f"   Test accuracy: {results['accuracy']:.2%}")
    print(f"   Test loss: {results['loss']:.4f}")

except RuntimeError as e:
    print(f"\n❌ Training failed: {e}")
    print("   Consider reducing batch size or checking GPU memory.")
except ValueError as e:
    print(f"\n❌ Data validation error: {e}")

## 🏁 Step 9: Chess MCTS Demo (Bonus)

The framework includes a chess implementation using MCTS - similar to AlphaZero!

In [None]:
"""Chess MCTS demo with clear installation instructions."""
from __future__ import annotations

from typing import Final

# =============================================================================
# CHESS DEMO CONSTANTS
# =============================================================================
CHESS_MCTS_ITERATIONS: Final[int] = 50
CHESS_EXPLORATION_WEIGHT: Final[float] = 1.414  # sqrt(2)
TOP_MOVES_DISPLAY: Final[int] = 5

try:
    import chess
    from src.games.chess.game import ChessGame
    from src.games.chess.mcts import ChessMCTS
    
    print("♚ Chess MCTS Demo")
    print("=" * SEPARATOR_LENGTH)
    
    # Create a chess game
    game = ChessGame()
    print(f"\n🎯 Initial position:")
    print(game.board)
    
    # Create MCTS player
    mcts = ChessMCTS(
        game,
        iterations=CHESS_MCTS_ITERATIONS,
        exploration_weight=CHESS_EXPLORATION_WEIGHT,
    )
    
    # Get best move
    print(f"\n🤔 MCTS thinking ({CHESS_MCTS_ITERATIONS} iterations)...")
    best_move = mcts.get_best_move()
    
    print(f"\n✅ Best move: {best_move}")
    
    # Show move statistics
    print(f"\n📊 Top {TOP_MOVES_DISPLAY} move statistics:")
    for move, stats in mcts.get_move_stats()[:TOP_MOVES_DISPLAY]:
        print(f"   {move}: visits={stats['visits']}, value={stats['value']:.3f}")

except ImportError as e:
    print("⚠️ Chess demo not available")
    print(f"   Error: {e}")
    print("\n💡 To enable the chess demo:")
    print("   1. Run: !pip install python-chess")
    print("   2. Restart runtime: Runtime > Restart runtime")
    print("   3. Re-run this cell")

## 📚 Resources & Next Steps

### 📖 Documentation
- [Repository README](https://github.com/ianshank/langgraph_multi_agent_mcts)
- [Architecture Documentation](https://github.com/ianshank/langgraph_multi_agent_mcts/blob/main/docs/langgraph_mcts_architecture.md)

### 🛠️ Key Files
- `app.py` - Main Gradio application
- `src/agents/meta_controller/` - Neural meta-controllers
- `src/framework/mcts/` - MCTS implementation
- `src/training/` - Training pipelines

### 🏃 Run Locally
```bash
git clone https://github.com/ianshank/langgraph_multi_agent_mcts.git
cd langgraph_multi_agent_mcts
pip install -r requirements.txt
python app.py
```

### 👋 Feedback
Open an issue on GitHub or contribute to the project!

In [None]:
"""Notebook completion summary."""
print("🎉 Notebook complete!")
print("\nYou've explored:")
print("  ✅ Neural Meta-Controllers (RNN and BERT with LoRA)")
print("  ✅ Monte Carlo Tree Search (MCTS) with MCTSEngine")
print("  ✅ Agent Routing Visualization")
print("  ✅ Training Pipeline")
print("  ✅ Interactive Gradio Demo (secure mode)")
print("\n🚀 Happy experimenting!")