In [None]:
import os
import sys
import pickle as pkl
from functools import partial
from tqdm import tqdm

import Xponge
from tqdm import tqdm
import numpy as np

import Xponge.forcefield.amber.gaff as gaff
from Xponge.helper import rdkit as xponge_rdkit_helper
from Xponge import Residue,Molecule

from rdkit import Chem

## Build Force Field

In [None]:
### input a mol2 file
mol2_file = '/lustre/grp/gyqlab/linxh/TopologyMapper4BGMTest_20250512/out_mol2/EDM/0.mol2'

In [None]:
LIG_NAME = 'LIG'
mol = Xponge.get_assignment_from_mol2(mol2_file)
rdmol = xponge_rdkit_helper.assign_to_rdmol(mol)

### remove Hs && re-add Hs
mol = xponge_rdkit_helper.rdmol_to_assign(
    Chem.AddHs(Chem.RemoveAllHs(rdmol), addCoords=True)
)

### build GAFF
mol.determine_atom_type('gaff')
mol.calculate_charge('tpacm4')
res_type = mol.to_residuetype(f'{LIG_NAME}')
gaff.parmchk2_gaff(res_type, f'{LIG_NAME}')
os.makedirs("topo", exist_ok=True)
residue = Residue(res_type, name=res_type.name)
for atom in res_type.atoms:
    residue.Add_Atom(atom)
mol = Molecule(name=residue.name)
mol.Add_Residue(residue)
mol.box_length = [999, 999, 999]
Xponge.save_sponge_input(mol, os.path.join("topo", LIG_NAME))
os.remove(f'{LIG_NAME}')

## Evaluate Energy/Forces

### With SPONGE

Energy/Forces can be evaluated with Molecular Dynamics Simulation Package SPONGE.

SPONGE is available at [here](https://spongemm.cn/zh/home).

In [None]:
SPONGE_EXEC = '$YOUR_SPONGE_PATH'

os.system(f'{SPONGE_EXEC} -mode minimization -pbc False -default_in_file_prefix topo/{LIG_NAME} -cutoff 100 -dt 0.0 -step_limit 1 -write_information_interval 1 -mdinfo topo/mdinfo.txt -frc topo/frc.dat -mdout topo/mdout.txt -crd topo/mdcrd.dat -box topo/mdbox.txt -rst topo/restart')

In [None]:
### load force 
force = np.fromfile('topo/frc.dat', dtype=np.float32).reshape(-1,3)

### With JAX

Here we provide an alternative way to evaluate energy/forces with a JAX-function which can be jitted & vmapped.


In [None]:
### convert SPONGE-format topology files into a JAX-readable dictionary
from jax_gaff import convert_sponge_input_to_dict

ff_params = convert_sponge_input_to_dict('topo/LIG')

In [None]:
import jax
import jax.numpy as jnp
from jax_gaff import gaff_ene_frc, NMAX_ATOMS, FF_TARGET_SHAPE, parse_num_file

jit_gaff_ene_frc = gaff_ene_frc ## you can jit/vmap/pmap this function using jax-transformations

In [None]:
### Pad input coordinates and force-field parameters
# (You can skip this step when evaluating a single molecule.)
def _pad_fn(ff_params):
    def _pad_length(arr_shape, target_shape):
        return [(0, tshape - ashape) for ashape, tshape in zip(arr_shape, target_shape)]
    return {
        k: np.pad(v, _pad_length(v.shape, FF_TARGET_SHAPE[k])) for k, v in ff_params.items()
    }

ff_param_dict_pad = _pad_fn(ff_params)
crds = np.array(parse_num_file(f'topo/{LIG_NAME}_coordinate.txt')[1:-1])
n_atoms = crds.shape[0]
crds = np.pad(crds, ((0, NMAX_ATOMS - n_atoms), (0, 0)))

In [None]:
### Converting inputs to jax-arrays

ff_inputs = tuple([jnp.array(ff_param_dict_pad[key][None, ...]) for key in [
                    'C', 'E', 'S', 'Ex', 'Ba', 'Bb', 'Bk', 'Br', 'Aa', 'Ab', 'Ac', 'Ak', 'At', 'Da', 'Db', 'Dc', 'Dd', 'Dn', 'Dk', 'Dp', 'Na', 'Nb', 'Nlf', 'Nqf'
                ]])

crds_jax = jnp.array(crds[None, ...])
ene, frc = jit_gaff_ene_frc(crds_jax, ff_inputs)