<a href="https://colab.research.google.com/github/matsunagalab/mcp-md/blob/main/colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MCP-MD: AI-Powered Molecular Dynamics Setup

This notebook demonstrates the complete MD simulation workflow using MCP-MD:

**Workflow Steps:**
1. **Fetch** - Download structure from RCSB PDB
2. **Prepare** - Clean protein + parameterize ligands (GAFF2/AM1-BCC)
3. **Solvate** - Add water box + ions (packmol-memgen)
4. **Build** - Generate Amber topology (tleap)
5. **Simulate** - Run MD with OpenMM (NPT ensemble)
6. **Visualize** - Interactive 3D trajectory animation

**Test Case: 1AKE (Adenylate Kinase)**
- Homodimer with AP5 ligand
- Chain A extraction + ligand parameterization
- Short NPT simulation (100 ps)

---

## Quick Start

1. **Cell 1**: Install condacolab (triggers runtime restart)
2. **Cell 2**: Install dependencies (run after restart)
3. **Cell 3+**: Run workflow cells in order

**Runtime**: Go to `Runtime > Change runtime type > GPU` for faster simulation

---
## Setup 1/2: Install condacolab

**The runtime will restart after this cell. This is expected!**

In [None]:
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Installing condacolab...")
    print("The runtime will restart. Run the next cell after restart.")
    !pip install -q condacolab
    import condacolab
    condacolab.install()
else:
    print("Not running in Colab - skipping condacolab setup")
    print("Make sure you have conda environment with AmberTools installed.")

---
## Setup 2/2: Install Dependencies

**Run this cell AFTER the runtime restarts.**

Installs AmberTools, OpenMM, RDKit, and project dependencies (~5-10 min)

In [None]:
import sys
import time

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    import condacolab
    condacolab.check()
    
    start_time = time.time()
    
    # Install conda packages (AmberTools + heavy scientific packages)
    print("="*60)
    print("Installing AmberTools + scientific packages via conda...")
    print("This takes ~5-10 minutes.")
    print("="*60)
    !conda install -y -c conda-forge ambertools=23 openmm rdkit pdbfixer 2>&1 | tail -20
    print(f"Conda packages installed ({time.time() - start_time:.0f}s)")
    
    # Clone repository and install
    print("\nCloning mcp-md repository...")
    !git clone -q https://github.com/matsunagalab/mcp-md.git /content/mcp-md
    %cd /content/mcp-md
    
    print("Installing Python dependencies...")
    !pip install -q -e .
    !pip install -q py3Dmol mdtraj
    
    sys.path.insert(0, '/content/mcp-md')
    
    total_time = time.time() - start_time
    print(f"\nSetup complete! ({total_time/60:.1f} minutes)")
    print("="*60)

else:
    # Local development
    sys.path.insert(0, '.')
    print("Local environment - dependencies should be pre-installed.")

---
## Verify Installation

In [None]:
import shutil

print("Checking dependencies...")
print()

# Python packages
print("Python Packages:")
packages = ["rdkit", "pdbfixer", "openmm", "py3Dmol", "mdtraj", "dimorphite_dl"]
for pkg in packages:
    try:
        __import__(pkg)
        print(f"  {pkg}")
    except ImportError:
        print(f"  {pkg} (MISSING)")

# AmberTools
print("\nAmberTools:")
for tool in ["antechamber", "tleap", "pdb4amber", "packmol-memgen"]:
    path = shutil.which(tool)
    print(f"  {tool}" if path else f"  {tool} (MISSING)")

# OpenMM platforms
print("\nOpenMM Platforms:")
import openmm as mm
for i in range(mm.Platform.getNumPlatforms()):
    name = mm.Platform.getPlatform(i).getName()
    gpu = " (GPU)" if name in ['CUDA', 'OpenCL'] else ""
    print(f"  {name}{gpu}")

---
## Configuration

Configure the target structure and output directory.

In [None]:
import sys
from pathlib import Path

# Target structure
PDB_ID = "1AKE"           # Adenylate Kinase homodimer
SELECT_CHAINS = ["A"]     # Chain A (protein + AP5 ligand)

# Output directory
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    OUTPUT_DIR = Path("/content/mcp-md/output")
else:
    OUTPUT_DIR = Path("./output")
OUTPUT_DIR.mkdir(exist_ok=True)

print(f"Target: PDB {PDB_ID}, Chain(s) {SELECT_CHAINS}")
print(f"Output: {OUTPUT_DIR}")

---
# Step 1: Fetch Structure from PDB

Download the structure from RCSB PDB database.

In [None]:
import importlib
import servers.structure_server as structure_module
importlib.reload(structure_module)

fetch_molecules = structure_module.fetch_molecules

print(f"Fetching {PDB_ID} from RCSB PDB...")
fetch_result = await fetch_molecules(pdb_id=PDB_ID, source="pdb", prefer_format="pdb")

if fetch_result["success"]:
    structure_file = fetch_result["file_path"]
    print(f"  File: {Path(structure_file).name}")
    print(f"  Atoms: {fetch_result['num_atoms']}")
    print(f"  Chains: {fetch_result['chains']}")
else:
    raise RuntimeError(f"Failed: {fetch_result['errors']}")

### Visualize: Original Structure

In [None]:
import py3Dmol

with open(structure_file, 'r') as f:
    pdb_content = f.read()

view = py3Dmol.view(width=800, height=500)
view.addModel(pdb_content, 'pdb')

# Protein: cartoon, colored by chain
view.setStyle({'chain': 'A'}, {'cartoon': {'color': 'blue'}})
view.setStyle({'chain': 'B'}, {'cartoon': {'color': 'orange'}})

# Ligands: sticks
view.setStyle({'resn': 'AP5'}, {'stick': {'color': 'green', 'radius': 0.3}})
view.addResLabels({'resn': 'AP5'}, {'fontSize': 12, 'fontColor': 'white',
                                     'backgroundColor': 'green', 'backgroundOpacity': 0.8})

view.zoomTo()
print(f"Original structure: Chain A (blue), Chain B (orange), AP5 ligands (green)")
view.show()

---
# Step 2: Prepare Complex

Clean protein + parameterize ligands with GAFF2/AM1-BCC.

This step:
1. Extracts selected chain(s)
2. Repairs missing atoms (PDBFixer)
3. Protonates at pH 7.4
4. Parameterizes ligands (antechamber)

In [None]:
prepare_complex = structure_module.prepare_complex

print(f"Preparing complex (Chain {SELECT_CHAINS})...")
print("This may take 2-5 minutes for ligand parameterization.")
print()

complex_result = prepare_complex(
    structure_file=structure_file,
    select_chains=SELECT_CHAINS,
    ph=7.4,
    process_proteins=True,
    process_ligands=True,
    run_parameterization=True
)

if complex_result["success"]:
    output_dir = Path(complex_result["output_dir"])
    print(f"Output: {output_dir}")
    
    print(f"\nProteins ({len(complex_result['proteins'])}):")
    for p in complex_result["proteins"]:
        status = "" if p["success"] else " (FAILED)"
        print(f"  Chain {p['chain_id']}: {Path(p['output_file']).name}{status}")
    
    print(f"\nLigands ({len(complex_result['ligands'])}):")
    for lig in complex_result["ligands"]:
        if lig["success"]:
            print(f"  {lig['ligand_id']}: charge={lig['net_charge']}")
            print(f"    mol2: {Path(lig['mol2_file']).name}")
            print(f"    frcmod: {Path(lig['frcmod_file']).name}")
        else:
            print(f"  {lig.get('ligand_id', '?')}: FAILED")
    
    merged_pdb = complex_result["merged_pdb"]
    print(f"\nMerged: {Path(merged_pdb).name}")
else:
    raise RuntimeError(f"Failed: {complex_result['errors']}")

### Visualize: Prepared Complex

In [None]:
with open(merged_pdb, 'r') as f:
    pdb_content = f.read()

view = py3Dmol.view(width=800, height=500)
view.addModel(pdb_content, 'pdb')

# Protein: cartoon spectrum
AMINO_ACIDS = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'CYX', 'GLN', 'GLU',
               'GLY', 'HIS', 'HID', 'HIE', 'HIP', 'ILE', 'LEU', 'LYS',
               'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL']
view.setStyle({'resn': AMINO_ACIDS}, {'cartoon': {'color': 'spectrum'}})

# Ligands: sticks
ligand_resnames = [lig['ligand_id'] for lig in complex_result['ligands'] if lig['success']]
for resn in ligand_resnames:
    view.setStyle({'resn': resn}, {'stick': {'color': 'green', 'radius': 0.3}})
    view.addResLabels({'resn': resn}, {'fontSize': 12, 'fontColor': 'white',
                                        'backgroundColor': 'green', 'backgroundOpacity': 0.8})

view.zoomTo()
print(f"Prepared complex: Protein (spectrum), Ligands {ligand_resnames} (green)")
view.show()

---
# Step 3: Solvate Structure

Add explicit water box (12 A padding) + ions (0.15 M NaCl).

In [None]:
import os
import json
import subprocess

# Set AMBERHOME for packmol-memgen (Colab workaround)
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    conda_info = json.loads(subprocess.run(['conda', 'info', '--json'], 
                                           capture_output=True, text=True).stdout)
    for env in conda_info.get('envs', []):
        if 'mcp-md' in env or env == conda_info.get('default_prefix'):
            os.environ["AMBERHOME"] = env
            break

import servers.solvation_server as solvation_module
importlib.reload(solvation_module)

solvate_structure = solvation_module.solvate_structure

print("Solvating structure (12 A box, 0.15 M NaCl)...")
solvate_result = solvate_structure(
    pdb_file=str(Path(merged_pdb).resolve()),
    output_dir=str(output_dir.resolve()),
    output_name="solvated",
    dist=12.0,
    cubic=True,
    salt=True,
    saltcon=0.15
)

if solvate_result["success"]:
    solvated_pdb = solvate_result["output_file"]
    box = solvate_result.get("box_dimensions", {})
    stats = solvate_result.get("statistics", {})
    print(f"  Output: {Path(solvated_pdb).name}")
    print(f"  Atoms: {stats.get('total_atoms', 'N/A')}")
    if box:
        print(f"  Box: {box.get('box_a', 0):.1f} x {box.get('box_b', 0):.1f} x {box.get('box_c', 0):.1f} A")
else:
    raise RuntimeError(f"Failed: {solvate_result['errors']}")

### Visualize: Solvated System

In [None]:
with open(solvated_pdb, 'r') as f:
    pdb_content = f.read()

view = py3Dmol.view(width=900, height=600)
view.addModel(pdb_content, 'pdb')

# Protein: cartoon
view.setStyle({'resn': AMINO_ACIDS}, {'cartoon': {'color': 'spectrum'}})

# Water: small spheres
view.setStyle({'resn': ['WAT', 'HOH']}, {'sphere': {'radius': 0.15, 'color': 'lightblue'}})

# Ions: spheres
view.setStyle({'resn': ['NA', 'Na+']}, {'sphere': {'radius': 0.8, 'color': 'purple'}})
view.setStyle({'resn': ['CL', 'Cl-']}, {'sphere': {'radius': 0.8, 'color': 'yellow'}})

# Ligands: sticks
for resn in ligand_resnames:
    view.setStyle({'resn': resn}, {'stick': {'color': 'green', 'radius': 0.3}})

view.zoomTo()
print(f"Solvated: Protein (spectrum), Water (blue dots), Na+ (purple), Cl- (yellow)")
view.show()

---
# Step 4: Build Amber System

Generate Amber topology (parm7) and coordinates (rst7) using tleap.

In [None]:
import servers.amber_server as amber_module
importlib.reload(amber_module)

build_amber_system = amber_module.build_amber_system

# Collect ligand parameters
ligand_params = []
for lig in complex_result.get("ligands", []):
    if lig.get("success") and lig.get("mol2_file"):
        ligand_params.append({
            "mol2": lig["mol2_file"],
            "frcmod": lig["frcmod_file"],
            "residue_name": lig["ligand_id"][:3].upper()
        })

print(f"Building Amber system ({len(ligand_params)} ligand(s))...")
amber_result = build_amber_system(
    pdb_file=solvate_result["output_file"],
    ligand_params=ligand_params if ligand_params else None,
    box_dimensions=solvate_result.get("box_dimensions"),
    water_model="tip3p",
    output_name="system"
)

if amber_result['success']:
    parm7_file = amber_result['parm7']
    rst7_file = amber_result['rst7']
    print(f"  Topology: {Path(parm7_file).name}")
    print(f"  Coordinates: {Path(rst7_file).name}")
else:
    raise RuntimeError(f"Failed: {amber_result['errors']}")

---
# Step 5: Run MD Simulation

Run NPT simulation with OpenMM:
- Energy minimization (500 steps)
- NPT equilibration (5 ps)
- NPT production (100 ps)

Parameters: 300 K, 1 atm, 2 fs timestep

In [None]:
import time
import openmm as mm
from openmm import app, unit
from openmm.app import AmberPrmtopFile, AmberInpcrdFile, Simulation, DCDReporter, StateDataReporter, PDBFile

def select_platform():
    """Select best available platform."""
    for name in ['CUDA', 'OpenCL', 'CPU']:
        try:
            platform = mm.Platform.getPlatformByName(name)
            if name == 'CUDA':
                platform.getPropertyDefaultValue('DeviceIndex')
            return platform, name
        except:
            continue
    return mm.Platform.getPlatformByName('CPU'), 'CPU'

# Simulation parameters
temperature = 300 * unit.kelvin
pressure = 1 * unit.atmosphere
timestep = 2 * unit.femtoseconds
minimize_steps = 500
equil_steps = 2500      # 5 ps
prod_steps = 50000      # 100 ps
report_interval = 500   # Every 1 ps

platform, platform_name = select_platform()
print(f"Platform: {platform_name}")
print(f"Simulation: {(equil_steps + prod_steps) * 2 / 1000:.0f} ps total")
print()

# Load system
print("[1/5] Loading Amber files...")
prmtop = AmberPrmtopFile(parm7_file)
inpcrd = AmberInpcrdFile(rst7_file)
print(f"  Atoms: {prmtop.topology.getNumAtoms()}")

# Create system
print("[2/5] Creating system...")
system = prmtop.createSystem(
    nonbondedMethod=app.PME,
    nonbondedCutoff=10 * unit.angstrom,
    constraints=app.HBonds,
    rigidWater=True
)
system.addForce(mm.MonteCarloBarostat(pressure, temperature, 25))

integrator = mm.LangevinMiddleIntegrator(temperature, 1/unit.picosecond, timestep)
simulation = Simulation(prmtop.topology, system, integrator, platform)
simulation.context.setPositions(inpcrd.positions)
if inpcrd.boxVectors:
    simulation.context.setPeriodicBoxVectors(*inpcrd.boxVectors)

# Minimize
print("[3/5] Energy minimization...")
t0 = time.time()
simulation.minimizeEnergy(maxIterations=minimize_steps)
print(f"  Done ({time.time() - t0:.1f}s)")

simulation.context.setVelocitiesToTemperature(temperature)

# Setup reporters
dcd_file = output_dir / "trajectory.dcd"
log_file = output_dir / "simulation.log"
simulation.reporters.append(DCDReporter(str(dcd_file), report_interval))
simulation.reporters.append(StateDataReporter(
    str(log_file), report_interval,
    step=True, time=True, potentialEnergy=True, temperature=True, speed=True
))
simulation.reporters.append(StateDataReporter(
    sys.stdout, report_interval * 10,
    step=True, time=True, temperature=True, speed=True, remainingTime=True,
    totalSteps=equil_steps + prod_steps
))

# Equilibration
print(f"[4/5] Equilibration ({equil_steps * 2 / 1000:.0f} ps)...")
t0 = time.time()
simulation.step(equil_steps)
print(f"  Done ({time.time() - t0:.1f}s)")

# Production
print(f"[5/5] Production ({prod_steps * 2 / 1000:.0f} ps)...")
t0 = time.time()
simulation.step(prod_steps)
print(f"  Done ({time.time() - t0:.1f}s)")

# Save final state
final_pdb = output_dir / "final_state.pdb"
state = simulation.context.getState(getPositions=True)
with open(final_pdb, 'w') as f:
    PDBFile.writeFile(simulation.topology, state.getPositions(), f)

print(f"\nSimulation complete!")
print(f"  Trajectory: {dcd_file.name}")
print(f"  Final state: {final_pdb.name}")

---
# Step 6: Visualize Trajectory

Interactive 3D animation of the MD trajectory (protein + ligand only).

In [None]:
import numpy as np
import tempfile
import mdtraj as md

print("Loading trajectory...")
traj = md.load(str(dcd_file), top=solvated_pdb)
print(f"  Frames: {traj.n_frames}")
print(f"  Time: {traj.time[0]:.1f} - {traj.time[-1]:.1f} ps")

# Select protein + ligand (exclude water/ions)
protein_indices = traj.topology.select('protein')
lig_indices = []
ligand_resnames = set()
standard_res = {'ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'CYX', 'GLN', 'GLU',
                'GLY', 'HIS', 'HID', 'HIE', 'HIP', 'ILE', 'LEU', 'LYS',
                'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL',
                'WAT', 'HOH', 'NA', 'CL', 'Na+', 'Cl-'}
for residue in traj.topology.residues:
    if residue.name not in standard_res:
        lig_indices.extend([atom.index for atom in residue.atoms])
        ligand_resnames.add(residue.name)
lig_indices = np.array(lig_indices) if lig_indices else np.array([], dtype=int)

keep_indices = np.unique(np.concatenate([protein_indices, lig_indices])) if len(lig_indices) > 0 else protein_indices
traj_subset = traj.atom_slice(keep_indices)

# Sample frames for visualization
max_frames = 20
if traj_subset.n_frames > max_frames:
    frame_indices = np.linspace(0, traj_subset.n_frames - 1, max_frames, dtype=int)
    traj_viz = traj_subset[frame_indices]
else:
    traj_viz = traj_subset

print(f"  Visualization: {traj_viz.n_atoms} atoms, {traj_viz.n_frames} frames")

# Write multi-model PDB
with tempfile.NamedTemporaryFile(suffix='.pdb', delete=False, mode='w') as tmp:
    tmp_path = tmp.name

with open(tmp_path, 'w') as f:
    for i in range(traj_viz.n_frames):
        frame_tmp = tmp_path + f".frame{i}.pdb"
        traj_viz[i].save_pdb(frame_tmp, force_overwrite=True)
        with open(frame_tmp, 'r') as ff:
            content = ff.read()
        f.write(f"MODEL     {i + 1}\n")
        for line in content.split('\n'):
            if not line.startswith('MODEL') and not line.startswith('ENDMDL') and line.strip():
                f.write(line + '\n')
        f.write("ENDMDL\n")
        Path(frame_tmp).unlink()

with open(tmp_path, 'r') as f:
    pdb_content = f.read()
Path(tmp_path).unlink()

# Create animated view
view = py3Dmol.view(width=800, height=600)
view.addModelsAsFrames(pdb_content, 'pdb')

aa_list = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'CYX', 'GLN', 'GLU',
           'GLY', 'HIS', 'HID', 'HIE', 'HIP', 'ILE', 'LEU', 'LYS',
           'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL']
view.setStyle({'resn': aa_list}, {'cartoon': {'color': 'spectrum'}})

for resn in sorted(ligand_resnames):
    view.setStyle({'resn': resn}, {'stick': {'color': 'green', 'radius': 0.3}})
    view.addResLabels({'resn': resn}, {'fontSize': 12, 'fontColor': 'white',
                                        'backgroundColor': 'green', 'backgroundOpacity': 0.8})

view.zoomTo()
view.animate({'loop': 'forward', 'reps': 0, 'interval': 100})

print(f"\nAnimation: {traj_viz.n_frames} frames")
print(f"Protein (spectrum), Ligands {list(ligand_resnames)} (green)")
view.show()

---
# Summary

Complete MD workflow executed successfully!

In [None]:
print("="*60)
print("MCP-MD WORKFLOW COMPLETE")
print("="*60)

print(f"\nTarget: PDB {PDB_ID}, Chain(s) {SELECT_CHAINS}")
print(f"Output: {output_dir}")

print(f"\nGenerated Files:")
files = [
    ("Structure", Path(structure_file).name),
    ("Merged complex", Path(merged_pdb).name),
    ("Solvated system", Path(solvated_pdb).name),
    ("Amber topology", Path(parm7_file).name),
    ("Amber coordinates", Path(rst7_file).name),
    ("Trajectory", dcd_file.name),
    ("Final state", final_pdb.name),
]
for label, fname in files:
    print(f"  {label}: {fname}")

print(f"\nLigand Parameters:")
for lig in complex_result.get("ligands", []):
    if lig.get("success"):
        print(f"  {lig['ligand_id']}: {Path(lig['mol2_file']).name}")

print(f"\nWorkflow Complete!")

---

## Next Steps

1. **Longer simulations**: Increase `prod_steps` for production runs (e.g., 500000 = 1 ns)
2. **Analysis**: Use MDTraj for RMSD, RMSF, hydrogen bonds, etc.
3. **Different systems**: Change `PDB_ID` and `SELECT_CHAINS` for your target
4. **Membrane systems**: Use `embed_in_membrane` instead of `solvate_structure`

For more information, see the [GitHub repository](https://github.com/matsunagalab/mcp-md).