In [None]:
from typing import Dict
import torch

import torch.nn as nn
import torch.nn.functional as F

class MockEnergyModel(nn.Module):
    def __init__(self, num_node_types):
        super().__init__()
        self.num_node_types = num_node_types

    def forward(self, data: Dict[str, torch.Tensor]):
        pos = data['pos']  # shape: [N, D]
        pos = pos - pos.mean(dim=0, keepdim=True)
        node_types = data['node_types']  # shape: [N], integer type
        node_type_onehot = F.one_hot(node_types, num_classes=self.num_node_types).float()  # [N, num_node_types]
        # Weight: sum over one-hot (can be customized, here just sum)
        weights = node_type_onehot.sum(dim=1)  # [N]
        pos_sum = (pos**2).sum(dim=1)  # [N]
        energy = (pos_sum * weights).sum()
        data['energy'] = energy * 10
        return data

# Compile the model with torch.jit and save it
model = MockEnergyModel(num_node_types=10)
scripted_model = torch.jit.script(model)
scripted_model.save("mock_energy_model.pt")

In [None]:
import sys
import numpy as np
import openmm as mm
from openmm import unit, Platform
from openmm.app import *
from openmmtorch import TorchForce
from openmm.app.metadynamics import Metadynamics, BiasVariable

from geqmd.geqtrain.compile import load_and_init_model

# --- Simulation Parameters ---
pdb_file = "ala.pdb"
total_steps = 20000  # Total number of simulation steps
log_frequency = 1000   # How often to print status to the console
dcd_frequency = 1000   # How often to save the trajectory
device = "CUDA"

# --- GEqTrain Model Parameters --- #
model_path = "/home/angiod@usi.ch/GEqMD/mock_energy_model.pt"
model_field = "energy"
model_position_unit = unit.angstrom
model_energy_unit = unit.kilocalorie_per_mole
kwargs = {
    "atom_names_filters": None,
    "resname_filters": None,
    "neighbor_names_filters": None,
    "neighbor_resname_filters": None
}

# --- Metadynamics Parameters ---
bias_factor = 8.0
gaussian_height = 0.1 * unit.kilocalories_per_mole
gaussian_frequency = 500

# --- Load the Initial Structure ---
print("Loading PDB file...")
try:
    pdb = PDBFile(pdb_file)
except FileNotFoundError:
    print(f"Error: The file '{pdb_file}' was not found.")
    print("Please make sure you have created this file in the same directory.")
    sys.exit(1)

topology = pdb.topology
positions = pdb.positions

# --- Set up the Force Field ---
print("Setting up the force field...")
forcefield = ForceField('amber14-all.xml', 'amber14/tip3pfb.xml')

# --- Create the OpenMM System ---
system = forcefield.createSystem(topology, nonbondedMethod=NoCutoff, constraints=HBonds)
# Remove all forces from the system
# while system.getNumForces() > 0:
#     system.removeForce(0)

# Add GEqTrain Model
model = load_and_init_model(
    model_filename=model_path,
    field=model_field,
    topology=topology,
    positions=positions,
    model_position_unit=model_position_unit,
    model_energy_unit=model_energy_unit,
    device=device.lower(),
    **kwargs,
)

torch_force = TorchForce(model)
system.addForce(torch_force)

# --- Define the Collective Variables (CVs) ---
# Phi (C-N-CA-C): atoms 4, 6, 8, 14 (0-indexed)
# Psi (N-CA-C-N): atoms 6, 8, 14, 16 (0-indexed)
phi_atoms = [4, 6, 8, 14]
psi_atoms = [6, 8, 14, 16]

phi = mm.CustomTorsionForce("theta")
phi.addTorsion(*phi_atoms)
# The biasWidth should be about 1/3 to 1/2 of the CV's standard deviation
# during an unbiased run. 0.35 rad is a reasonable starting point for torsions.
phi_var = BiasVariable(phi, minValue=-np.pi, maxValue=np.pi, biasWidth=0.35, periodic=True)

psi = mm.CustomTorsionForce("theta")
psi.addTorsion(*psi_atoms)
psi_var = BiasVariable(psi, minValue=-np.pi, maxValue=np.pi, biasWidth=0.35, periodic=True)

# --- Configure Metadynamics ---
print("Configuring Metadynamics...")
meta = Metadynamics(system, [phi_var, psi_var], 300.0*unit.kelvin, bias_factor,
                    gaussian_height, gaussian_frequency,
                    saveFrequency=gaussian_frequency, biasDir='.')

# --- Set up the Integrator ---
integrator = mm.LangevinIntegrator(300*unit.kelvin, 1.0/unit.picosecond, 2.0*unit.femtoseconds)

# --- Create the Simulation Object ---
print("Creating simulation object...")
platform = Platform.getPlatformByName(device)
simulation = Simulation(topology, system, integrator, platform=platform)
simulation.context.setPositions(positions)

# --- Minimize the Energy ---
print("Performing energy minimization...")
simulation.minimizeEnergy()
print("Energy minimization complete.")

# --- Add Reporters to the Simulation ---
simulation.reporters.append(StateDataReporter(sys.stdout, log_frequency, step=True,
                                               potentialEnergy=True, temperature=True, progress=True,
                                               remainingTime=True, speed=True, totalSteps=total_steps,
                                               separator='\t'))
simulation.reporters.append(DCDReporter('trajectory.dcd', dcd_frequency))

# --- Run the Simulation ---
print(f"Starting simulation for {total_steps} steps...")
meta.step(simulation, total_steps)

print("Simulation complete!")
print(f"Trajectory saved to 'trajectory.dcd'.")

from matplotlib import pyplot as plt
fes_data = meta.getFreeEnergy()

plt.imshow(fes_data, origin='lower', aspect='auto')
plt.xlabel('Psi (ψ) index')
plt.ylabel('Phi (φ) index')
# plt.gca().invert_xaxis()
# plt.gca().invert_yaxis()
plt.colorbar()
plt.show()