<a href="https://colab.research.google.com/github/matsunagalab/mdzen/blob/main/colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MDZen: AI-Powered Molecular Dynamics Agent

**Interactive AI assistant for setting up MD simulations**

## Workflow

1. **Setup** - Install dependencies (Konda + AmberTools)
2. **Phase 1: Clarification** - Describe your simulation, AI generates SimulationBrief
3. **Edit Brief** - Review and customize simulation parameters
4. **Phase 2: Execute** - Run workflow step by step
5. **Visualization** - View trajectory animation with py3Dmol
6. **Download** - Get all generated files

---

## Quick Start

1. Set **API key** in Colab secrets (ANTHROPIC_API_KEY, OPENAI_API_KEY, or GOOGLE_API_KEY)
2. Run **Setup** cell (~5-10 min)
3. Run **Phase 1** - describe your simulation
4. **Edit Brief** - customize parameters if needed
5. Run **Phase 2** - click buttons for each step
6. **Visualize** and **Download** results

---
## Setup: Install Konda and Dependencies

**Konda** is a simple wrapper for conda in Google Colab.
- No kernel restart needed (unlike condacolab)
- Uses conda for package installation
- Installation takes ~5-10 minutes

**API Key**: Set one of the following in Colab secrets:
- `ANTHROPIC_API_KEY` for Claude
- `OPENAI_API_KEY` for GPT-4
- `GOOGLE_API_KEY` for Gemini

In [None]:
#@title ‚ñ∂Ô∏è Run Setup (click to expand code)
import sys
import os
import time

IN_COLAB = 'google.colab' in sys.modules

# Detect and set API keys from Colab secrets
detected_provider = None

if IN_COLAB:
    from google.colab import userdata
    api_keys = {
        'ANTHROPIC_API_KEY': 'anthropic',
        'OPENAI_API_KEY': 'openai',
        'GOOGLE_API_KEY': 'google',
    }
    for key_name, provider in api_keys.items():
        try:
            key_value = userdata.get(key_name)
            if key_value:
                os.environ[key_name] = key_value
                if detected_provider is None:
                    detected_provider = provider
                print(f"‚úì {key_name} loaded ({provider})")
        except:
            pass
    if detected_provider is None:
        print("‚ö†Ô∏è No API key found! Add to Colab Secrets.")
    else:
        print(f"ü§ñ Using {detected_provider.upper()}")
else:
    for key_name in ['ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'GOOGLE_API_KEY']:
        if os.environ.get(key_name):
            detected_provider = key_name.split('_')[0].lower()
            break

if IN_COLAB:
    start_time = time.time()
    print("\nüì¶ Installing Konda...")
    !pip install -q konda
    import konda
    konda.install()

    print("üìã Accepting conda ToS...")
    !conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main 2>/dev/null || true
    !conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r 2>/dev/null || true

    print("üêç Creating Python 3.11 environment...")
    !conda create -n mdzen python=3.11 -y 2>&1 | tail -3

    print("‚öóÔ∏è Installing AmberTools (5-10 min)...")
    !conda install -n mdzen -y -c conda-forge ambertools=23 openmm rdkit pdbfixer 2>&1 | tail -5
    print(f"‚úì Conda packages ({time.time() - start_time:.0f}s)")

    print("üì• Cloning repository...")
    !rm -rf /content/mdzen
    !git clone -q https://github.com/matsunagalab/mdzen.git /content/mdzen
    %cd /content/mdzen

    # Install Python packages explicitly (no -q to see errors)
    print("üì¶ Installing Python packages...")
    !pip install gradio py3Dmol nest_asyncio matplotlib mdtraj
    !pip install litellm google-adk google-genai anthropic
    !pip install fastmcp "mcp[cli]" pydantic gemmi parmed httpx pdb2pqr propka dimorphite-dl

    # Verify critical packages
    import importlib
    for pkg in ['litellm', 'gradio', 'google.adk', 'anthropic']:
        try:
            importlib.import_module(pkg.replace('.', '_') if '.' in pkg else pkg)
            print(f"‚úì {pkg}")
        except ImportError as e:
            print(f"‚ùå {pkg}: {e}")

    # Set environment variables
    os.environ["AMBERHOME"] = "/usr/local/envs/mdzen"
    os.environ["MDZEN_CONDA_ENV"] = "mdzen"
    os.environ["PATH"] = f"/usr/local/envs/mdzen/bin:{os.environ['PATH']}"
    
    # Add mdzen to path (instead of pip install -e .)
    sys.path.insert(0, '/content/mdzen/src')
    sys.path.insert(0, '/content/mdzen')

    print(f"\n‚úÖ Setup complete! ({(time.time() - start_time)/60:.1f} min)")
else:
    sys.path.insert(0, './src')
    sys.path.insert(0, '.')
    print("Local environment")

---
## Phase 1: Clarification Chat

Describe your simulation and the AI agent will ask clarifying questions to generate a **SimulationBrief**.

**Example prompts:**
- "Setup MD for PDB 1AKE in water, 1 ns at 300K"
- "I want to simulate lysozyme (PDB 1LYZ) with explicit solvent"
- "Run a short simulation of insulin (PDB 4INS), chain A only"

After the brief is generated, proceed to the next cell to **review and edit** the parameters.

In [None]:
#@title ‚ñ∂Ô∏è Phase 1: Clarification Chat
import gradio as gr
import asyncio
import nest_asyncio
import json
import sys
from pathlib import Path

nest_asyncio.apply()

# Global State
if 'mdzen_state' not in dir():
    mdzen_state = {"session_id": None, "session_service": None, "session_dir": None, "simulation_brief": None, "workflow_outputs": {}}

def init_session():
    import random, string
    job_id = ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))
    base_dir = Path("/content/mdzen/outputs") if 'google.colab' in sys.modules else Path("./outputs")
    session_dir = base_dir / f"job_{job_id}"
    session_dir.mkdir(parents=True, exist_ok=True)
    mdzen_state["session_id"] = f"job_{job_id}"
    mdzen_state["session_dir"] = str(session_dir)
    return session_dir

def phase1_chat(message, history):
    import traceback
    try:
        loop = asyncio.get_event_loop()
        if mdzen_state["session_dir"] is None: init_session()
        from mdzen.agents.clarification_agent import create_clarification_agent
        from google.adk.runners import Runner
        from google.genai import types
        from mdzen.state.session_manager import create_session_service, initialize_session_state, get_session_state
        
        if mdzen_state["session_service"] is None:
            db_path = Path(mdzen_state["session_dir"]) / "session.db"
            mdzen_state["session_service"] = create_session_service(str(db_path), in_memory=False)
            loop.run_until_complete(initialize_session_state(
                session_service=mdzen_state["session_service"], app_name="mdzen",
                user_id="default", session_id=mdzen_state["session_id"], session_dir=mdzen_state["session_dir"]))
        
        agent, toolsets = create_clarification_agent()
        runner = Runner(app_name="mdzen", agent=agent, session_service=mdzen_state["session_service"])
        user_message = types.Content(role="user", parts=[types.Part(text=message)])
        response_text = ""
        
        async def run_agent():
            nonlocal response_text
            async for event in runner.run_async(user_id="default", session_id=mdzen_state["session_id"], new_message=user_message):
                if event.is_final_response() and event.content:
                    if hasattr(event.content, 'parts'):
                        for part in event.content.parts:
                            if hasattr(part, 'text'): response_text += part.text
                    else: response_text = str(event.content)
        loop.run_until_complete(run_agent())
        
        state = loop.run_until_complete(get_session_state(mdzen_state["session_service"], "mdzen", "default", mdzen_state["session_id"]))
        if state and state.get("simulation_brief"):
            mdzen_state["simulation_brief"] = state["simulation_brief"]
            brief = mdzen_state["simulation_brief"]
            if isinstance(brief, dict):
                response_text += f"\n\n---\n‚úÖ **SimulationBrief Generated!**\n- PDB: {brief.get('pdb_id', 'N/A')}\n- Temperature: {brief.get('temperature', 300)}K\n- Time: {brief.get('simulation_time_ns', 1.0)}ns\n\n**‚Üí Run next cell to edit brief.**"
        yield response_text if response_text else "Processing..."
        for toolset in toolsets: loop.run_until_complete(toolset.close())
    except Exception as e:
        yield f"Error: {e}\n\n{traceback.format_exc()}"

with gr.Blocks() as phase1_demo:
    gr.Markdown("## üó£Ô∏è Describe Your Simulation")
    gr.ChatInterface(
        fn=phase1_chat,
        type="messages",
        examples=["Setup MD for PDB 1AKE in water, 1 ns at 300K", "Simulate lysozyme (PDB 1LYZ) with explicit solvent"],
    )
phase1_demo.launch(share=True, debug=True)

---
## SimulationBrief Editor

Review and edit the simulation parameters before executing the workflow.

**Fields:**
- **Structure**: PDB ID, AlphaFold ID, or FASTA sequence
- **Simulation**: Temperature, pressure, simulation time
- **Solvation**: Water model, box padding, salt concentration
- **Options**: Chain selection, force field

Click **"Save Brief"** when done, then proceed to Phase 2.

In [None]:
#@title ‚ñ∂Ô∏è SimulationBrief Editor
import gradio as gr
import json

def load_brief():
    brief = mdzen_state.get("simulation_brief", {})
    if isinstance(brief, str): brief = json.loads(brief)
    return (brief.get("pdb_id", ""), brief.get("alphafold_id", ""), brief.get("fasta_sequence", ""),
            brief.get("ligand_smiles", ""), brief.get("select_chains", ""), brief.get("temperature", 300),
            brief.get("pressure_bar", 1.0), brief.get("simulation_time_ns", 1.0), brief.get("water_model", "tip3p"),
            brief.get("box_padding", 12.0), brief.get("salt_concentration", 0.15), brief.get("force_field", "ff19SB"))

def save_brief(pdb_id, alphafold_id, fasta_sequence, ligand_smiles, select_chains,
               temperature, pressure_bar, simulation_time_ns, water_model, box_padding, salt_concentration, force_field):
    brief = {"pdb_id": pdb_id or None, "alphafold_id": alphafold_id or None, "fasta_sequence": fasta_sequence or None,
             "ligand_smiles": ligand_smiles or None, "select_chains": select_chains or None,
             "temperature": float(temperature), "pressure_bar": float(pressure_bar), "simulation_time_ns": float(simulation_time_ns),
             "water_model": water_model, "box_padding": float(box_padding), "salt_concentration": float(salt_concentration),
             "force_field": force_field, "ensemble": "NPT", "cubic_box": True, "ph": 7.4}
    mdzen_state["simulation_brief"] = brief
    return f"‚úÖ Brief saved!\n\n```json\n{json.dumps(brief, indent=2)}\n```\n\n**‚Üí Run next cell to execute workflow.**"

with gr.Blocks() as editor_demo:
    gr.Markdown("## ‚úèÔ∏è Edit SimulationBrief")
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Structure")
            pdb_id = gr.Textbox(label="PDB ID", placeholder="e.g., 1AKE")
            alphafold_id = gr.Textbox(label="AlphaFold ID", placeholder="e.g., AF-P00533-F1")
            fasta_sequence = gr.Textbox(label="FASTA Sequence", lines=2)
            ligand_smiles = gr.Textbox(label="Ligand SMILES")
            select_chains = gr.Textbox(label="Select Chains", placeholder="e.g., A,B")
        with gr.Column():
            gr.Markdown("### Parameters")
            temperature = gr.Slider(label="Temperature (K)", minimum=250, maximum=400, value=300, step=5)
            pressure_bar = gr.Number(label="Pressure (bar)", value=1.0)
            simulation_time_ns = gr.Slider(label="Simulation Time (ns)", minimum=0.01, maximum=100, value=1.0, step=0.1)
            water_model = gr.Dropdown(label="Water Model", choices=["tip3p", "tip4pew", "opc", "spce"], value="tip3p")
            box_padding = gr.Slider(label="Box Padding (√Ö)", minimum=8, maximum=20, value=12, step=1)
            salt_concentration = gr.Slider(label="Salt (M)", minimum=0, maximum=0.5, value=0.15, step=0.01)
            force_field = gr.Dropdown(label="Force Field", choices=["ff19SB", "ff14SB", "ff99SB"], value="ff19SB")
    with gr.Row():
        load_btn = gr.Button("üì• Load", variant="secondary")
        save_btn = gr.Button("üíæ Save", variant="primary")
    output = gr.Markdown()
    load_btn.click(load_brief, outputs=[pdb_id, alphafold_id, fasta_sequence, ligand_smiles, select_chains,
                                         temperature, pressure_bar, simulation_time_ns, water_model, box_padding, salt_concentration, force_field])
    save_btn.click(save_brief, inputs=[pdb_id, alphafold_id, fasta_sequence, ligand_smiles, select_chains,
                                        temperature, pressure_bar, simulation_time_ns, water_model, box_padding, salt_concentration, force_field], outputs=output)
editor_demo.launch(share=True, debug=True)

---
## Phase 2: Execute Workflow

Execute the MD workflow step by step:

1. **prepare_complex** - Fetch structure and parameterize ligands
2. **solvate** - Add water box and ions
3. **build_topology** - Generate Amber topology files
4. **run_simulation** - Run MD with OpenMM

Click each button to execute that step. Progress is shown below.

In [None]:
#@title ‚ñ∂Ô∏è Phase 2: Execute Workflow
import gradio as gr
import traceback
from pathlib import Path
import asyncio

def run_prepare_complex():
    try:
        brief = mdzen_state.get("simulation_brief")
        if not brief: return "‚ùå No brief. Run Phase 1 first."
        session_dir = Path(mdzen_state["session_dir"])
        import importlib, servers.structure_server as mod
        importlib.reload(mod)
        pdb_id = brief.get('pdb_id')
        if not pdb_id: return "‚ùå No PDB ID"
        loop = asyncio.get_event_loop()
        fetch = loop.run_until_complete(mod.fetch_molecules(pdb_id=pdb_id, source="pdb", prefer_format="pdb", output_dir=str(session_dir)))
        if not fetch["success"]: return f"‚ùå Fetch failed: {fetch.get('errors')}"
        result = mod.prepare_complex(structure_file=fetch["file_path"], select_chains=brief.get('select_chains'), ph=brief.get('ph', 7.4),
                                     process_proteins=True, process_ligands=True, run_parameterization=True, output_dir=str(session_dir))
        if not result["success"]: return f"‚ùå Prepare failed: {result.get('errors')}"
        mdzen_state["workflow_outputs"].update({"structure_file": fetch["file_path"], "merged_pdb": result["merged_pdb"], "complex_result": result})
        return f"‚úÖ **prepare_complex**\n- Proteins: {len(result['proteins'])}\n- Ligands: {len(result['ligands'])}"
    except Exception as e: return f"‚ùå {e}\n{traceback.format_exc()}"

def run_solvate():
    try:
        brief, session_dir = mdzen_state.get("simulation_brief"), Path(mdzen_state["session_dir"])
        merged_pdb = mdzen_state["workflow_outputs"].get("merged_pdb")
        if not merged_pdb: return "‚ùå Run prepare_complex first"
        import importlib, servers.solvation_server as mod
        importlib.reload(mod)
        result = mod.solvate_structure(pdb_file=str(Path(merged_pdb).resolve()), output_dir=str(session_dir), output_name="solvated",
                                       dist=brief.get('box_padding', 12.0), cubic=brief.get('cubic_box', True), salt=True, saltcon=brief.get('salt_concentration', 0.15))
        if not result["success"]: return f"‚ùå {result.get('errors')}"
        mdzen_state["workflow_outputs"].update({"solvated_pdb": result["output_file"], "box_dimensions": result.get("box_dimensions")})
        stats = result.get('statistics', {})
        return f"‚úÖ **solvate**\n- Atoms: {stats.get('total_atoms', '?')}\n- Water: {stats.get('water_molecules', '?')}"
    except Exception as e: return f"‚ùå {e}\n{traceback.format_exc()}"

def run_build_topology():
    try:
        brief, session_dir = mdzen_state.get("simulation_brief"), Path(mdzen_state["session_dir"])
        solvated_pdb = mdzen_state["workflow_outputs"].get("solvated_pdb")
        if not solvated_pdb: return "‚ùå Run solvate first"
        import importlib, servers.amber_server as mod
        importlib.reload(mod)
        ligand_params = [{"mol2": l["mol2_file"], "frcmod": l["frcmod_file"], "residue_name": l["ligand_id"][:3].upper()}
                         for l in mdzen_state["workflow_outputs"].get("complex_result", {}).get("ligands", []) if l.get("success") and l.get("mol2_file")]
        result = mod.build_amber_system(pdb_file=solvated_pdb, ligand_params=ligand_params or None,
                                        box_dimensions=mdzen_state["workflow_outputs"].get("box_dimensions"),
                                        water_model=brief.get('water_model', 'tip3p'), output_name="system", output_dir=str(session_dir))
        if not result['success']: return f"‚ùå {result.get('errors')}"
        mdzen_state["workflow_outputs"].update({"parm7": result['parm7'], "rst7": result['rst7']})
        return f"‚úÖ **build_topology**\n- {Path(result['parm7']).name}\n- {Path(result['rst7']).name}"
    except Exception as e: return f"‚ùå {e}\n{traceback.format_exc()}"

def run_simulation():
    try:
        brief, session_dir = mdzen_state.get("simulation_brief"), Path(mdzen_state["session_dir"])
        parm7, rst7 = mdzen_state["workflow_outputs"].get("parm7"), mdzen_state["workflow_outputs"].get("rst7")
        if not parm7 or not rst7: return "‚ùå Run build_topology first"
        import openmm as mm
        from openmm import app, unit
        from openmm.app import AmberPrmtopFile, AmberInpcrdFile, Simulation, DCDReporter, PDBFile
        platform, platform_name = None, "CPU"
        for name in ['CUDA', 'OpenCL', 'CPU']:
            try: platform, platform_name = mm.Platform.getPlatformByName(name), name; break
            except: pass
        prmtop, inpcrd = AmberPrmtopFile(parm7), AmberInpcrdFile(rst7)
        temp, pres = brief.get('temperature', 300.0) * unit.kelvin, (brief.get('pressure_bar') or 1.0) * unit.atmosphere
        sim_time = brief.get('simulation_time_ns', 0.1)
        system = prmtop.createSystem(nonbondedMethod=app.PME, nonbondedCutoff=10*unit.angstrom, constraints=app.HBonds, rigidWater=True)
        system.addForce(mm.MonteCarloBarostat(pres, temp, 25))
        integrator = mm.LangevinMiddleIntegrator(temp, 1/unit.picosecond, 2.0*unit.femtoseconds)
        sim = Simulation(prmtop.topology, system, integrator, platform)
        sim.context.setPositions(inpcrd.positions)
        if inpcrd.boxVectors: sim.context.setPeriodicBoxVectors(*inpcrd.boxVectors)
        sim.minimizeEnergy(maxIterations=500)
        sim.context.setVelocitiesToTemperature(temp)
        md_dir = session_dir / "md_simulation"; md_dir.mkdir(exist_ok=True)
        dcd_file = md_dir / "trajectory.dcd"
        total_steps = int(sim_time * 1e6 / 2)
        sim.reporters.append(DCDReporter(str(dcd_file), max(100, total_steps // 100)))
        sim.step(total_steps)
        final_pdb = md_dir / "final_state.pdb"
        with open(final_pdb, 'w') as f: PDBFile.writeFile(sim.topology, sim.context.getState(getPositions=True).getPositions(), f)
        mdzen_state["workflow_outputs"].update({"trajectory": str(dcd_file), "final_pdb": str(final_pdb)})
        return f"‚úÖ **run_simulation**\n- Platform: {platform_name}\n- Time: {sim_time} ns\n\n**‚Üí Run Visualization cell**"
    except Exception as e: return f"‚ùå {e}\n{traceback.format_exc()}"

with gr.Blocks() as phase2_demo:
    gr.Markdown("## ‚öôÔ∏è Execute Workflow")
    output = gr.Markdown("Click a button to start...")
    with gr.Row():
        btn1 = gr.Button("1Ô∏è‚É£ prepare_complex", variant="primary")
        btn2 = gr.Button("2Ô∏è‚É£ solvate", variant="primary")
        btn3 = gr.Button("3Ô∏è‚É£ build_topology", variant="primary")
        btn4 = gr.Button("4Ô∏è‚É£ run_simulation", variant="primary")
    btn1.click(run_prepare_complex, outputs=output)
    btn2.click(run_solvate, outputs=output)
    btn3.click(run_build_topology, outputs=output)
    btn4.click(run_simulation, outputs=output)
phase2_demo.launch(share=True, debug=True)

---
## Visualization

View your trajectory animation with py3Dmol. Click the button to load and visualize the simulation results.

In [None]:
#@title ‚ñ∂Ô∏è Visualization
import py3Dmol
import numpy as np
import tempfile
from pathlib import Path

if not mdzen_state.get("workflow_outputs") or 'trajectory' not in mdzen_state["workflow_outputs"]:
    print("‚ùå No trajectory. Complete workflow first.")
else:
    import mdtraj as md
    traj = md.load(mdzen_state["workflow_outputs"]['trajectory'], top=mdzen_state["workflow_outputs"]['parm7'])
    protein = traj.atom_slice(traj.topology.select('protein'))
    frames = protein[np.linspace(0, protein.n_frames-1, min(20, protein.n_frames), dtype=int)] if protein.n_frames > 20 else protein
    
    with tempfile.NamedTemporaryFile(suffix='.pdb', delete=False, mode='w') as tmp:
        for i in range(frames.n_frames):
            ftmp = f"{tmp.name}.f{i}.pdb"
            frames[i].save_pdb(ftmp, force_overwrite=True)
            tmp.write(f"MODEL {i+1}\n")
            tmp.write(''.join(l for l in open(ftmp) if not l.startswith(('MODEL','ENDMDL')) and l.strip()))
            tmp.write("ENDMDL\n")
            Path(ftmp).unlink()
        tmp_path = tmp.name
    
    view = py3Dmol.view(width=800, height=500)
    view.addModelsAsFrames(open(tmp_path).read(), 'pdb')
    view.setStyle({'cartoon': {'color': 'spectrum'}})
    view.zoomTo()
    view.animate({'loop': 'forward', 'reps': 0, 'interval': 100})
    Path(tmp_path).unlink()
    print(f"üé¨ {frames.n_frames} frames | {traj.time[-1]:.1f} ps")
    view.show()

---
## Download Results

In [None]:
#@title ‚ñ∂Ô∏è Download Results
import sys
from pathlib import Path

if mdzen_state.get("session_dir"):
    session_dir = Path(mdzen_state["session_dir"])
    if session_dir.exists():
        print(f"üìÇ {session_dir}\n")
        for f in sorted(session_dir.rglob('*')):
            if f.is_file(): print(f"  {f.relative_to(session_dir)} ({f.stat().st_size/1024:.1f} KB)")
        if 'google.colab' in sys.modules:
            from google.colab import files
            import shutil
            shutil.make_archive(str(session_dir), 'zip', session_dir)
            print(f"\n‚¨áÔ∏è Downloading {session_dir.name}.zip...")
            files.download(f"{session_dir}.zip")
    else: print("‚ùå Session not found")
else: print("‚ùå Run workflow first")

---

## Next Steps

1. **Longer simulations**: Modify the simulation time in your Phase 1 request
2. **Analysis**: Use MDTraj for RMSD, RMSF, hydrogen bonds, etc.
3. **Different systems**: Try membrane proteins, protein-ligand complexes
4. **Command line**: Use `main.py run` for local development

For more information, see the [GitHub repository](https://github.com/matsunagalab/mdzen).