Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaNs when using torchani #50

Open
peastman opened this issue Mar 17, 2023 · 11 comments
Open

NaNs when using torchani #50

peastman opened this issue Mar 17, 2023 · 11 comments

Comments

@peastman
Copy link
Member

I'm running simulations of mixed ML/MM systems where part is computed with ANI-2x and part with Amber. As long as I specify implementation='nnpops' in the call to createMixedSystem() it works well. But if I specify implementation='torchani', the simulation immediately blows up with NaN coordinates. I tried a few molecules and the result is the same for all of them.

Does anyone have an idea what could be causing this? I can put together a test case to reproduce the problem, if that's helpful. My current system is too big to post here. Here are the relevant packages from my conda environment.

nnpops                    0.4             cuda112py310hd4d1af5_0    conda-forge
openmm                    8.0.0           py310h5728c26_0    conda-forge
openmm-ml                 1.0                pyhd8ed1ab_0    conda-forge
openmm-torch              1.0             cuda112py310hbd91edb_0    conda-forge
pytorch                   1.13.1          cuda112py310he33e0d6_200    conda-forge
torchani                  2.2.2           cuda112py310haf08e2f_7    conda-forge
@sef43
Copy link
Member

sef43 commented Mar 17, 2023

Things I would try first:

  • use an older version of pytorch say 1.11.1
  • If you make a simulation with just the ML part and use a pure ML system does it also get NaNs?
  • Run it on just CPU

If you can provide an example I am happy to take a look

@peastman
Copy link
Member Author

Here's a bit more information.

  • The problem doesn't happen when simulating just the ML part. It only happens in the mixed system. It has about 130,000 total atoms, of which 53 are simulated with ML.
  • The crash happens quickly but not instantly. It usually happens within the first 50 steps, and almost always within the first 100 steps.
  • When it happens, the first sign of trouble is a sudden huge increase in potential energy:
#"Step","Potential Energy (kJ/mole)","Kinetic Energy (kJ/mole)","Temperature (K)","Speed (ns/day)"
...
25,-5374000.105669722,548358.342327144,482.35639886791785,0.623
26,-5373635.169209611,548826.6602447805,482.76834873142826,0.631
27,-5370600.139194695,546261.4701012508,480.51191204662575,0.654
28,24362204.57730076,578217.3027816577,508.62145134739,0.676
29,591276527108.6628,52792615009164.625,46438348242.45094,0.697

It might be related to the pytorch version. I have two environments. In the first environment, all packages are installed with conda. In that environment, the problem always happens. Here are the versions of the most relevant packages.

nnpops                    0.4             cuda112py310hd4d1af5_0    conda-forge
openmm                    8.0.0           py310h5728c26_0    conda-forge
openmm-ml                 1.0                pyhd8ed1ab_0    conda-forge
openmm-torch              1.0             cuda112py310hbd91edb_0    conda-forge
pytorch                   1.13.1          cuda112py310he33e0d6_200    conda-forge
torchani                  2.2.2           cuda112py310haf08e2f_7    conda-forge

In the other environment, all the OpenMM related packages are installed from source. In that environment the problem usually does not happen. I have occasionally seen NaNs in it, but they're much less frequent. Here are the package versions.

pytorch                   1.11.0          cuda112py39ha0cca9b_1    conda-forge
pytorch-gpu               1.11.0          cuda112py39h68407e5_1    conda-forge
torchani                  2.2.2           cuda112py39h0dd23f4_5    conda-forge

I tried to downgrade pytorch to 1.11 in the first environment to see if that would fix the problem, but I'm getting version conflicts.

@JohannesKarwou
Copy link

I don’t know if this is related to your problem, but when using createMixedSystem() for switching a whole waterbox (375 atoms) I ran into errors with boxes blowing up (#52). I think in my case it might be related to the CustomBondForce (https://github.com/openmm/openmm-ml/blob/c3d8c28eb92bf5c4b16efb81ad7a44b707fc5907/openmmml/mlpotential.py#LL298C14-L298C14) not taking PBC into account for the atoms described only by ML. I’m not sure, but could this cause problems here for you too?

@peastman
Copy link
Member Author

That doesn't sound related. In my case, the error only happens if we use TorchANI instead of NNPOps to implement the TorchForce.

@sef43
Copy link
Member

sef43 commented Mar 22, 2023

Are you using NNPOps before or after this PR: openmm/NNPOps#83 ?
Before this the NNPops and TorchnANI implementations would have differing results. Otherwise I don't really understand why they would be so different. You could do something similar to what Raul did here to narrow down where the difference comes from: openmm/NNPOps#82 (comment)

@peastman
Copy link
Member Author

The problem isn't in NNPOps. It works correctly. The error happens when using torchani instead.

I'm making progress toward narrowing it down. Here is my current simplest code for reproducing it. It creates a mixed system, evaluates the forces, and then immediately evaluates the forces again. When using torchani 2.2.2 and pytorch 1.13.1, they come out different.

The error requires that only a small part of the system is modeled with ML. In this script I have 2000 atoms, with only 50 being ML. If I reduce it to 1000 atoms it works correctly.

The error requires there to be a NonbondedForce in the system, and for it to use a cutoff. It does not need to use periodic boundary conditions, though.

from openmm import *
from openmm.app import *
from openmm.unit import *
from openmmml import MLPotential
import numpy as np

potential = MLPotential('ani2x')
numParticles = 2000
topology = Topology()
chain = topology.addChain()
residue = topology.addResidue('UNK', chain)
system = System()
nb = NonbondedForce()
nb.setNonbondedMethod(NonbondedForce.CutoffNonPeriodic)
system.addForce(nb)
elements = np.random.choice([element.hydrogen, element.carbon, element.nitrogen, element.oxygen], numParticles)
for i in range(numParticles):
    system.addParticle(1.0)
    nb.addParticle(0, 1, 0)
    topology.addAtom(f'{i}', elements[i], residue)
pos = np.random.random((numParticles, 3))
ml_atoms = list(range(numParticles-50, numParticles))

system2 = potential.createMixedSystem(topology, system, ml_atoms, implementation='torchani')
integrator = LangevinMiddleIntegrator(300*kelvin, 1/picosecond, 0.002*picoseconds)
context = Context(system2, integrator)
context.setPositions(pos)
f1 = context.getState(getForces=True).getForces(asNumpy=True)._value
f2 = context.getState(getForces=True).getForces(asNumpy=True)._value
for i in ml_atoms:
    print(f1[i], f2[i])

@sef43
Copy link
Member

sef43 commented Apr 3, 2023

This seems to be a bug in PyTorch 1.13. The JIT profile guided optimisation does something to the torchani aev_computer functions that results in incorrect gradients on the position tensor. It does not happen with pytorch 1.12.1, it does not happen if the torchani aev_computer is replaced by NNPOps TorchANISymmetryFunctions. It only happens using CUDA device. It does not occur if the JIT profile guided optimisations are disabled. The below script, when run on a CUDA gpu, will demonstrate this.

from openmm import *
from openmm.app import *
from openmm.unit import *
from openmmml import MLPotential
import numpy as np
import torch

# make the test system
potential = MLPotential('ani2x')
numParticles = 2000
topology = Topology()
chain = topology.addChain()
residue = topology.addResidue('UNK', chain)
system = System()
nb = NonbondedForce()
nb.setNonbondedMethod(NonbondedForce.CutoffNonPeriodic)
system.addForce(nb)
elements = np.random.choice([element.hydrogen, element.carbon, element.nitrogen, element.oxygen], numParticles)
for i in range(numParticles):
    system.addParticle(1.0)
    nb.addParticle(0, 1, 0)
    topology.addAtom(f'{i}', elements[i], residue)
pos = np.random.random((numParticles, 3))
ml_atoms = list(range(numParticles-50, numParticles))
system2 = potential.createMixedSystem(topology, system, ml_atoms, implementation='torchani')


# load the pytorch model
# CPU version for reference forces
model_cpu = torch.jit.load("animodel.pt", map_location="cpu")
pos_cpu = torch.tensor(pos, requires_grad=True, dtype=torch.float32, device="cpu")
e_cpu = model_cpu(pos_cpu)
e_cpu.backward()
f_cpu = -pos_cpu.grad

# load in a CUDA version of the model
model_cuda = torch.jit.load("animodel.pt",map_location="cuda")

# turn on JIT profile guided optimizations (These will be on by default I think)
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)

# later calls (the optimised ones) will fail to compute correct forces
forces_cuda_jitopt = []
N=5 # num reps
for n in range(N):
    pos_cuda = torch.tensor(pos, requires_grad=True, dtype=torch.float32, device="cuda")
    e_cuda = model_cuda(pos_cuda)
    e_cuda.backward()
    f_cuda = -pos_cuda.grad
    forces_cuda_jitopt.append(f_cuda.cpu().numpy())

# compare
print("compare cuda forces with JIT profile guided optimization enabled")
for n in range(N):
   if np.allclose(f_cpu, forces_cuda_jitopt[n], rtol=1e-3):
      print("n =",n, "forces are correct")
   else:
      print("n = ",n, "forces are wrong!")


# now do the same but disable JIT profile guided optimisations

# load in a CUDA version of the model
model_cuda = torch.jit.load("animodel.pt",map_location="cuda")

torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)


forces_cuda = []
N=5 # num reps
for n in range(N):
    pos_cuda = torch.tensor(pos, requires_grad=True, dtype=torch.float32, device="cuda")

    e_cuda = model_cuda(pos_cuda)
    e_cuda.backward()
    f_cuda = -pos_cuda.grad

    forces_cuda.append(f_cuda.cpu().numpy())

# compare
print("compare cuda forces with JIT profile guided optimization disabled")
for n in range(N):
   if np.allclose(f_cpu, forces_cuda[n], rtol=1e-3):
      print("n =",n, "forces are correct")
   else:
      print("n = ",n, "forces are wrong!")

The output I get on a RTX3090 with Pytorch 1.13.1 and CUDA 11.7 is this:

compare cuda forces with JIT profile guided optimization enabled
n = 0 forces are correct
n = 1 forces are correct
n =  2 forces are wrong!
n =  3 forces are wrong!
n =  4 forces are wrong!
compare cuda forces with JIT profile guided optimization disabled
n = 0 forces are correct
n = 1 forces are correct
n = 2 forces are correct
n = 3 forces are correct
n = 4 forces are correct

@sef43
Copy link
Member

sef43 commented Apr 4, 2023

This seems to be fixable for me by changing a ** operation to torch.float_power: aiqm/torchani@172b6fe

@peastman do you get correct forces if you use my fork of torchani?

to install in an existing environment:

  1. first remove conda-forge torchani
    conda remove --force torchani
  2. then install fork from github using pip:
    pip install git+https://github.com/sef43/torchani.git@patch_openmmml_issue50

@peastman
Copy link
Member Author

peastman commented Apr 4, 2023

Your fix works for me. Fantastic work tracking this down! Hopefully they'll release an update soon.

@sef43
Copy link
Member

sef43 commented Apr 4, 2023

recommended workaround is turning off NVFuser aiqm/torchani#628 (comment)

(I don't know why changing the ** to a float_power seems to fix it )

At this moment, I would recommend disabling NVFuser by running the following:

torch._C._jit_set_nvfuser_enabled(False)

This is the relevant pytorch issue: pytorch/pytorch#84510

@sef43
Copy link
Member

sef43 commented Aug 23, 2023

According to this comment pytorch/pytorch#84510 (comment) NVFuser is being replaced by NNC. These means in future PyTorch releases the default TorchScript settings will be to use NNC, but for the current pytorch 2.0 we will need to tell people to switch NVFuser to NNC if they want to use Torchani without NNPOps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants