# 🧬 Prolix: Explicit Solvent MD with Water Visualization

This notebook runs a 1UAO protein simulation in **explicit TIP3P water** using PME electrostatics, then visualizes the trajectory with waters shown as low-alpha blue spheres.

> **Note**: Run this on GPU/TPU for fast PME compilation!

## 1. Setup Environment

In [None]:
# Clone repositories
!git clone https://github.com/maraxen/prolix.git
!git clone https://github.com/maraxen/priox.git prolix/priox

# Install dependencies
%cd prolix
!pip install -q uv
!uv pip install -e . --system
!uv pip install -e priox --system
!pip install -q py2Dmol biotite array_record

In [None]:
# Verify JAX backend
import jax
print(f"JAX devices: {jax.devices()}")
print(f"Backend: {jax.default_backend()}")

## 2. Load Solvated System

In [None]:
import jax.numpy as jnp
from jax import random
import numpy as np
import biotite.structure as struc

from prolix import simulate
from prolix.visualization import TrajectoryReader, save_trajectory_html
from priox.md.bridge.core import parameterize_system
from priox.physics.force_fields.loader import load_force_field
from priox.io.parsing import biotite as parsing_biotite

# Load pre-solvated PDB
pdb_path = "data/pdb/1UAO_solvated_tip3p.pdb"
atom_array = parsing_biotite.load_structure_with_hydride(pdb_path, model=1, remove_solvent=False)

# Extract box from CRYST1
box = atom_array.box
if box.ndim == 3:
    box = box[0]
box_size = jnp.array([box[0,0], box[1,1], box[2,2]])
print(f"System: {len(atom_array)} atoms")
print(f"Box: {box_size} Å")

In [None]:
# Build topology for parameterization
residues = []
atom_names = []
atom_counts = []

res_starts = struc.get_residue_starts(atom_array)

for i, start_idx in enumerate(res_starts):
    end_idx = res_starts[i+1] if i < len(res_starts)-1 else len(atom_array)
    res_atoms = atom_array[start_idx:end_idx]
    
    res_name = res_atoms.res_name[0]
    if res_name in ["HOH", "TIP3"]:
        res_name = "WAT"
    residues.append(res_name)
    
    names = list(res_atoms.atom_name)
    if i == 0 and res_name not in ["WAT", "NA", "CL"]:
        names = ["H1" if n == "H" else n for n in names]
    atom_names.extend(names)
    atom_counts.append(len(names))

# N/C terminal naming for protein
protein_idx = [i for i, r in enumerate(residues) if r not in ["WAT", "NA", "CL"]]
if protein_idx:
    residues[protein_idx[0]] = "N" + residues[protein_idx[0]]
    residues[protein_idx[-1]] = "C" + residues[protein_idx[-1]]

print(f"Residues: {len(residues)} ({len(protein_idx)} protein)")

## 3. Parameterize & Run Simulation

In [None]:
# Load force field and parameterize
ff = load_force_field("data/force_fields/ff14SB.eqx")
system_params = parameterize_system(
    ff, residues, atom_names, atom_counts,
    water_model="TIP3P",
    rigid_water=True
)
print("Parameterization complete!")

In [None]:
# Run explicit solvent simulation (50 ps)
positions = jnp.array(atom_array.coord)
key = random.PRNGKey(42)
# NOTE: TPU uses float32. For numerical stability with PME:
#  - Use conservative timestep (1.0 fs instead of 2.0 fs)
#  - Ensure adequate PME grid resolution (64 instead of 48)
#  - Minimization convergence is monitored automatically
spec = simulate.SimulationSpec(
    total_time_ns=0.05,  # 50 ps
    step_size_fs=1.0,  # Conservative for TPU float32 + PME
    save_interval_ns=0.001,
    accumulate_steps=200,  # Smaller buffer for memory efficiency
    save_path="1uao_explicit_traj.array_record",
    temperature_k=300.0,
    gamma=1.0,
    box=box_size,
    use_pbc=True,
    use_neighbor_list=True,
    pme_grid_size=64  # Increased for better accuracy
)
print("Running 50ps explicit solvent simulation...")
final_state = simulate.run_simulation(
    system_params=system_params,
    r_init=positions,
    spec=spec,
    key=key
)
print(f"Done! Final energy: {final_state.potential_energy:.2f} kcal/mol")

## 4. Visualize with Water (Low Alpha Blue Spheres)

In [None]:
# Generate HTML visualization with custom water styling
custom_styles = [
    # Water oxygens: blue spheres with low opacity
    ({'resn': ['WAT', 'HOH'], 'atom': 'O'}, 
     {'sphere': {'radius': 0.5, 'color': 'lightblue', 'opacity': 0.25}}),
]

save_trajectory_html(
    trajectory="1uao_explicit_traj.array_record",
    pdb_path=pdb_path,
    output_path="1uao_explicit_viz.html",
    stride=2,
    style="cartoon",
    title="1UAO Explicit Solvent (50ps)",
    custom_styles=custom_styles
)
print("Saved 1uao_explicit_viz.html")

In [None]:
# Display inline (requires py2Dmol)
try:
    import py2Dmol
    from prolix.visualization import view_trajectory
    view_trajectory(
        "1uao_explicit_traj.array_record",
        pdb_path,
        stride=5
    )
except ImportError:
    print("py2Dmol not available. Download 1uao_explicit_viz.html instead.")

In [None]:
# Download the HTML file
from google.colab import files
files.download('1uao_explicit_viz.html')