<a href="https://colab.research.google.com/github/matsunagalab/mcp-md/blob/main/colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# üß¨ MCP-MD: AI-Powered Molecular Dynamics Agent

**Interactive AI assistant for setting up MD simulations**

This notebook provides a chat interface to interact with the MCP-MD AI agent. Simply describe what you want to simulate, and the agent will:

1. **Analyze** your request and ask clarifying questions
2. **Fetch** structures from PDB/AlphaFold
3. **Prepare** protein + parameterize ligands (GAFF2/AM1-BCC)
4. **Solvate** with water box + ions
5. **Build** Amber topology (tleap)
6. **Simulate** with OpenMM (NPT ensemble)
7. **Visualize** results with interactive 3D viewer

---

## Quick Start

1. **Run Setup cells** (Cell 1-2) - installs dependencies (~5-10 min)
2. **Enter your API key** (Cell 3)
3. **Start chatting!** - describe your simulation (Cell 4)

**Example prompts:**
- "Setup MD for PDB 1AKE in water, 1 ns at 300K"
- "I want to simulate lysozyme (PDB 1LYZ) with explicit solvent"
- "Run a short simulation of insulin (PDB 4INS), chain A only"

---
## Setup 1/2: Install condacolab

**‚ö†Ô∏è The runtime will restart after this cell. This is expected!**

In [None]:
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Installing condacolab...")
    print("‚ö†Ô∏è The runtime will restart. Run the next cell after restart.")
    !pip install -q condacolab
    import condacolab
    condacolab.install()
else:
    print("Not running in Colab - skipping condacolab setup")
    print("Make sure you have conda environment with AmberTools installed.")

---
## Setup 2/2: Install Dependencies

**Run this cell AFTER the runtime restarts.**

Installs AmberTools, OpenMM, RDKit, and project dependencies (~5-10 min)

In [None]:
import sys
import time

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    import condacolab
    condacolab.check()
    
    start_time = time.time()
    
    # Install conda packages (AmberTools + heavy scientific packages)
    print("="*60)
    print("Installing AmberTools + scientific packages via conda...")
    print("This takes ~5-10 minutes. Please wait.")
    print("="*60)
    !conda install -y -c conda-forge ambertools=23 openmm rdkit pdbfixer 2>&1 | tail -20
    print(f"\n‚úì Conda packages installed ({time.time() - start_time:.0f}s)")
    
    # Clone repository and install
    print("\nCloning mcp-md repository...")
    !git clone -q https://github.com/matsunagalab/mcp-md.git /content/mcp-md
    %cd /content/mcp-md
    
    print("Installing Python dependencies...")
    !pip install -q -e .
    !pip install -q py3Dmol mdtraj
    
    # Set AMBERHOME
    import os
    import json
    import subprocess
    conda_info = json.loads(subprocess.run(['conda', 'info', '--json'], 
                                           capture_output=True, text=True).stdout)
    os.environ["AMBERHOME"] = conda_info.get('default_prefix', '')
    
    sys.path.insert(0, '/content/mcp-md')
    
    total_time = time.time() - start_time
    print(f"\n" + "="*60)
    print(f"‚úì Setup complete! ({total_time/60:.1f} minutes)")
    print("="*60)
    print("\nüéâ You can now proceed to the next cell!")

else:
    # Local development
    sys.path.insert(0, '.')
    print("Local environment - dependencies should be pre-installed.")

---
## API Key Configuration

Enter your Anthropic API key to enable the AI agent.

Get your API key from: https://console.anthropic.com/

In [None]:
import os
import sys
from getpass import getpass

IN_COLAB = 'google.colab' in sys.modules

# Check if API key is already set
if os.environ.get("ANTHROPIC_API_KEY"):
    print("‚úì ANTHROPIC_API_KEY is already set")
else:
    if IN_COLAB:
        from google.colab import userdata
        try:
            # Try to get from Colab secrets
            api_key = userdata.get('ANTHROPIC_API_KEY')
            os.environ["ANTHROPIC_API_KEY"] = api_key
            print("‚úì ANTHROPIC_API_KEY loaded from Colab secrets")
        except:
            # Prompt for input
            print("Enter your Anthropic API key:")
            api_key = getpass()
            os.environ["ANTHROPIC_API_KEY"] = api_key
            print("‚úì ANTHROPIC_API_KEY set")
    else:
        # Local - try to load from .env
        from dotenv import load_dotenv
        load_dotenv()
        if os.environ.get("ANTHROPIC_API_KEY"):
            print("‚úì ANTHROPIC_API_KEY loaded from .env")
        else:
            print("Enter your Anthropic API key:")
            api_key = getpass()
            os.environ["ANTHROPIC_API_KEY"] = api_key
            print("‚úì ANTHROPIC_API_KEY set")

---
## ü§ñ MCP-MD AI Agent Chat Interface

Interact with the AI agent to set up your MD simulation!

**How to use:**
1. Type your request in the text box (e.g., "Setup MD for PDB 1AKE")
2. Click **Send** or press Enter
3. The agent will analyze your request and may ask clarifying questions
4. Answer the questions until the agent has enough information
5. Click **Run Simulation** to execute the workflow

**Commands:**
- Type `clear` to reset the conversation
- Type `run` to execute the simulation workflow
- Type `help` for more commands

In [None]:
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
import asyncio
import json
from pathlib import Path
from datetime import datetime

# Import MCP-MD components
from langchain_core.messages import HumanMessage, AIMessage
from mcp_md.clarification_agent import create_clarification_graph
from mcp_md.state_scope import SimulationBrief

# ============================================================================
# Chat State
# ============================================================================
class ChatState:
    def __init__(self):
        self.messages = []  # Conversation history
        self.agent_state = None  # LangGraph state
        self.simulation_brief = None  # Generated brief
        self.workflow_outputs = {}  # File paths from workflow
        self.graph = None  # Clarification graph
        
    def reset(self):
        self.messages = []
        self.agent_state = None
        self.simulation_brief = None
        self.workflow_outputs = {}

chat_state = ChatState()

# ============================================================================
# UI Components
# ============================================================================

# Chat history display
chat_output = widgets.Output(
    layout=widgets.Layout(
        width='100%',
        height='400px',
        border='1px solid #ccc',
        overflow_y='auto',
        padding='10px'
    )
)

# User input
user_input = widgets.Text(
    placeholder='Type your message here... (e.g., "Setup MD for PDB 1AKE")',
    layout=widgets.Layout(width='80%')
)

# Send button
send_button = widgets.Button(
    description='Send',
    button_style='primary',
    icon='paper-plane',
    layout=widgets.Layout(width='18%')
)

# Run simulation button
run_button = widgets.Button(
    description='Run Simulation',
    button_style='success',
    icon='play',
    disabled=True,
    layout=widgets.Layout(width='49%')
)

# Clear button
clear_button = widgets.Button(
    description='Clear Chat',
    button_style='warning',
    icon='trash',
    layout=widgets.Layout(width='49%')
)

# Status indicator
status_label = widgets.HTML(
    value='<span style="color: gray;">Ready to chat</span>'
)

# Progress bar for long operations
progress_bar = widgets.FloatProgress(
    value=0,
    min=0,
    max=100,
    description='Progress:',
    bar_style='info',
    layout=widgets.Layout(width='100%', visibility='hidden')
)

# Visualization output
viz_output = widgets.Output(
    layout=widgets.Layout(width='100%', min_height='400px')
)

# ============================================================================
# Chat Display Functions
# ============================================================================

def add_message(role, content, msg_type='text'):
    """Add a message to the chat display."""
    timestamp = datetime.now().strftime('%H:%M')
    
    if role == 'user':
        color = '#007bff'
        bg = '#e3f2fd'
        align = 'right'
        icon = 'üë§'
    elif role == 'agent':
        color = '#28a745'
        bg = '#e8f5e9'
        align = 'left'
        icon = 'ü§ñ'
    else:  # system
        color = '#6c757d'
        bg = '#f8f9fa'
        align = 'center'
        icon = '‚ÑπÔ∏è'
    
    # Format content (handle markdown-like formatting)
    formatted_content = content.replace('\n', '<br>')
    
    html = f'''
    <div style="text-align: {align}; margin: 8px 0;">
        <div style="display: inline-block; max-width: 85%; padding: 10px 15px; 
                    border-radius: 15px; background: {bg}; text-align: left;">
            <span style="font-size: 0.8em; color: {color};">{icon} {timestamp}</span><br>
            <span style="color: #333;">{formatted_content}</span>
        </div>
    </div>
    '''
    
    with chat_output:
        display(HTML(html))
    
    # Store in history
    chat_state.messages.append({'role': role, 'content': content})

def display_brief(brief):
    """Display the simulation brief in a formatted way."""
    if hasattr(brief, 'model_dump'):
        brief_dict = brief.model_dump()
    else:
        brief_dict = brief
    
    html = '''
    <div style="background: #fff3e0; padding: 15px; border-radius: 10px; margin: 10px 0;">
        <h4 style="color: #e65100; margin-top: 0;">üìã Simulation Brief</h4>
        <table style="width: 100%; font-size: 0.9em;">
    '''
    
    important_fields = [
        ('pdb_id', 'PDB ID'),
        ('select_chains', 'Chains'),
        ('temperature', 'Temperature (K)'),
        ('simulation_time_ns', 'Simulation Time (ns)'),
        ('force_field', 'Force Field'),
        ('water_model', 'Water Model'),
        ('box_padding', 'Box Padding (√Ö)'),
        ('is_membrane', 'Membrane System'),
    ]
    
    for field, label in important_fields:
        value = brief_dict.get(field, 'N/A')
        if value is not None:
            html += f'<tr><td style="padding: 3px;"><b>{label}:</b></td><td>{value}</td></tr>'
    
    html += '</table></div>'
    
    with chat_output:
        display(HTML(html))

# ============================================================================
# Agent Interaction
# ============================================================================

async def process_message(message):
    """Process user message through the clarification agent."""
    global chat_state
    
    # Initialize graph if needed
    if chat_state.graph is None:
        chat_state.graph = create_clarification_graph()
    
    # Build input state
    if chat_state.agent_state is None:
        input_state = {"messages": [HumanMessage(content=message)]}
    else:
        # Continue conversation
        input_state = {
            "messages": chat_state.agent_state.get("messages", []) + [HumanMessage(content=message)],
            "structure_info": chat_state.agent_state.get("structure_info"),
        }
    
    # Run the agent
    result = await chat_state.graph.ainvoke(input_state)
    chat_state.agent_state = result
    
    # Extract response
    messages = result.get("messages", [])
    if messages:
        last_msg = messages[-1]
        if hasattr(last_msg, 'content') and last_msg.content:
            return last_msg.content, result.get("simulation_brief")
    
    return None, result.get("simulation_brief")

# ============================================================================
# Event Handlers
# ============================================================================

def on_send_click(b):
    """Handle send button click."""
    message = user_input.value.strip()
    if not message:
        return
    
    # Handle special commands
    if message.lower() == 'clear':
        on_clear_click(None)
        return
    elif message.lower() == 'run':
        if not run_button.disabled:
            on_run_click(None)
        else:
            add_message('system', 'No simulation brief ready yet. Keep chatting with the agent.')
        user_input.value = ''
        return
    elif message.lower() == 'help':
        help_text = '''**Available Commands:**
‚Ä¢ `clear` - Reset conversation
‚Ä¢ `run` - Execute simulation workflow
‚Ä¢ `help` - Show this help

**Example Prompts:**
‚Ä¢ "Setup MD for PDB 1AKE in water"
‚Ä¢ "Simulate lysozyme (1LYZ) for 1 ns at 310K"
‚Ä¢ "I want to study insulin, chain A only"'''
        add_message('system', help_text)
        user_input.value = ''
        return
    
    # Clear input
    user_input.value = ''
    
    # Show user message
    add_message('user', message)
    
    # Update status
    status_label.value = '<span style="color: blue;">üîÑ Processing...</span>'
    send_button.disabled = True
    
    # Process asynchronously
    async def process_and_display():
        try:
            response, brief = await process_message(message)
            
            if response:
                add_message('agent', response)
            
            if brief:
                chat_state.simulation_brief = brief
                display_brief(brief)
                run_button.disabled = False
                status_label.value = '<span style="color: green;">‚úì Ready to run simulation</span>'
                add_message('system', '‚úÖ Simulation brief generated! Click **Run Simulation** to start, or continue chatting to modify.')
            else:
                status_label.value = '<span style="color: gray;">Ready to chat</span>'
            
        except Exception as e:
            add_message('system', f'‚ùå Error: {str(e)}')
            status_label.value = '<span style="color: red;">Error occurred</span>'
        finally:
            send_button.disabled = False
    
    # Run async in notebook
    asyncio.ensure_future(process_and_display())

def on_clear_click(b):
    """Handle clear button click."""
    chat_state.reset()
    chat_output.clear_output()
    viz_output.clear_output()
    run_button.disabled = True
    status_label.value = '<span style="color: gray;">Ready to chat</span>'
    add_message('system', 'Conversation cleared. Start a new simulation setup!')

def on_run_click(b):
    """Handle run simulation button click."""
    if chat_state.simulation_brief is None:
        add_message('system', '‚ùå No simulation brief available. Please complete the chat first.')
        return
    
    add_message('system', 'üöÄ Starting MD workflow...')
    status_label.value = '<span style="color: blue;">üîÑ Running workflow...</span>'
    run_button.disabled = True
    send_button.disabled = True
    progress_bar.layout.visibility = 'visible'
    
    async def run_workflow():
        try:
            await execute_md_workflow(chat_state.simulation_brief)
            status_label.value = '<span style="color: green;">‚úì Workflow complete!</span>'
        except Exception as e:
            add_message('system', f'‚ùå Workflow error: {str(e)}')
            status_label.value = '<span style="color: red;">Workflow failed</span>'
        finally:
            progress_bar.layout.visibility = 'hidden'
            send_button.disabled = False
            run_button.disabled = False
    
    asyncio.ensure_future(run_workflow())

def on_input_submit(sender):
    """Handle Enter key in input field."""
    on_send_click(None)

# ============================================================================
# MD Workflow Execution
# ============================================================================

async def execute_md_workflow(brief):
    """Execute the complete MD workflow with visualization."""
    import importlib
    import py3Dmol
    
    # Get brief as dict
    if hasattr(brief, 'model_dump'):
        brief_dict = brief.model_dump()
    else:
        brief_dict = brief
    
    pdb_id = brief_dict.get('pdb_id')
    select_chains = brief_dict.get('select_chains')
    
    if not pdb_id:
        add_message('system', '‚ùå No PDB ID specified in brief')
        return
    
    # Create output directory
    import sys
    IN_COLAB = 'google.colab' in sys.modules
    if IN_COLAB:
        output_dir = Path("/content/mcp-md/output") / f"{pdb_id}_{datetime.now().strftime('%H%M%S')}"
    else:
        output_dir = Path("./output") / f"{pdb_id}_{datetime.now().strftime('%H%M%S')}"
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # ===== Step 1: Fetch Structure =====
    add_message('system', f'üì• Step 1/5: Fetching {pdb_id} from PDB...')
    progress_bar.value = 10
    
    import servers.structure_server as structure_module
    importlib.reload(structure_module)
    
    fetch_result = await structure_module.fetch_molecules(pdb_id=pdb_id, source="pdb", prefer_format="pdb")
    
    if not fetch_result["success"]:
        raise RuntimeError(f"Fetch failed: {fetch_result['errors']}")
    
    structure_file = fetch_result["file_path"]
    add_message('system', f'‚úì Fetched: {Path(structure_file).name} ({fetch_result["num_atoms"]} atoms)')
    
    # Visualize original structure
    with viz_output:
        clear_output()
        with open(structure_file, 'r') as f:
            pdb_content = f.read()
        view = py3Dmol.view(width=800, height=400)
        view.addModel(pdb_content, 'pdb')
        view.setStyle({'cartoon': {'color': 'spectrum'}})
        view.zoomTo()
        print(f"Original structure: {pdb_id}")
        display(view.show())
    
    progress_bar.value = 20
    
    # ===== Step 2: Prepare Complex =====
    add_message('system', f'üîß Step 2/5: Preparing complex (chains: {select_chains or "all"})...')
    
    complex_result = structure_module.prepare_complex(
        structure_file=structure_file,
        select_chains=select_chains,
        ph=brief_dict.get('ph', 7.4),
        process_proteins=True,
        process_ligands=True,
        run_parameterization=True
    )
    
    if not complex_result["success"]:
        raise RuntimeError(f"Prepare failed: {complex_result['errors']}")
    
    merged_pdb = complex_result["merged_pdb"]
    add_message('system', f'‚úì Prepared: {len(complex_result["proteins"])} protein(s), {len(complex_result["ligands"])} ligand(s)')
    
    progress_bar.value = 40
    
    # ===== Step 3: Solvate =====
    add_message('system', f'üíß Step 3/5: Solvating ({brief_dict.get("box_padding", 12)} √Ö box)...')
    
    import servers.solvation_server as solvation_module
    importlib.reload(solvation_module)
    
    solvate_result = solvation_module.solvate_structure(
        pdb_file=str(Path(merged_pdb).resolve()),
        output_dir=str(Path(complex_result["output_dir"]).resolve()),
        output_name="solvated",
        dist=brief_dict.get('box_padding', 12.0),
        cubic=brief_dict.get('cubic_box', True),
        salt=True,
        saltcon=brief_dict.get('salt_concentration', 0.15)
    )
    
    if not solvate_result["success"]:
        raise RuntimeError(f"Solvate failed: {solvate_result['errors']}")
    
    solvated_pdb = solvate_result["output_file"]
    box = solvate_result.get("box_dimensions", {})
    add_message('system', f'‚úì Solvated: {solvate_result["statistics"].get("total_atoms", "?")} atoms')
    
    progress_bar.value = 55
    
    # ===== Step 4: Build Amber System =====
    add_message('system', f'üèóÔ∏è Step 4/5: Building Amber topology...')
    
    import servers.amber_server as amber_module
    importlib.reload(amber_module)
    
    # Collect ligand params
    ligand_params = []
    for lig in complex_result.get("ligands", []):
        if lig.get("success") and lig.get("mol2_file"):
            ligand_params.append({
                "mol2": lig["mol2_file"],
                "frcmod": lig["frcmod_file"],
                "residue_name": lig["ligand_id"][:3].upper()
            })
    
    amber_result = amber_module.build_amber_system(
        pdb_file=solvate_result["output_file"],
        ligand_params=ligand_params if ligand_params else None,
        box_dimensions=solvate_result.get("box_dimensions"),
        water_model=brief_dict.get('water_model', 'tip3p'),
        output_name="system"
    )
    
    if not amber_result['success']:
        raise RuntimeError(f"Amber build failed: {amber_result['errors']}")
    
    parm7_file = amber_result['parm7']
    rst7_file = amber_result['rst7']
    add_message('system', f'‚úì Built: {Path(parm7_file).name}')
    
    progress_bar.value = 70
    
    # ===== Step 5: Run MD Simulation =====
    sim_time = brief_dict.get('simulation_time_ns', 0.1)  # Default 100 ps for demo
    add_message('system', f'üèÉ Step 5/5: Running {sim_time} ns simulation...')
    
    import openmm as mm
    from openmm import app, unit
    from openmm.app import AmberPrmtopFile, AmberInpcrdFile, Simulation, DCDReporter, PDBFile
    
    # Select platform
    platform = None
    for name in ['CUDA', 'OpenCL', 'CPU']:
        try:
            platform = mm.Platform.getPlatformByName(name)
            break
        except:
            continue
    
    # Load and create system
    prmtop = AmberPrmtopFile(parm7_file)
    inpcrd = AmberInpcrdFile(rst7_file)
    
    temperature = brief_dict.get('temperature', 300.0) * unit.kelvin
    pressure = (brief_dict.get('pressure_bar') or 1.0) * unit.atmosphere
    timestep = brief_dict.get('timestep', 2.0) * unit.femtoseconds
    
    system = prmtop.createSystem(
        nonbondedMethod=app.PME,
        nonbondedCutoff=10 * unit.angstrom,
        constraints=app.HBonds,
        rigidWater=True
    )
    system.addForce(mm.MonteCarloBarostat(pressure, temperature, 25))
    
    integrator = mm.LangevinMiddleIntegrator(temperature, 1/unit.picosecond, timestep)
    simulation = Simulation(prmtop.topology, system, integrator, platform)
    simulation.context.setPositions(inpcrd.positions)
    if inpcrd.boxVectors:
        simulation.context.setPeriodicBoxVectors(*inpcrd.boxVectors)
    
    # Minimize
    simulation.minimizeEnergy(maxIterations=500)
    simulation.context.setVelocitiesToTemperature(temperature)
    
    # Setup trajectory
    dcd_file = Path(complex_result["output_dir"]) / "trajectory.dcd"
    total_steps = int(sim_time * 1e6 / 2)  # ns to steps (2fs timestep)
    report_interval = max(100, total_steps // 100)
    
    simulation.reporters.append(DCDReporter(str(dcd_file), report_interval))
    
    # Run simulation with progress updates
    steps_per_update = total_steps // 10
    for i in range(10):
        simulation.step(steps_per_update)
        progress_bar.value = 70 + (i + 1) * 2.5
    
    # Save final state
    final_pdb = Path(complex_result["output_dir"]) / "final_state.pdb"
    state = simulation.context.getState(getPositions=True)
    with open(final_pdb, 'w') as f:
        PDBFile.writeFile(simulation.topology, state.getPositions(), f)
    
    add_message('system', f'‚úì Simulation complete! Trajectory: {dcd_file.name}')
    progress_bar.value = 95
    
    # Store outputs
    chat_state.workflow_outputs = {
        'structure_file': structure_file,
        'merged_pdb': merged_pdb,
        'solvated_pdb': solvated_pdb,
        'parm7': parm7_file,
        'rst7': rst7_file,
        'trajectory': str(dcd_file),
        'final_pdb': str(final_pdb),
        'output_dir': complex_result["output_dir"]
    }
    
    # ===== Visualize Trajectory =====
    add_message('system', 'üé¨ Generating trajectory visualization...')
    
    import mdtraj as md
    import numpy as np
    import tempfile
    
    traj = md.load(str(dcd_file), top=parm7_file)
    
    # Select protein only
    protein_indices = traj.topology.select('protein')
    traj_protein = traj.atom_slice(protein_indices)
    
    # Sample frames
    max_frames = 20
    if traj_protein.n_frames > max_frames:
        frame_indices = np.linspace(0, traj_protein.n_frames - 1, max_frames, dtype=int)
        traj_viz = traj_protein[frame_indices]
    else:
        traj_viz = traj_protein
    
    # Write multi-model PDB
    with tempfile.NamedTemporaryFile(suffix='.pdb', delete=False, mode='w') as tmp:
        tmp_path = tmp.name
    
    with open(tmp_path, 'w') as f:
        for i in range(traj_viz.n_frames):
            frame_tmp = tmp_path + f".frame{i}.pdb"
            traj_viz[i].save_pdb(frame_tmp, force_overwrite=True)
            with open(frame_tmp, 'r') as ff:
                content = ff.read()
            f.write(f"MODEL     {i + 1}\n")
            for line in content.split('\n'):
                if not line.startswith('MODEL') and not line.startswith('ENDMDL') and line.strip():
                    f.write(line + '\n')
            f.write("ENDMDL\n")
            Path(frame_tmp).unlink()
    
    with open(tmp_path, 'r') as f:
        traj_pdb = f.read()
    Path(tmp_path).unlink()
    
    # Display animated trajectory
    with viz_output:
        clear_output()
        view = py3Dmol.view(width=800, height=500)
        view.addModelsAsFrames(traj_pdb, 'pdb')
        view.setStyle({'cartoon': {'color': 'spectrum'}})
        view.zoomTo()
        view.animate({'loop': 'forward', 'reps': 0, 'interval': 100})
        print(f"Trajectory Animation: {traj_viz.n_frames} frames, {traj.time[-1]:.1f} ps")
        display(view.show())
    
    progress_bar.value = 100
    
    # Summary
    summary = f'''**Workflow Complete!** üéâ

**Output Directory:** `{complex_result["output_dir"]}`

**Generated Files:**
‚Ä¢ Topology: `{Path(parm7_file).name}`
‚Ä¢ Coordinates: `{Path(rst7_file).name}`
‚Ä¢ Trajectory: `{dcd_file.name}` ({traj.n_frames} frames)
‚Ä¢ Final state: `{final_pdb.name}`'''
    
    add_message('system', summary)

# ============================================================================
# Connect Event Handlers
# ============================================================================
send_button.on_click(on_send_click)
clear_button.on_click(on_clear_click)
run_button.on_click(on_run_click)
user_input.on_submit(on_input_submit)

# ============================================================================
# Display UI
# ============================================================================

# Welcome message
with chat_output:
    display(HTML('''
    <div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); 
                border-radius: 10px; color: white; margin-bottom: 10px;">
        <h3 style="margin: 0;">üß¨ MCP-MD AI Agent</h3>
        <p style="margin: 5px 0 0 0; opacity: 0.9;">Your AI assistant for molecular dynamics simulations</p>
    </div>
    '''))

add_message('agent', '''Hello! I'm your MD simulation assistant. üëã

Tell me what you'd like to simulate, for example:
‚Ä¢ "Setup MD for PDB 1AKE in water"
‚Ä¢ "I want to simulate lysozyme (1LYZ) for 1 ns"
‚Ä¢ "Run a simulation of insulin, chain A only"

I'll analyze the structure and ask any clarifying questions before we begin.''')

# Layout
input_row = widgets.HBox([user_input, send_button], layout=widgets.Layout(margin='10px 0'))
button_row = widgets.HBox([run_button, clear_button], layout=widgets.Layout(margin='5px 0'))

ui = widgets.VBox([
    chat_output,
    input_row,
    button_row,
    status_label,
    progress_bar,
    widgets.HTML('<h4 style="margin: 20px 0 10px 0;">üìä Visualization</h4>'),
    viz_output
])

display(ui)

---
## üìÅ Download Results

After the simulation completes, run this cell to download the generated files.

In [None]:
import sys
from pathlib import Path

IN_COLAB = 'google.colab' in sys.modules

if chat_state.workflow_outputs:
    output_dir = chat_state.workflow_outputs.get('output_dir')
    
    if output_dir and Path(output_dir).exists():
        print(f"Output directory: {output_dir}")
        print("\nGenerated files:")
        for f in Path(output_dir).glob('*'):
            size_kb = f.stat().st_size / 1024
            print(f"  {f.name} ({size_kb:.1f} KB)")
        
        if IN_COLAB:
            from google.colab import files
            import shutil
            
            # Create zip file
            zip_name = f"{Path(output_dir).name}.zip"
            shutil.make_archive(Path(output_dir).name, 'zip', output_dir)
            
            print(f"\nüì• Downloading {zip_name}...")
            files.download(zip_name)
        else:
            print(f"\nFiles are in: {output_dir}")
    else:
        print("No output directory found. Run a simulation first.")
else:
    print("No workflow outputs available. Run a simulation first using the chat interface above.")

---
## üìà Analyze Trajectory (Optional)

Run basic analysis on the generated trajectory.

In [None]:
import matplotlib.pyplot as plt
import mdtraj as md
import numpy as np

if chat_state.workflow_outputs and chat_state.workflow_outputs.get('trajectory'):
    traj_file = chat_state.workflow_outputs['trajectory']
    top_file = chat_state.workflow_outputs['parm7']
    
    print(f"Loading trajectory: {traj_file}")
    traj = md.load(traj_file, top=top_file)
    
    # Select protein for RMSD
    protein_atoms = traj.topology.select('protein and name CA')
    
    if len(protein_atoms) > 0:
        # Calculate RMSD
        rmsd = md.rmsd(traj, traj, 0, atom_indices=protein_atoms) * 10  # Convert to Angstrom
        
        # Plot
        fig, ax = plt.subplots(figsize=(10, 4))
        ax.plot(traj.time, rmsd, 'b-', linewidth=1)
        ax.set_xlabel('Time (ps)')
        ax.set_ylabel('RMSD (√Ö)')
        ax.set_title('CŒ± RMSD vs Time')
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()
        
        print(f"\nRMSD Statistics:")
        print(f"  Mean: {np.mean(rmsd):.2f} √Ö")
        print(f"  Max:  {np.max(rmsd):.2f} √Ö")
        print(f"  Final: {rmsd[-1]:.2f} √Ö")
    else:
        print("No protein CŒ± atoms found for RMSD calculation")
else:
    print("No trajectory available. Run a simulation first using the chat interface.")

---

## Next Steps

1. **Longer simulations**: Modify the simulation time in your chat request
2. **Analysis**: Use MDTraj for RMSD, RMSF, hydrogen bonds, etc.
3. **Different systems**: Try membrane proteins, protein-ligand complexes
4. **Batch processing**: Use `main.py batch` command for automated runs

For more information, see the [GitHub repository](https://github.com/matsunagalab/mcp-md).