<a href="https://colab.research.google.com/github/matsunagalab/mcp-md/blob/main/colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# üß¨ MCP-MD: AI-Powered Molecular Dynamics Agent

**Interactive AI assistant for setting up MD simulations**

This notebook provides a chat interface to interact with the MCP-MD AI agent. Simply describe what you want to simulate, and the agent will:

1. **Analyze** your request and ask clarifying questions
2. **Fetch** structures from PDB/AlphaFold
3. **Prepare** protein + parameterize ligands (GAFF2/AM1-BCC)
4. **Solvate** with water box + ions
5. **Build** Amber topology (tleap)
6. **Simulate** with OpenMM (NPT ensemble)
7. **Visualize** results with interactive 3D viewer

---

## Quick Start

1. **Run Setup cells** (Cell 1-2) - installs dependencies (~5-10 min)
2. **Enter your API key** (Cell 3)
3. **Start chatting!** - describe your simulation (Cell 4)

**Example prompts:**
- "Setup MD for PDB 1AKE in water, 1 ns at 300K"
- "I want to simulate lysozyme (PDB 1LYZ) with explicit solvent"
- "Run a short simulation of insulin (PDB 4INS), chain A only"

---
## Setup 1/2: Install condacolab

**‚ö†Ô∏è The runtime will restart after this cell. This is expected!**

In [None]:
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Installing condacolab...")
    print("‚ö†Ô∏è The runtime will restart. Run the next cell after restart.")
    !pip install -q condacolab
    import condacolab
    condacolab.install()
else:
    print("Not running in Colab - skipping condacolab setup")
    print("Make sure you have conda environment with AmberTools installed.")

---
## Setup 2/2: Install Dependencies

**Run this cell AFTER the runtime restarts.**

Installs AmberTools, OpenMM, RDKit, and project dependencies (~5-10 min)

In [None]:
import sys
import time

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    import condacolab
    condacolab.check()
    
    start_time = time.time()
    
    # Install conda packages (AmberTools + heavy scientific packages)
    print("="*60)
    print("Installing AmberTools + scientific packages via conda...")
    print("This takes ~5-10 minutes. Please wait.")
    print("="*60)
    !conda install -y -c conda-forge ambertools=23 openmm rdkit pdbfixer 2>&1 | tail -20
    print(f"\n‚úì Conda packages installed ({time.time() - start_time:.0f}s)")
    
    # Clone repository and install
    print("\nCloning mcp-md repository...")
    !rm -rf /content/mcp-md  # Remove if exists to get fresh copy
    !git clone -q https://github.com/matsunagalab/mcp-md.git /content/mcp-md
    %cd /content/mcp-md
    
    print("Installing Python dependencies (using conda's pip)...")
    # Use conda's pip explicitly to install to the correct environment
    !conda run pip install -q -e .
    !conda run pip install -q langchain langchain-core langgraph langchain-anthropic langchain-mcp-adapters
    !conda run pip install -q py3Dmol mdtraj nest_asyncio
    
    # Set AMBERHOME
    import os
    import json
    import subprocess
    conda_info = json.loads(subprocess.run(['conda', 'info', '--json'], 
                                           capture_output=True, text=True).stdout)
    os.environ["AMBERHOME"] = conda_info.get('default_prefix', '')
    
    # Add paths for mcp_md module (src layout) and servers
    sys.path.insert(0, '/content/mcp-md/src')
    sys.path.insert(0, '/content/mcp-md')
    
    total_time = time.time() - start_time
    print(f"\n" + "="*60)
    print(f"‚úì Setup complete! ({total_time/60:.1f} minutes)")
    print("="*60)
    print("\nüéâ You can now proceed to the next cell!")

else:
    # Local development - add src to path
    sys.path.insert(0, './src')
    sys.path.insert(0, '.')
    print("Local environment - dependencies should be pre-installed.")

---
## API Key Configuration

Enter your Anthropic API key to enable the AI agent.

Get your API key from: https://console.anthropic.com/

In [None]:
import os
import sys
from getpass import getpass

IN_COLAB = 'google.colab' in sys.modules

# Check if API key is already set
if os.environ.get("ANTHROPIC_API_KEY"):
    print("‚úì ANTHROPIC_API_KEY is already set")
else:
    if IN_COLAB:
        from google.colab import userdata
        try:
            # Try to get from Colab secrets
            api_key = userdata.get('ANTHROPIC_API_KEY')
            os.environ["ANTHROPIC_API_KEY"] = api_key
            print("‚úì ANTHROPIC_API_KEY loaded from Colab secrets")
        except:
            # Prompt for input
            print("Enter your Anthropic API key:")
            api_key = getpass()
            os.environ["ANTHROPIC_API_KEY"] = api_key
            print("‚úì ANTHROPIC_API_KEY set")
    else:
        # Local - try to load from .env
        from dotenv import load_dotenv
        load_dotenv()
        if os.environ.get("ANTHROPIC_API_KEY"):
            print("‚úì ANTHROPIC_API_KEY loaded from .env")
        else:
            print("Enter your Anthropic API key:")
            api_key = getpass()
            os.environ["ANTHROPIC_API_KEY"] = api_key
            print("‚úì ANTHROPIC_API_KEY set")

---
## ü§ñ MCP-MD AI Agent Chat Interface

Interact with the AI agent to set up your MD simulation!

**How to use:**
1. Type your request in the text box (e.g., "Setup MD for PDB 1AKE")
2. Click **Send** or press Enter
3. The agent will analyze your request and may ask clarifying questions
4. Answer the questions until the agent has enough information
5. Click **Run Simulation** to execute the workflow

**Commands:**
- Type `clear` to reset the conversation
- Type `run` to execute the simulation workflow
- Type `help` for more commands

In [None]:
# ============================================================================
# Chat Interface (ipywidgets-based, similar to GPT-4 chat UI)
# ============================================================================

import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import asyncio
import json
from pathlib import Path
from datetime import datetime

# Import MCP-MD components
from langchain_core.messages import HumanMessage, AIMessage
from mcp_md.clarification_agent import create_clarification_graph
from mcp_md.state_scope import SimulationBrief

# Global state
past_messages = []  # Chat history for display
agent_state = None  # LangGraph state
simulation_brief = None  # Generated brief
graph = None  # Clarification graph
workflow_outputs = {}  # Output files from workflow

# UI widgets
input_box_layout = widgets.Layout(width='100%')
input_box = widgets.Textarea(
    description='You:',
    rows=3,
    layout=input_box_layout,
    placeholder='‰æã: "Setup MD for PDB 1AKE in water, 1 ns at 300K"'
)
send_button = widgets.Button(description='Send', button_style='primary')
output_area = widgets.Output()

def send_to_agent(button):
    """Handle send button click - process user message through agent."""
    global past_messages, agent_state, simulation_brief, graph
    
    user_input = input_box.value.strip()
    if not user_input:
        return
    
    # Add user message to history
    past_messages.append({"role": "user", "content": user_input})
    
    # Clear input
    input_box.value = ''
    
    # Show processing state
    with output_area:
        clear_output()
        for msg in past_messages:
            display_message(msg['role'], msg['content'])
        display(HTML('<p style="color: blue;">üîÑ Processing...</p>'))
    
    # Process with agent
    async def process():
        global agent_state, simulation_brief, graph
        
        try:
            # Initialize graph if needed
            if graph is None:
                graph = create_clarification_graph()
            
            # Build input state
            if agent_state is None:
                input_state = {"messages": [HumanMessage(content=user_input)]}
            else:
                input_state = {
                    "messages": agent_state.get("messages", []) + [HumanMessage(content=user_input)],
                    "structure_info": agent_state.get("structure_info"),
                }
            
            # Run agent
            result = await graph.ainvoke(input_state)
            agent_state = result
            
            # Extract response
            response = None
            messages = result.get("messages", [])
            if messages:
                last_msg = messages[-1]
                if hasattr(last_msg, 'content') and last_msg.content:
                    response = last_msg.content
            
            # Add agent response to history
            if response:
                past_messages.append({"role": "assistant", "content": response})
            
            # Check for simulation brief
            brief = result.get("simulation_brief")
            if brief:
                simulation_brief = brief
                past_messages.append({
                    "role": "system", 
                    "content": f"‚úÖ SimulationBrief generated!\n\nPDB: {brief.pdb_id}\nTemp: {brief.temperature}K\nTime: {brief.simulation_time_ns}ns\n\nÊ¨°„ÅÆ„Çª„É´„ÇíÂÆüË°å„Åó„Å¶MD„ÉØ„Éº„ÇØ„Éï„É≠„Éº„ÇíÈñãÂßã„Åó„Å¶„Åè„Å†„Åï„ÅÑ„ÄÇ"
                })
            
        except Exception as e:
            past_messages.append({"role": "system", "content": f"‚ùå Error: {str(e)}"})
        
        # Update display
        with output_area:
            clear_output()
            for msg in past_messages:
                display_message(msg['role'], msg['content'])
    
    # Run async
    asyncio.ensure_future(process())

def display_message(role, content):
    """Display a single message with styling."""
    if role == 'user':
        color = 'blue'
        label = 'You'
    elif role == 'assistant':
        color = 'green'
        label = 'Agent'
    else:
        color = 'gray'
        label = 'System'
    
    display(HTML(f"<span style='color: {color}; white-space: pre-wrap;'><b>{label}:</b> {content}</span><br><br>"))

def reset_chat(button):
    """Reset chat state."""
    global past_messages, agent_state, simulation_brief, graph, workflow_outputs
    past_messages = []
    agent_state = None
    simulation_brief = None
    graph = None
    workflow_outputs = {}
    with output_area:
        clear_output()
        display(HTML('<p style="color: gray;">Chat cleared. Start a new conversation!</p>'))

# Reset button
reset_button = widgets.Button(description='Clear Chat', button_style='warning')
reset_button.on_click(reset_chat)

# Connect event handlers
send_button.on_click(send_to_agent)

# Display UI
display(HTML('''
<div style="padding: 15px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); 
            border-radius: 10px; color: white; margin-bottom: 15px;">
    <h3 style="margin: 0;">üß¨ MCP-MD AI Agent</h3>
    <p style="margin: 5px 0 0 0;">Describe your MD simulation to get started</p>
</div>
'''))

display(output_area)
display(widgets.HBox([input_box, send_button, reset_button]))

# ============================================================================
# MD Workflow Execution
# ============================================================================
# Run this cell after the chat generates a SimulationBrief

import importlib
from pathlib import Path
from datetime import datetime

if simulation_brief is None:
    print("‚ùå No simulation brief available. Please complete the chat first.")
else:
    print("üöÄ Starting MD workflow...")
    
    # Get brief as dict
    if hasattr(simulation_brief, 'model_dump'):
        brief_dict = simulation_brief.model_dump()
    else:
        brief_dict = simulation_brief
    
    pdb_id = brief_dict.get('pdb_id')
    select_chains = brief_dict.get('select_chains')
    
    if not pdb_id:
        print("‚ùå No PDB ID specified")
    else:
        # Create output directory
        import sys
        IN_COLAB = 'google.colab' in sys.modules
        if IN_COLAB:
            output_base = Path("/content/mcp-md/output")
        else:
            output_base = Path("./output")
        
        output_dir = output_base / f"{pdb_id}_{datetime.now().strftime('%H%M%S')}"
        output_dir.mkdir(parents=True, exist_ok=True)
        
        async def run_workflow():
            global workflow_outputs
            
            # Step 1: Fetch
            print(f"üì• Step 1/5: Fetching {pdb_id}...")
            import servers.structure_server as structure_module
            importlib.reload(structure_module)
            
            fetch_result = await structure_module.fetch_molecules(pdb_id=pdb_id, source="pdb", prefer_format="pdb")
            if not fetch_result["success"]:
                raise RuntimeError(f"Fetch failed: {fetch_result['errors']}")
            structure_file = fetch_result["file_path"]
            print(f"   ‚úì Fetched: {Path(structure_file).name}")
            
            # Step 2: Prepare
            print(f"üîß Step 2/5: Preparing complex...")
            complex_result = structure_module.prepare_complex(
                structure_file=structure_file,
                select_chains=select_chains,
                ph=brief_dict.get('ph', 7.4),
                process_proteins=True,
                process_ligands=True,
                run_parameterization=True
            )
            if not complex_result["success"]:
                raise RuntimeError(f"Prepare failed: {complex_result['errors']}")
            merged_pdb = complex_result["merged_pdb"]
            print(f"   ‚úì Prepared: {len(complex_result['proteins'])} protein(s), {len(complex_result['ligands'])} ligand(s)")
            
            # Step 3: Solvate
            print(f"üíß Step 3/5: Solvating...")
            import servers.solvation_server as solvation_module
            importlib.reload(solvation_module)
            
            solvate_result = solvation_module.solvate_structure(
                pdb_file=str(Path(merged_pdb).resolve()),
                output_dir=str(Path(complex_result["output_dir"]).resolve()),
                output_name="solvated",
                dist=brief_dict.get('box_padding', 12.0),
                cubic=brief_dict.get('cubic_box', True),
                salt=True,
                saltcon=brief_dict.get('salt_concentration', 0.15)
            )
            if not solvate_result["success"]:
                raise RuntimeError(f"Solvate failed: {solvate_result['errors']}")
            print(f"   ‚úì Solvated: {solvate_result['statistics'].get('total_atoms', '?')} atoms")
            
            # Step 4: Build Amber
            print(f"üèóÔ∏è Step 4/5: Building Amber topology...")
            import servers.amber_server as amber_module
            importlib.reload(amber_module)
            
            ligand_params = []
            for lig in complex_result.get("ligands", []):
                if lig.get("success") and lig.get("mol2_file"):
                    ligand_params.append({
                        "mol2": lig["mol2_file"],
                        "frcmod": lig["frcmod_file"],
                        "residue_name": lig["ligand_id"][:3].upper()
                    })
            
            amber_result = amber_module.build_amber_system(
                pdb_file=solvate_result["output_file"],
                ligand_params=ligand_params if ligand_params else None,
                box_dimensions=solvate_result.get("box_dimensions"),
                water_model=brief_dict.get('water_model', 'tip3p'),
                output_name="system"
            )
            if not amber_result['success']:
                raise RuntimeError(f"Amber build failed: {amber_result['errors']}")
            parm7_file = amber_result['parm7']
            rst7_file = amber_result['rst7']
            print(f"   ‚úì Built: {Path(parm7_file).name}")
            
            # Step 5: Run MD
            sim_time = brief_dict.get('simulation_time_ns', 0.1)
            print(f"üèÉ Step 5/5: Running {sim_time} ns simulation...")
            
            import openmm as mm
            from openmm import app, unit
            from openmm.app import AmberPrmtopFile, AmberInpcrdFile, Simulation, DCDReporter, PDBFile
            
            # Select platform
            platform = None
            for name in ['CUDA', 'OpenCL', 'CPU']:
                try:
                    platform = mm.Platform.getPlatformByName(name)
                    print(f"   Using {name} platform")
                    break
                except:
                    continue
            
            if platform is None:
                raise RuntimeError("No OpenMM platform available")
            
            prmtop = AmberPrmtopFile(parm7_file)
            inpcrd = AmberInpcrdFile(rst7_file)
            
            temperature = brief_dict.get('temperature', 300.0) * unit.kelvin
            pressure = (brief_dict.get('pressure_bar') or 1.0) * unit.atmosphere
            timestep = 2.0 * unit.femtoseconds
            
            system = prmtop.createSystem(
                nonbondedMethod=app.PME,
                nonbondedCutoff=10 * unit.angstrom,
                constraints=app.HBonds,
                rigidWater=True
            )
            system.addForce(mm.MonteCarloBarostat(pressure, temperature, 25))
            
            integrator = mm.LangevinMiddleIntegrator(temperature, 1/unit.picosecond, timestep)
            simulation = Simulation(prmtop.topology, system, integrator, platform)
            simulation.context.setPositions(inpcrd.positions)
            if inpcrd.boxVectors:
                simulation.context.setPeriodicBoxVectors(*inpcrd.boxVectors)
            
            simulation.minimizeEnergy(maxIterations=500)
            simulation.context.setVelocitiesToTemperature(temperature)
            
            dcd_file = Path(complex_result["output_dir"]) / "trajectory.dcd"
            total_steps = int(sim_time * 1e6 / 2)
            report_interval = max(100, total_steps // 100)
            simulation.reporters.append(DCDReporter(str(dcd_file), report_interval))
            
            simulation.step(total_steps)
            
            final_pdb = Path(complex_result["output_dir"]) / "final_state.pdb"
            state = simulation.context.getState(getPositions=True)
            with open(final_pdb, 'w') as f:
                PDBFile.writeFile(simulation.topology, state.getPositions(), f)
            
            print(f"   ‚úì Complete!")
            
            # Store outputs
            workflow_outputs = {
                'structure_file': structure_file,
                'merged_pdb': merged_pdb,
                'solvated_pdb': solvate_result["output_file"],
                'parm7': parm7_file,
                'rst7': rst7_file,
                'trajectory': str(dcd_file),
                'final_pdb': str(final_pdb),
                'output_dir': complex_result["output_dir"]
            }
            
            print(f"\n‚úÖ Workflow complete!")
            print(f"Output directory: {complex_result['output_dir']}")
        
        # Run async workflow
        import nest_asyncio
        nest_asyncio.apply()
        asyncio.get_event_loop().run_until_complete(run_workflow())

In [None]:
# ============================================================================
# 3D Visualization with py3Dmol (Independent Cell)
# ============================================================================
# Run this cell to visualize the trajectory after workflow completion

import py3Dmol
import mdtraj as md
import numpy as np
import tempfile
from pathlib import Path

if not workflow_outputs or 'trajectory' not in workflow_outputs:
    print("‚ùå No trajectory available. Run the MD workflow first.")
else:
    traj_file = workflow_outputs['trajectory']
    top_file = workflow_outputs['parm7']
    
    print(f"Loading trajectory: {traj_file}")
    traj = md.load(traj_file, top=top_file)
    print(f"Loaded {traj.n_frames} frames, {traj.n_atoms} atoms")
    
    # Select protein only for visualization
    protein_indices = traj.topology.select('protein')
    if len(protein_indices) > 0:
        traj_protein = traj.atom_slice(protein_indices)
        
        # Sample frames for visualization
        max_frames = 20
        if traj_protein.n_frames > max_frames:
            frame_indices = np.linspace(0, traj_protein.n_frames - 1, max_frames, dtype=int)
            traj_viz = traj_protein[frame_indices]
        else:
            traj_viz = traj_protein
        
        # Write multi-model PDB for animation
        with tempfile.NamedTemporaryFile(suffix='.pdb', delete=False, mode='w') as tmp:
            tmp_path = tmp.name
        
        with open(tmp_path, 'w') as f:
            for i in range(traj_viz.n_frames):
                frame_tmp = tmp_path + f".frame{i}.pdb"
                traj_viz[i].save_pdb(frame_tmp, force_overwrite=True)
                with open(frame_tmp, 'r') as ff:
                    content = ff.read()
                f.write(f"MODEL     {i + 1}\n")
                for line in content.split('\n'):
                    if not line.startswith('MODEL') and not line.startswith('ENDMDL') and line.strip():
                        f.write(line + '\n')
                f.write("ENDMDL\n")
                Path(frame_tmp).unlink()
        
        with open(tmp_path, 'r') as f:
            traj_pdb = f.read()
        Path(tmp_path).unlink()
        
        # Create animated 3D view
        view = py3Dmol.view(width=800, height=500)
        view.addModelsAsFrames(traj_pdb, 'pdb')
        view.setStyle({'cartoon': {'color': 'spectrum'}})
        view.zoomTo()
        view.animate({'loop': 'forward', 'reps': 0, 'interval': 100})
        
        print(f"\nüé¨ Trajectory Animation: {traj_viz.n_frames} frames")
        print(f"   Simulation time: {traj.time[-1]:.1f} ps")
        view.show()
    else:
        print("No protein atoms found in trajectory")

# ============================================================================
# Download Results
# ============================================================================

import sys
from pathlib import Path

IN_COLAB = 'google.colab' in sys.modules

if workflow_outputs and workflow_outputs.get('output_dir'):
    output_dir = Path(workflow_outputs['output_dir'])
    
    if output_dir.exists():
        print(f"Output directory: {output_dir}")
        print("\nGenerated files:")
        for f in sorted(output_dir.glob('*')):
            size_kb = f.stat().st_size / 1024
            print(f"  {f.name} ({size_kb:.1f} KB)")
        
        if IN_COLAB:
            from google.colab import files
            import shutil
            
            zip_name = f"{output_dir.name}.zip"
            shutil.make_archive(output_dir.name, 'zip', output_dir)
            print(f"\nüì• Downloading {zip_name}...")
            files.download(zip_name)
        else:
            print(f"\nFiles are in: {output_dir}")
    else:
        print("Output directory not found.")
else:
    print("‚ùå No workflow outputs available. Run the MD workflow first.")

In [None]:
# ============================================================================
# RMSD Analysis
# ============================================================================

import matplotlib.pyplot as plt
import mdtraj as md
import numpy as np

if workflow_outputs and workflow_outputs.get('trajectory'):
    traj_file = workflow_outputs['trajectory']
    top_file = workflow_outputs['parm7']
    
    print(f"Loading trajectory: {traj_file}")
    traj = md.load(traj_file, top=top_file)
    
    protein_atoms = traj.topology.select('protein and name CA')
    
    if len(protein_atoms) > 0:
        rmsd = md.rmsd(traj, traj, 0, atom_indices=protein_atoms) * 10  # Angstrom
        
        fig, ax = plt.subplots(figsize=(10, 4))
        ax.plot(traj.time, rmsd, 'b-', linewidth=1)
        ax.set_xlabel('Time (ps)')
        ax.set_ylabel('RMSD (√Ö)')
        ax.set_title('CŒ± RMSD vs Time')
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()
        
        print(f"\nRMSD Statistics:")
        print(f"  Mean:  {np.mean(rmsd):.2f} √Ö")
        print(f"  Max:   {np.max(rmsd):.2f} √Ö")
        print(f"  Final: {rmsd[-1]:.2f} √Ö")
    else:
        print("No protein CŒ± atoms found")
else:
    print("‚ùå No trajectory available. Run the MD workflow first.")

---

## Next Steps

1. **Longer simulations**: Modify the simulation time in your chat request
2. **Analysis**: Use MDTraj for RMSD, RMSF, hydrogen bonds, etc.
3. **Different systems**: Try membrane proteins, protein-ligand complexes
4. **Batch processing**: Use `main.py batch` command for automated runs

For more information, see the [GitHub repository](https://github.com/matsunagalab/mcp-md).