In [None]:
!pip install ase mace-torch

In [None]:
from __future__ import annotations
import math, random
import numpy as np

from ase import Atoms, units
from ase.build import molecule
from ase.geometry import wrap_positions
from ase.io import write
from ase.md import Langevin
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution


N_WATERS   = 256        # number of H2O molecules
BOX_LEN    = 24.0       # Å, cubic box side
SEED       = 7          # RNG seed
MAX_TRIALS = 50_000     # max insertion attempts
OUTFILE    = "water_random_init.extxyz"

# placement cutoffs (Å) to avoid overlaps
CUT_OO = 2.4
CUT_OH = 1.6
CUT_HH = 1.2


def random_rotation_matrix(rng: random.Random) -> np.ndarray:
    u1, u2, u3 = rng.random(), rng.random(), rng.random()
    q1 = math.sqrt(1 - u1) * math.sin(2 * math.pi * u2)
    q2 = math.sqrt(1 - u1) * math.cos(2 * math.pi * u2)
    q3 = math.sqrt(u1) * math.sin(2 * math.pi * u3)
    q4 = math.sqrt(u1) * math.cos(2 * math.pi * u3)
    x, y, z, w = q1, q2, q3, q4
    return np.array([
        [1-2*(y*y+z*z), 2*(x*y - z*w),   2*(x*z + y*w)],
        [2*(x*y + z*w), 1-2*(x*x+z*z),   2*(y*z - x*w)],
        [2*(x*z - y*w), 2*(y*z + x*w),   1-2*(x*x+y*y)],
    ])

def passes_cutoffs(new: Atoms, existing: Atoms, box_len: float) -> bool:
    if len(existing) == 0:
        return True
    for s1, p1 in zip(new.get_chemical_symbols(), new.get_positions()):
        dv = existing.get_positions() - p1
        dv -= np.rint(dv / box_len) * box_len  # minimum image
        d = np.linalg.norm(dv, axis=1)
        s2 = np.array(existing.get_chemical_symbols())

        if s1 == 'O':
            if np.any(d[s2=='O'] < CUT_OO): return False
            if np.any(d[s2=='H'] < CUT_OH): return False
        else:  # s1 == 'H'
            if np.any(d[s2=='O'] < CUT_OH): return False
            if np.any(d[s2=='H'] < CUT_HH): return False
    return True

def place_water_random(rng: random.Random, box_len: float) -> Atoms:
    h2o = molecule("H2O")
    R = random_rotation_matrix(rng)
    com = h2o.get_center_of_mass()
    pos = (h2o.get_positions() - com) @ R.T
    shift = np.array([rng.random()*box_len, rng.random()*box_len, rng.random()*box_len])
    pos += shift
    h2o.set_positions(pos)
    h2o.set_pbc(True)
    h2o.set_cell([box_len, box_len, box_len])
    h2o.set_positions(wrap_positions(h2o.get_positions(), h2o.cell, pbc=[1,1,1]))
    return h2o

def build_box(n_waters=N_WATERS, box_len=BOX_LEN, seed=SEED, max_trials=MAX_TRIALS) -> Atoms:
    rng = random.Random(seed)
    box = Atoms(pbc=True, cell=[box_len, box_len, box_len])
    trials = placed = 0
    while placed < n_waters and trials < max_trials:
        trials += 1
        cand = place_water_random(rng, box_len)
        if passes_cutoffs(cand, box, box_len):
            box += cand
            placed += 1
            if placed % 16 == 0:
                print(f"  placed {placed}/{n_waters} (trials {trials})")
    if placed < n_waters:
        raise RuntimeError(f"Only placed {placed}/{n_waters}. Increase BOX_LEN or relax cutoffs.")
    return box

def get_mace_calc(device: str | None = None):
    """Load pretrained MACE-MP calculator."""
    from mace.calculators import mace_mp
    try:
        return mace_mp(model="small", device=device)
    except Exception:
        print("[info] 'small' not available → falling back to 'medium'.")
        return mace_mp(model="medium", device=device)

def run_md(atoms: Atoms, T=300.0, dt_fs=0.5, steps=2000, friction=1e-3,
           traj="water_md.xyz", log_interval=50, device: str | None = None):
    atoms.calc = get_mace_calc(device=device)
    MaxwellBoltzmannDistribution(atoms, temperature_K=T)
    dyn = Langevin(atoms, dt_fs * units.fs, temperature_K=T, friction=friction)
    def dump(): write(traj, atoms, append=True)
    dyn.attach(dump, interval=log_interval)
    print(f"[MD] NVT @ {T} K, dt={dt_fs} fs, steps={steps}, saving→{traj}")
    dyn.run(steps)
    print("[MD] done.")


if __name__ == "__main__":
    print(f"Building random water box: N={N_WATERS}, L={BOX_LEN} Å, seed={SEED}")
    atoms = build_box(N_WATERS, BOX_LEN, SEED, MAX_TRIALS)
    print(f"Writing initial structure to {OUTFILE}")
    write(OUTFILE, atoms)

    # Attach MACE and run a short MD:
    run_md(atoms, T=700.0, dt_fs=0.5, steps=100, friction=1e-3,
           traj="water_md.xyz", log_interval=5, device=None)  # set device="cuda" if available


In [None]:
print(f"Building random water box: N={N_WATERS}, L={BOX_LEN} Å, seed={SEED}")
atoms = build_box(N_WATERS, BOX_LEN, SEED, MAX_TRIALS)
print(f"Writing initial structure to {OUTFILE}")
write(OUTFILE, atoms)


In [None]:
from ase.visualize import view
view(atoms)

In [None]:
# Attach MACE and run a short MD:
run_md(atoms, T=700.0, dt_fs=0.5, steps=100, friction=1e-3,
        traj="water_md.xyz", log_interval=5, device=None)  # set device="cuda" if available

In [None]:
from ase.io import read
import numpy as np

md_water_traj = read("water_md.xyz", index=":")
import matplotlib.pyplot as plt

# extract energy and forces from frames
n_frames = len(md_water_traj)
energies = np.full(n_frames, np.nan)
mean_forces = np.full(n_frames, np.nan)
max_forces = np.full(n_frames, np.nan)

for i, frame in enumerate(md_water_traj):
    # energy: prefer info['energy'] (from writer), fallback to calculator call
    e = frame.info.get("energy") if isinstance(frame.info, dict) else None
    if e is None:
        try:
            e = frame.get_potential_energy()
        except Exception:
            e = np.nan
    energies[i] = e

    # forces: check arrays (ASE stores forces in arrays if written)
    if "forces" in frame.arrays:
        f = frame.get_forces()
        mags = np.linalg.norm(f, axis=1)
        mean_forces[i] = mags.mean()
        max_forces[i] = mags.max()

# simple plots
steps = np.arange(n_frames)

fig, (ax_e, ax_f) = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
ax_e.plot(steps, energies, marker='o', ms=3)
ax_e.set_ylabel("Potential energy (eV)")
ax_e.grid(True)

ax_f.plot(steps, mean_forces, marker='o', ms=3, label="mean |F|")
ax_f.plot(steps, max_forces, marker='o', ms=3, label="max |F|")
ax_f.set_ylabel("Force (eV/Å)")
ax_f.set_xlabel("Frame")
ax_f.legend()
ax_f.grid(True)

plt.tight_layout()
plt.show()
