In [1]:
import sys
sys.path.append("..")

import torch

from torchmd.forces import Forces
from torchmd.systems import System
from moleculekit.molecule import Molecule
from torchmd.parameters import Parameters

import numpy as np
import time
import json
import tqdm

from module.torchmd import tagged_forcefield
from module.torchforcefield import TorchForceField

import os

2024-12-03 17:46:02,008 - numexpr.utils - INFO - Note: NumExpr detected 32 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
2024-12-03 17:46:02,009 - numexpr.utils - INFO - NumExpr defaulting to 16 threads.


In [2]:
version_major, version_minor = torch.__version__.split(".")[:2]
assert int(version_major) >= 2 and int(version_minor) >= 4, "Compilation only works with pytorch 2.4 or greater"

In [3]:
def calculate_forces(
        coords,
        box,
        system,
        forces_obj):
    if box is not None:
        # Take the diagonal 
        linearized = box.reshape(-1,9)[:,[0,4,8]]
        # Then reformat to TorchMD's expected shape
        box_full = linearized.reshape(linearized.shape[0], 3, 1)
    else:
        box_full = torch.zeros(coords.shape[0], 3, 1)

    prior_forces = []
    prior_energy = 0.0
    for i in tqdm.tqdm(range(0, coords.shape[0]), dynamic_ncols=True):
        co = coords[i]
        system.set_box(box_full[i])
        Epot = forces_obj.compute(co.reshape([1, -1, 3]), system.box, system.forces)
        fr = (
            system.forces.detach().cpu().numpy().astype(np.float32).reshape([-1, 3])
        )
        prior_energy += Epot[0]
        prior_forces.append(fr)

    prior_forces = torch.as_tensor(prior_forces, dtype=system.forces.dtype)
    return prior_energy, prior_forces

def make_forces(
    coords_npz,
    box_npz,
    psf,
    prior_path,
    prior_params_path,
    device="cpu",
):
    device = torch.device(device)
    precision = torch.float
    replicas = 1


    # Load prior_params
    with open(prior_params_path, 'r') as f:
        prior_params = json.load(f)
    forceterms = prior_params["forceterms"]
    exclusions = prior_params["exclusions"]

    mol = Molecule(psf)
    natoms = mol.numAtoms

    coords = np.load(coords_npz)
    box = None
    if box_npz:
        box = np.load(box_npz)
        box = torch.tensor(box, dtype=precision).to(device)
    coords = torch.tensor(coords, dtype=precision).to(device)

    atom_vel = torch.zeros(replicas, natoms, 3)
    atom_pos = torch.zeros(natoms, 3, replicas)

    ff = tagged_forcefield.create(mol, prior_path)
    parameters = Parameters(ff, mol, forceterms, precision=precision, device=device)

    system = System(natoms, replicas, precision, device)
    system.set_positions(atom_pos)
    system.set_velocities(atom_vel)

    forces = Forces(parameters, terms=forceterms, exclusions=exclusions)

    return calculate_forces(coords, box, system, forces)

def graph_forward(module, box, data, repeats=20):
    static_in = torch.as_tensor(data, device=module.device)
    if box is not None and not torch.all(box == 0):
        static_box = torch.ones((3,), device=module.device)
    else:
        static_box = None
    static_out = torch.zeros_like(data, device=module.device)
    static_pots_out = torch.zeros(1, device=module.device)

    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for i in range(repeats):
            static_pots_out[:] = module.forward(static_in, static_box, static_out)
    torch.cuda.current_stream().wait_stream(s)

    g = torch.cuda.CUDAGraph()
    with torch.cuda.graph(g):
        static_pots_out[:] = module.forward(static_in, static_box, static_out)

    def eval_graph(co, box, fo_out):
        static_in.copy_(co)
        if static_box is not None:
            static_box.copy_(box)
        g.replay()
        fo_out.copy_(static_out)
        return static_pots_out

    module.forward = eval_graph

def make_forces_tff(
    coords_npz,
    box_npz,
    psf,
    prior_path,
    prior_params_path,
    device="cpu",
    compile=False,
    cuda_graph=True
):
    device = torch.device(device)
    precision = torch.float
    replicas = 1

    # Load prior_params
    with open(prior_params_path, 'r') as f:
        prior_params = json.load(f)
    forceterms = prior_params["forceterms"]
    exclusions = prior_params["exclusions"]

    mol = Molecule(psf)
    natoms = mol.numAtoms

    coords = np.load(coords_npz)
    box = None
    graph_box = None
    if box_npz:
        box = np.load(box_npz)
        box = torch.tensor(box, dtype=precision).to(device)
        graph_box = box[0]
    coords = torch.tensor(coords, dtype=precision).to(device)

    atom_vel = torch.zeros(replicas, natoms, 3)
    atom_pos = torch.zeros(natoms, 3, replicas)

    forces = TorchForceField(prior_path, mol, device, terms=forceterms, exclusions=exclusions, use_box=box is not None)
    print(forces)

    if compile:
        t0 = time.time()
        print("Compiling prior...", end="", flush=True)
        forces.forward = torch.compile(forces.forward)
        print(f" Done ({time.time() - t0:.2f}s)")
    if cuda_graph:
        t0 = time.time()
        print("Building CUDA graph...", end="", flush=True)
        graph_forward(forces, graph_box, coords[0], repeats=20)
        print(f" Done ({time.time() - t0:.2f}s)")
        
        

    system = System(natoms, replicas, precision, device)
    system.set_positions(atom_pos)
    system.set_velocities(atom_vel)

    return calculate_forces(coords, box, system, forces)

In [4]:
class TestData():
    def __init__(self, path, box=True):
        pdbid = path.split("/")[-1]
        if not pdbid: # If it ended with a /
            pdbid = path.split("/")[-2]
        self.ref_coords = os.path.join(path, "raw/coordinates.npy")
        self.prior_params_path = os.path.join(path, f"raw/{pdbid}_prior_params.json")
        self.prior_path = os.path.join(path, f"raw/{pdbid}_priors.yaml")
        self.psf = os.path.join(path, f"processed/{pdbid}_processed.psf")
        self.ref_box = os.path.join(path, "raw/box.npy")
        if not box or not os.path.exists(self.ref_box):
            self.ref_box = None

def calc_rmsd(a, b):
    if a is None or b is None:
        return None
    criterion = torch.nn.MSELoss()
    return torch.sqrt(criterion(a, b))

In [None]:
# data = TestData("/home/argon/Stuff/seq_embedding/cg_single_chain_2024.06.26_subsetC_CA_lj_angleXCX_dihedralX/6MRR/")
data = TestData("/media/DATA_18_TB_1/daniel_s/cgschnet/seq_embedding/cg_single_chain_2024.06.26_subsetC_CA_lj_angleXCX_dihedralX/6MRR/")

tmd_forces_result = ttf_forces_result = ttf_graph_forces_result = ttf_comp_graph_forces_result = None
tmd_energy_result = ttf_energy_result = ttf_graph_energy_result = ttf_comp_graph_energy_result = float("nan")

tmd_energy_result, tmd_forces_result = make_forces(data.ref_coords, None, data.psf, data.prior_path, data.prior_params_path, "cuda")
ttf_energy_result, ttf_forces_result = make_forces_tff(data.ref_coords, None, data.psf, data.prior_path, data.prior_params_path, "cuda", compile=False, cuda_graph=False)
ttf_graph_energy_result, ttf_graph_forces_result = make_forces_tff(data.ref_coords, None, data.psf, data.prior_path, data.prior_params_path, "cuda", compile=False, cuda_graph=True)
ttf_comp_graph_energy_result, ttf_comp_graph_forces_result = make_forces_tff(data.ref_coords, None, data.psf, data.prior_path, data.prior_params_path, "cuda", compile=True, cuda_graph=True)

print("---")
print("Force RMS magnitude", torch.sqrt(torch.mean(tmd_forces_result**2)))
print("RMSD: TorchMD vs TorchMD", calc_rmsd(tmd_forces_result, tmd_forces_result))
print("RMSD: TorchMD vs TorchForceField", calc_rmsd(tmd_forces_result, ttf_forces_result))
print("RMSD: TorchMD vs TorchForceField graph", calc_rmsd(tmd_forces_result, ttf_graph_forces_result))
print("RMSD: TorchMD vs TorchForceField compiled graph", calc_rmsd(tmd_forces_result, ttf_comp_graph_forces_result))
print("Energy:", tmd_energy_result, ttf_energy_result, ttf_graph_energy_result, ttf_comp_graph_energy_result)

  self.pos[:] = torch.tensor(
  self.box[r][torch.eye(3).bool()] = torch.tensor(
100%|██████████| 10000/10000 [00:13<00:00, 737.49it/s]
  prior_forces = torch.as_tensor(prior_forces, dtype=system.forces.dtype)


TorchForceField(
  (term_modules): ModuleList(
    (0): TFF_Bond()
    (1): TFF_Angle()
    (2): TFF_Dihedral()
    (3): TFF_RepulsionCG(cutoff=None, exclusions=['bonds', 'angles', 'dihedrals'])
  )
)


100%|██████████| 10000/10000 [00:07<00:00, 1350.44it/s]


TorchForceField(
  (term_modules): ModuleList(
    (0): TFF_Bond()
    (1): TFF_Angle()
    (2): TFF_Dihedral()
    (3): TFF_RepulsionCG(cutoff=None, exclusions=['bonds', 'angles', 'dihedrals'])
  )
)
Building CUDA graph... Done (0.06s)


100%|██████████| 10000/10000 [00:02<00:00, 3387.57it/s]


TorchForceField(
  (term_modules): ModuleList(
    (0): TFF_Bond()
    (1): TFF_Angle()
    (2): TFF_Dihedral()
    (3): TFF_RepulsionCG(cutoff=None, exclusions=['bonds', 'angles', 'dihedrals'])
  )
)
Compiling prior... Done (0.59s)
Building CUDA graph...

  self.pos[:] = torch.tensor(


 Done (2.16s)


  self.box[r][torch.eye(3).bool()] = torch.tensor(
100%|██████████| 10000/10000 [00:01<00:00, 6771.68it/s]


---
Force RMS magnitude tensor(8.0916)
RMSD: TorchMD vs TorchMD tensor(0.)
RMSD: TorchMD vs TorchForceField tensor(6.6535e-07)
RMSD: TorchMD vs TorchForceField graph tensor(6.8872e-07)
RMSD: TorchMD vs TorchForceField compiled graph tensor(7.3400e-07)
Energy: 852764.7920870781 tensor([852762.6250], device='cuda:0') tensor([852762.6250], device='cuda:0') tensor([852762.6250], device='cuda:0')


In [None]:
# Test a protein that uses box wrapping (something multichained)
# data = TestData("/home/argon/Stuff/cg_2GJH_dimer_2024.05.31_400fs_CA_lj_angleXCX_dihedralX/2GJH/", box=True)
data = TestData("/media/DATA_18_TB_1/daniel_s/cgschnet/cg_2GJH_dimer_2024.05.31_400fs_CA_lj_angleXCX_dihedralX/2GJH/")

tmd_forces_result = ttf_forces_result = ttf_graph_forces_result = ttf_comp_graph_forces_result = None
tmd_energy_result = ttf_energy_result = ttf_graph_energy_result = ttf_comp_graph_energy_result = float("nan")

tmd_energy_result, tmd_forces_result = make_forces(data.ref_coords, data.ref_box, data.psf, data.prior_path, data.prior_params_path, "cuda")
ttf_energy_result, ttf_forces_result = make_forces_tff(data.ref_coords, data.ref_box, data.psf, data.prior_path, data.prior_params_path, "cuda", compile=False, cuda_graph=False)
ttf_graph_energy_result, ttf_graph_forces_result = make_forces_tff(data.ref_coords, data.ref_box, data.psf, data.prior_path, data.prior_params_path, "cuda", compile=False, cuda_graph=True)
ttf_comp_graph_energy_result, ttf_comp_graph_forces_result = make_forces_tff(data.ref_coords, data.ref_box, data.psf, data.prior_path, data.prior_params_path, "cuda", compile=True, cuda_graph=True)

print("---")
print("Force RMS magnitude", torch.sqrt(torch.mean(tmd_forces_result**2)))
print("RMSD: TorchMD vs TorchMD", calc_rmsd(tmd_forces_result, tmd_forces_result))
print("RMSD: TorchMD vs TorchForceField", calc_rmsd(tmd_forces_result, ttf_forces_result))
print("RMSD: TorchMD vs TorchForceField graph", calc_rmsd(tmd_forces_result, ttf_graph_forces_result))
print("RMSD: TorchMD vs TorchForceField compiled graph", calc_rmsd(tmd_forces_result, ttf_comp_graph_forces_result))
print("Energy:", tmd_energy_result, ttf_energy_result, ttf_graph_energy_result, ttf_comp_graph_energy_result)

100%|██████████| 20000/20000 [00:28<00:00, 710.68it/s]


TorchForceField(
  (term_modules): ModuleList(
    (0): TFF_Bond()
    (1): TFF_Angle()
    (2): TFF_Dihedral()
    (3): TFF_RepulsionCG(cutoff=None, exclusions=['bonds', 'angles', 'dihedrals'])
  )
)


100%|██████████| 20000/20000 [00:15<00:00, 1279.95it/s]


TorchForceField(
  (term_modules): ModuleList(
    (0): TFF_Bond()
    (1): TFF_Angle()
    (2): TFF_Dihedral()
    (3): TFF_RepulsionCG(cutoff=None, exclusions=['bonds', 'angles', 'dihedrals'])
  )
)
Building CUDA graph... Done (0.09s)


100%|██████████| 20000/20000 [00:06<00:00, 3104.78it/s]


TorchForceField(
  (term_modules): ModuleList(
    (0): TFF_Bond()
    (1): TFF_Angle()
    (2): TFF_Dihedral()
    (3): TFF_RepulsionCG(cutoff=None, exclusions=['bonds', 'angles', 'dihedrals'])
  )
)
Compiling prior... Done (0.00s)
Building CUDA graph...

  self.pos[:] = torch.tensor(


 Done (1.51s)


  self.box[r][torch.eye(3).bool()] = torch.tensor(
100%|██████████| 20000/20000 [00:03<00:00, 5812.19it/s]


---
Force RMS magnitude tensor(8.6310)
RMSD: TorchMD vs TorchMD tensor(0.)
RMSD: TorchMD vs TorchForceField tensor(1.0310e-06)
RMSD: TorchMD vs TorchForceField graph tensor(1.0649e-06)
RMSD: TorchMD vs TorchForceField compiled graph tensor(1.0960e-06)
Energy: 2566794.674589157 tensor([2566793.7500], device='cuda:0') tensor([2566793.7500], device='cuda:0') tensor([2566793.7500], device='cuda:0')
