# 🎯 LangGraph Multi-Agent MCTS Framework

## Production Demo with Trained Neural Meta-Controllers

This notebook demonstrates the **LangGraph Multi-Agent MCTS Framework** - a sophisticated multi-agent system that combines:

- **LangGraph** for explicit state management and agent orchestration
- **Monte Carlo Tree Search (MCTS)** for strategic planning and exploration
- **Neural Meta-Controllers** (RNN and BERT with LoRA) for intelligent agent routing

### 🧠 Trained Models
- **RNN Meta-Controller**: GRU-based sequential pattern recognition for fast routing
- **BERT with LoRA**: Transformer-based text understanding with parameter-efficient fine-tuning

### 🤖 Agents
- **HRM (Hierarchical Reasoning Model)**: Decomposes complex problems hierarchically
- **TRM (Tiny Recursive Model)**: Iterative refinement for progressive improvement
- **MCTS**: Monte Carlo Tree Search for optimization and strategic exploration

---

**Repository**: [github.com/ianshank/langgraph_multi_agent_mcts](https://github.com/ianshank/langgraph_multi_agent_mcts)

## 🚀 Step 1: Environment Setup

Clone the repository and install all dependencies. This cell handles:
- Repository cloning
- Dependency installation (including `nest_asyncio` for async support in Colab)
- Path configuration

In [None]:
import os
import sys

# Repository path
REPO_NAME = "langgraph_multi_agent_mcts"
REPO_PATH = f"/content/{REPO_NAME}"
REPO_URL = "https://github.com/ianshank/langgraph_multi_agent_mcts.git"

# 1. Clone the repository (clean up if re-running)
print("📦 Cloning repository...")
!rm -rf {REPO_PATH}
!git clone {REPO_URL} {REPO_PATH}

# 2. Change to repo directory
%cd {REPO_PATH}

# 3. Install dependencies from requirements.txt
print("\n📦 Installing dependencies...")
!pip install -q -r requirements.txt

# 4. Install additional Colab-specific packages
print("\n📦 Installing Colab-specific packages...")
!pip install -q nest_asyncio ipywidgets

# 5. Apply nest_asyncio for async support in Jupyter
import nest_asyncio
nest_asyncio.apply()

# 6. Add repo to Python path
sys.path.insert(0, REPO_PATH)

print("\n✅ Setup complete!")
print(f"📁 Working directory: {os.getcwd()}")

## 🔑 Step 2: API Key Configuration (Optional)

The neural meta-controllers use **pre-trained local models**, so API keys are **optional**.

However, if you want to:
- Use LLM-powered agents (HRM/TRM with actual generation)
- Enable LangSmith tracing for debugging
- Use Weights & Biases for experiment tracking

You can configure your API keys below.

In [None]:
import os
import getpass

def set_key(name, required=False):
    """Set API key from Colab Secrets or manual input."""
    try:
        from google.colab import userdata
        value = userdata.get(name)
        if value:
            print(f"✅ {name} loaded from Colab Secrets")
            return value
    except:
        pass
    
    if required:
        return getpass.getpass(f"Enter your {name}: ")
    else:
        print(f"⚠️ {name} not set (optional)")
        return None

# Optional API Keys
print("🔑 Configuring API Keys...\n")

# OpenAI - for LLM-powered agents (optional)
openai_key = set_key("OPENAI_API_KEY")
if openai_key:
    os.environ["OPENAI_API_KEY"] = openai_key

# LangSmith - for tracing (optional but recommended)
langchain_key = set_key("LANGCHAIN_API_KEY")
if langchain_key:
    os.environ["LANGCHAIN_TRACING_V2"] = "true"
    os.environ["LANGCHAIN_API_KEY"] = langchain_key
    os.environ["LANGCHAIN_PROJECT"] = "langgraph-multi-agent-mcts"

# Weights & Biases - for experiment tracking (optional)
wandb_key = set_key("WANDB_API_KEY")
if wandb_key:
    os.environ["WANDB_API_KEY"] = wandb_key

print("\n✅ API key configuration complete!")

## 🧠 Step 3: Load Trained Neural Meta-Controllers

Initialize the framework with the pre-trained models:
- **RNN Meta-Controller**: Fast, captures sequential patterns (10D features → 3-class routing)
- **BERT with LoRA**: Context-aware text understanding for complex routing decisions

In [None]:
import torch
from pathlib import Path

# Import meta-controllers
from src.agents.meta_controller.rnn_controller import RNNMetaController
from src.agents.meta_controller.bert_controller_v2 import BERTMetaController
from src.agents.meta_controller.base import MetaControllerFeatures

print("🧠 Initializing Neural Meta-Controllers...\n")

# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"💻 Device: {device}")
if device == "cuda":
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

# Initialize RNN Controller
print("\n🔄 Loading RNN Meta-Controller...")
rnn_controller = RNNMetaController(name="RNNController", seed=42, device=device)

# Load trained weights
rnn_model_path = Path(REPO_PATH) / "models" / "rnn_meta_controller.pt"
if rnn_model_path.exists():
    checkpoint = torch.load(rnn_model_path, map_location=device, weights_only=True)
    rnn_controller.model.load_state_dict(checkpoint)
    rnn_controller.model.eval()
    print(f"   ✅ Loaded trained weights from {rnn_model_path.name}")
else:
    print(f"   ⚠️ Using untrained model (weights not found)")

# Initialize BERT Controller with LoRA
print("\n🤖 Loading BERT Meta-Controller with LoRA...")
bert_controller = BERTMetaController(name="BERTController", seed=42, device=device, use_lora=True)

# Load trained weights
bert_model_path = Path(REPO_PATH) / "models" / "bert_lora" / "final_model"
if bert_model_path.exists():
    try:
        bert_controller.load_model(str(bert_model_path))
        print(f"   ✅ Loaded trained LoRA weights from {bert_model_path.name}")
    except Exception as e:
        print(f"   ⚠️ Could not load BERT weights: {e}")
else:
    print(f"   ⚠️ Using untrained model (weights not found)")

print("\n✅ Meta-controllers loaded successfully!")

## 🎮 Step 4: Interactive Agent Routing Demo

Try the neural meta-controllers! Enter a query and see how the controllers decide which agent to route it to.

In [None]:
# Import feature extraction
try:
    from src.agents.meta_controller.feature_extractor import FeatureExtractor, FeatureExtractorConfig
    config = FeatureExtractorConfig.from_env()
    config.device = device
    feature_extractor = FeatureExtractor(config)
    print("✅ Using semantic feature extraction")
except Exception as e:
    feature_extractor = None
    print(f"⚠️ Using heuristic feature extraction: {e}")

def extract_features(query: str, iteration: int = 0, last_agent: str = "none"):
    """Extract features from a query for the meta-controller."""
    if feature_extractor:
        return feature_extractor.extract_features(query, iteration, last_agent)
    
    # Fallback heuristic extraction
    query_length = len(query)
    has_technical = any(word in query.lower() for word in ["algorithm", "code", "implement", "technical", "system"])
    has_comparison = any(word in query.lower() for word in ["vs", "versus", "compare", "difference", "better"])
    has_optimization = any(word in query.lower() for word in ["optimize", "best", "improve", "maximize", "minimize"])
    
    hrm_conf = 0.5 + (0.2 if has_technical else 0)
    trm_conf = 0.5 + (0.2 if has_comparison else 0)
    mcts_conf = 0.5 + (0.2 if has_optimization else 0)
    
    total = hrm_conf + trm_conf + mcts_conf
    return MetaControllerFeatures(
        hrm_confidence=hrm_conf/total,
        trm_confidence=trm_conf/total,
        mcts_value=mcts_conf/total,
        consensus_score=0.6,
        last_agent=last_agent,
        iteration=iteration,
        query_length=query_length,
        has_rag_context=query_length > 50,
        rag_relevance_score=0.7 if query_length > 50 else 0.0,
        is_technical_query=has_technical,
    )

def route_query(query: str, controller_type: str = "rnn"):
    """Route a query using the specified meta-controller."""
    features = extract_features(query)
    
    if controller_type.lower() == "rnn":
        prediction = rnn_controller.predict(features)
    else:
        prediction = bert_controller.predict(features)
    
    return prediction, features

print("✅ Routing functions ready!")

In [None]:
# 🎮 Try it yourself!

# Example queries - try changing these!
example_queries = [
    "What are the key factors when choosing between microservices and monolithic architecture?",
    "How can we optimize a Python application that processes 10GB of log files daily?",
    "Compare B-trees vs LSM-trees for write-heavy workloads",
    "Design a distributed rate limiting system for 100k requests per second",
    "Explain the difference between supervised and unsupervised learning",
]

print("🧠 Neural Meta-Controller Routing Demo")
print("=" * 60)

for i, query in enumerate(example_queries, 1):
    print(f"\n📝 Query {i}: {query[:60]}...")
    print("-" * 60)
    
    # Get predictions from both controllers
    rnn_pred, features = route_query(query, "rnn")
    bert_pred, _ = route_query(query, "bert")
    
    print(f"\n  🔄 RNN Controller:")
    print(f"     Selected Agent: {rnn_pred.agent.upper()}")
    print(f"     Confidence: {rnn_pred.confidence:.1%}")
    print(f"     Probabilities: HRM={rnn_pred.probabilities['hrm']:.1%}, TRM={rnn_pred.probabilities['trm']:.1%}, MCTS={rnn_pred.probabilities['mcts']:.1%}")
    
    print(f"\n  🤖 BERT Controller:")
    print(f"     Selected Agent: {bert_pred.agent.upper()}")
    print(f"     Confidence: {bert_pred.confidence:.1%}")
    print(f"     Probabilities: HRM={bert_pred.probabilities['hrm']:.1%}, TRM={bert_pred.probabilities['trm']:.1%}, MCTS={bert_pred.probabilities['mcts']:.1%}")
    
    # Agreement check
    if rnn_pred.agent == bert_pred.agent:
        print(f"\n  ✅ Controllers AGREE: {rnn_pred.agent.upper()}")
    else:
        print(f"\n  ⚠️ Controllers DISAGREE: RNN={rnn_pred.agent.upper()}, BERT={bert_pred.agent.upper()}")

## 🎲 Step 5: Monte Carlo Tree Search (MCTS) Demo

Explore the MCTS engine - the strategic planning component that simulates multiple decision paths.

In [None]:
import random
from src.framework.mcts.core import MCTSNode, MCTSState, MCTSEngine
from src.framework.mcts.config import MCTSConfig, create_preset_config

print("🎲 Monte Carlo Tree Search Demo")
print("=" * 60)

# Create MCTS configuration
config = create_preset_config("BALANCED")
print(f"\n📊 MCTS Configuration (BALANCED preset):")
print(f"   Iterations: {config.iterations}")
print(f"   Exploration Weight: {config.exploration_weight}")
print(f"   Max Depth: {config.max_depth}")

# Define simple action generator for demo
def generate_actions(state: MCTSState) -> list[str]:
    """Generate possible actions from current state."""
    depth = len(state.state_id.split("_"))
    if depth > 3:
        return []  # Terminal
    return [f"action_{chr(65+i)}" for i in range(3)]  # A, B, C

def evaluate_state(state: MCTSState) -> float:
    """Evaluate state value (simplified for demo)."""
    # Simulate some value based on path taken
    return random.uniform(0.3, 0.9)

def transition(state: MCTSState, action: str) -> MCTSState:
    """Transition to new state."""
    return MCTSState(f"{state.state_id}_{action}")

# Run MCTS
print(f"\n🎯 Running MCTS simulation...")
root = MCTSNode(state=MCTSState("root"))

iterations = 100
for i in range(iterations):
    # Selection - traverse to leaf using UCB1
    node = root
    while node.children and not node.terminal:
        node = node.select_child(config.exploration_weight)
    
    # Expansion - add children if not terminal
    if not node.terminal and node.visits > 0:
        actions = generate_actions(node.state)
        if actions:
            action = random.choice(actions)
            child_state = transition(node.state, action)
            node = node.add_child(action=action, child_state=child_state)
        else:
            node.terminal = True
    
    # Simulation - evaluate
    value = evaluate_state(node.state)
    
    # Backpropagation
    while node:
        node.visits += 1
        node.value_sum += value
        node = node.parent

# Results
print(f"\n📊 MCTS Results after {iterations} iterations:")
print(f"   Root visits: {root.visits}")
print(f"   Root value: {root.value:.3f}")
print(f"   Children: {len(root.children)}")

if root.children:
    print(f"\n🏆 Best Actions (by visits):")
    sorted_children = sorted(root.children, key=lambda c: c.visits, reverse=True)
    for i, child in enumerate(sorted_children[:3], 1):
        print(f"   {i}. {child.action}: visits={child.visits}, value={child.value:.3f}")
    
    best = sorted_children[0]
    print(f"\n   ✅ Recommended: {best.action} (most robust selection)")

## 📊 Step 6: Visualize Routing Probabilities

Create a visual comparison of how different queries are routed.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Test queries for visualization
test_queries = [
    "Implement a binary search tree",
    "Compare PostgreSQL vs MongoDB",
    "Optimize the neural network training",
    "What is machine learning?",
    "Design a caching strategy for APIs",
]

# Collect predictions
rnn_probs = []
bert_probs = []

for query in test_queries:
    rnn_pred, _ = route_query(query, "rnn")
    bert_pred, _ = route_query(query, "bert")
    
    rnn_probs.append([rnn_pred.probabilities['hrm'], 
                      rnn_pred.probabilities['trm'], 
                      rnn_pred.probabilities['mcts']])
    bert_probs.append([bert_pred.probabilities['hrm'], 
                       bert_pred.probabilities['trm'], 
                       bert_pred.probabilities['mcts']])

rnn_probs = np.array(rnn_probs)
bert_probs = np.array(bert_probs)

# Create visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
agents = ['HRM', 'TRM', 'MCTS']
colors = ['#2ecc71', '#3498db', '#e74c3c']
x = np.arange(len(test_queries))
width = 0.25

# RNN Controller
ax1 = axes[0]
for i, (agent, color) in enumerate(zip(agents, colors)):
    ax1.bar(x + i*width, rnn_probs[:, i], width, label=agent, color=color, alpha=0.8)
ax1.set_xlabel('Query')
ax1.set_ylabel('Probability')
ax1.set_title('🔄 RNN Meta-Controller Routing')
ax1.set_xticks(x + width)
ax1.set_xticklabels([f'Q{i+1}' for i in range(len(test_queries))])
ax1.legend()
ax1.set_ylim(0, 1)
ax1.grid(axis='y', alpha=0.3)

# BERT Controller
ax2 = axes[1]
for i, (agent, color) in enumerate(zip(agents, colors)):
    ax2.bar(x + i*width, bert_probs[:, i], width, label=agent, color=color, alpha=0.8)
ax2.set_xlabel('Query')
ax2.set_ylabel('Probability')
ax2.set_title('🤖 BERT Meta-Controller Routing')
ax2.set_xticks(x + width)
ax2.set_xticklabels([f'Q{i+1}' for i in range(len(test_queries))])
ax2.legend()
ax2.set_ylim(0, 1)
ax2.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.suptitle('Neural Meta-Controller Agent Routing Comparison', y=1.02, fontsize=14, fontweight='bold')
plt.show()

# Print query legend
print("\n📝 Query Legend:")
for i, query in enumerate(test_queries):
    print(f"   Q{i+1}: {query}")

## 🏛️ Step 7: Full Framework Demo with Gradio UI

Launch the complete Gradio interface for interactive exploration.

In [None]:
# Option A: Run the full Gradio app
# This launches the complete UI with all features

print("🏛️ Launching Full Framework UI...")
print("\nNote: This will run the complete Gradio interface.")
print("Click the public URL to access from any device.\n")

# Import and run the app
import importlib.util
spec = importlib.util.spec_from_file_location("app", f"{REPO_PATH}/app.py")
app_module = importlib.util.module_from_spec(spec)

# This will initialize the framework and launch Gradio
# Comment out the next two lines if you just want to explore the code
# spec.loader.exec_module(app_module)
# app_module.demo.launch(share=True, debug=True)

print("ℹ️ To launch the full UI, uncomment the last two lines above and run this cell.")
print("   Or run: !python app.py")

In [None]:
# Option B: Quick Mini Gradio Demo
# A lightweight version for quick testing

import gradio as gr
import asyncio

def process_query_mini(query: str, controller_type: str):
    """Process a query with the selected controller."""
    if not query.strip():
        return "Please enter a query.", "", ""
    
    prediction, features = route_query(query, controller_type.lower())
    
    # Format routing decision
    routing = f"""🧠 **Meta-Controller Decision**

**Selected Agent:** `{prediction.agent.upper()}`
**Confidence:** {prediction.confidence:.1%}

**Routing Probabilities:**
- HRM: {prediction.probabilities['hrm']:.1%} {'█' * int(prediction.probabilities['hrm'] * 20)}
- TRM: {prediction.probabilities['trm']:.1%} {'█' * int(prediction.probabilities['trm'] * 20)}
- MCTS: {prediction.probabilities['mcts']:.1%} {'█' * int(prediction.probabilities['mcts'] * 20)}
"""
    
    # Format features
    features_str = f"""📊 **Extracted Features**

- Query Length: {features.query_length}
- Technical Query: {'Yes' if features.is_technical_query else 'No'}
- Has RAG Context: {'Yes' if features.has_rag_context else 'No'}
- HRM Confidence: {features.hrm_confidence:.3f}
- TRM Confidence: {features.trm_confidence:.3f}
- MCTS Value: {features.mcts_value:.3f}
"""
    
    # Simulated agent response
    agent_responses = {
        "hrm": f"[HRM] Breaking down hierarchically: {query[:80]}...",
        "trm": f"[TRM] Applying iterative refinement: {query[:80]}...",
        "mcts": f"[MCTS] Strategic exploration via tree search: {query[:80]}...",
    }
    response = agent_responses.get(prediction.agent, "Unknown agent")
    
    return response, routing, features_str

# Create mini Gradio interface
with gr.Blocks(title="LangGraph Multi-Agent MCTS - Mini Demo") as mini_demo:
    gr.Markdown("""
    # 🎯 LangGraph Multi-Agent MCTS - Quick Demo
    
    Test the neural meta-controllers with your own queries!
    """)
    
    with gr.Row():
        with gr.Column(scale=2):
            query_input = gr.Textbox(
                label="Query",
                placeholder="Enter your question...",
                lines=3
            )
        with gr.Column(scale=1):
            controller_select = gr.Radio(
                choices=["RNN", "BERT"],
                value="RNN",
                label="Controller Type"
            )
    
    submit_btn = gr.Button("🚀 Process Query", variant="primary")
    
    with gr.Row():
        response_output = gr.Textbox(label="Agent Response", lines=3)
    
    with gr.Row():
        routing_output = gr.Markdown(label="Routing Decision")
        features_output = gr.Markdown(label="Features")
    
    submit_btn.click(
        fn=process_query_mini,
        inputs=[query_input, controller_select],
        outputs=[response_output, routing_output, features_output]
    )

# Launch with share=True for public URL
mini_demo.launch(share=True, debug=True)

## 🎭 Step 8: Training Your Own Models (Optional)

Learn how to train the neural meta-controllers on your own data.

In [None]:
# Generate synthetic training data
from src.training.data_generator import MetaControllerDataGenerator
from src.training.train_rnn import RNNTrainer

print("🎭 Training Pipeline Demo")
print("=" * 60)

# Generate balanced dataset
print("\n📦 Generating synthetic training data...")
generator = MetaControllerDataGenerator(seed=42)
features_list, labels_list = generator.generate_balanced_dataset(samples_per_class=50)

print(f"   Total samples: {len(features_list)}")
print(f"   Classes: {set(labels_list)}")

# Convert to tensors
X, y = generator.to_tensor_dataset(features_list, labels_list)
print(f"   Feature tensor shape: {X.shape}")
print(f"   Label tensor shape: {y.shape}")

# Split dataset
splits = generator.split_dataset(X, y, train_ratio=0.7, val_ratio=0.15)
print(f"\n📊 Dataset splits:")
print(f"   Training: {splits['X_train'].shape[0]} samples")
print(f"   Validation: {splits['X_val'].shape[0]} samples")
print(f"   Test: {splits['X_test'].shape[0]} samples")

# Quick training demo (3 epochs)
print("\n🏋️ Training RNN model (3 epochs)...")
trainer = RNNTrainer(
    hidden_dim=32,
    num_layers=1,
    dropout=0.1,
    lr=1e-3,
    batch_size=16,
    epochs=3,
    early_stopping_patience=2,
    seed=42,
)

history = trainer.train(
    train_data=(splits["X_train"], splits["y_train"]),
    val_data=(splits["X_val"], splits["y_val"]),
)

print(f"\n✅ Training complete!")
print(f"   Best validation accuracy: {history['best_val_accuracy']:.2%}")
print(f"   Best validation loss: {history['best_val_loss']:.4f}")

# Evaluate on test set
print("\n🧪 Evaluating on test set...")
test_loader = trainer.create_dataloader(splits["X_test"], splits["y_test"], shuffle=False)
results = trainer.evaluate(test_loader)

print(f"   Test accuracy: {results['accuracy']:.2%}")
print(f"   Test loss: {results['loss']:.4f}")

## 🏁 Step 9: Chess MCTS Demo (Bonus)

The framework includes a chess implementation using MCTS - similar to AlphaZero!

In [None]:
try:
    import chess
    from src.games.chess.game import ChessGame
    from src.games.chess.mcts import ChessMCTS
    
    print("♚ Chess MCTS Demo")
    print("=" * 60)
    
    # Create a chess game
    game = ChessGame()
    print(f"\n🎯 Initial position:")
    print(game.board)
    
    # Create MCTS player
    mcts = ChessMCTS(game, iterations=50, exploration_weight=1.414)
    
    # Get best move
    print(f"\n🤔 MCTS thinking (50 iterations)...")
    best_move = mcts.get_best_move()
    
    print(f"\n✅ Best move: {best_move}")
    
    # Show move statistics
    print(f"\n📊 Move statistics:")
    for move, stats in mcts.get_move_stats()[:5]:
        print(f"   {move}: visits={stats['visits']}, value={stats['value']:.3f}")
        
except ImportError as e:
    print(f"⚠️ Chess demo not available: {e}")
    print("   Install with: pip install python-chess")

## 📚 Resources & Next Steps

### 📖 Documentation
- [Repository README](https://github.com/ianshank/langgraph_multi_agent_mcts)
- [Architecture Documentation](https://github.com/ianshank/langgraph_multi_agent_mcts/blob/main/docs/langgraph_mcts_architecture.md)

### 🛠️ Key Files
- `app.py` - Main Gradio application
- `src/agents/meta_controller/` - Neural meta-controllers
- `src/framework/mcts/` - MCTS implementation
- `src/training/` - Training pipelines

### 🏃 Run Locally
```bash
git clone https://github.com/ianshank/langgraph_multi_agent_mcts.git
cd langgraph_multi_agent_mcts
pip install -r requirements.txt
python app.py
```

### 👋 Feedback
Open an issue on GitHub or contribute to the project!

In [None]:
print("🎉 Notebook complete!")
print("\nYou've explored:")
print("  ✅ Neural Meta-Controllers (RNN and BERT with LoRA)")
print("  ✅ Monte Carlo Tree Search (MCTS)")
print("  ✅ Agent Routing Visualization")
print("  ✅ Training Pipeline")
print("  ✅ Interactive Gradio Demo")
print("\n🚀 Happy experimenting!")