Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenMMException when adding TorchForce to the system #134

Open
dmighty007 opened this issue Feb 11, 2024 · 6 comments
Open

OpenMMException when adding TorchForce to the system #134

dmighty007 opened this issue Feb 11, 2024 · 6 comments

Comments

@dmighty007
Copy link

dmighty007 commented Feb 11, 2024

I am trying to use TorchForce to bias a simulation(box-full of waters). The torch model that calculates the CV looks for nearest neighbors of reference water molecule (within cutoff) then calculates pairwise distance between them, this is my feature. Now when adding this jitted model to openmm system it throws OpenMMException. Probably the issue is regarding grad of the tensor I'm returning from my TorchForce model.

The model used in TorchForce:


class Regres_CV2(nn.Module):
    def __init__(self):
        
        """
        input_dim : flattened input vector length
        hidden1 : node number of hidden layer 1
        hidden2 : node number of hidden layer 2
        code : dimension of latent space
        learning_rate : name suggests
        thresh : thresh to compare while earlystopping
        train_data : trainning dataset
        val_data : validation dataset
        """
        
        super(Regres_CV2, self).__init__()
        
        self.input_dim = 100
        self.hidden1 = 1024
        self.hidden2 = 512
        self.hidden3 = 100
        self.hidden4 = 2
        
        torch.manual_seed(1)
        np.random.seed(1)
        self.encoder = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden1),
            nn.GELU(),
            nn.Linear(self.hidden1, self.hidden2),
            nn.GELU(),
            nn.Linear(self.hidden2, self.hidden3),
            nn.GELU(),
            nn.Linear(self.hidden3, self.hidden4),
        )
        
        
    def transform(self, x):
        mean = torch.tensor([2.71049179, 2.7858348 , 2.8624867 , 2.98495639, 3.7791522 ,
                        3.97462569, 4.10705615, 4.21363918, 4.30592099, 4.38936177,
                        4.46746004, 4.54203554, 4.61539552, 4.68864327, 4.76333205,
                        4.84088127, 4.92357368, 5.01360315, 5.10504332, 5.19608193,
                        5.28742064, 5.37936281, 5.4723322 , 5.56674506, 5.66513154,
                        5.77511278, 5.92270633, 6.02856036, 6.11510288, 6.19097639,
                        6.26000642, 6.32406077, 6.38447937, 6.44241717, 6.4985381 ,
                        6.55356604, 6.60778704, 6.6614738 , 6.71468639, 6.76745529,
                        6.81982973, 6.87181133, 6.92385802, 6.97626398, 7.02942503,
                        7.08359104, 7.139082  , 7.19567081, 7.25226343, 7.30630277,
                        7.35744459, 7.40608602, 7.45265641, 7.49777495, 7.54154908,
                        7.58434985, 7.62622659, 7.66738141, 7.7078045 , 7.74760026,
                        7.78690909, 7.82570807, 7.86406368, 7.90206851, 7.93969587,
                        7.97700164, 8.01400973, 8.05068654, 8.08717092, 8.12329801,
                        8.15917206, 8.1947259 , 8.22994791, 8.26485376, 8.29941423,
                        8.33366086, 8.36766181, 8.40139063, 8.43490873, 8.46820756,
                        8.50136546, 8.53428456, 8.56699235, 8.59918515, 8.6309053 ,
                        8.66224525, 8.69325112, 8.72396482, 8.75443437, 8.78466108,
                        8.81469268, 8.84469551, 8.87456353, 8.9043148 , 8.93397793,
                        8.96364157, 8.99321065, 9.02271019, 9.05217446, 9.08154678])/10.0
        var = torch.tensor([0.06644448, 0.0720813 , 0.09388647, 0.15218329, 0.37277633,
                        0.31781363, 0.27569099, 0.24100772, 0.21193125, 0.18754059,
                        0.16740371, 0.15137762, 0.13928688, 0.13143184, 0.12765353,
                        0.12734585, 0.1298747 , 0.13416642, 0.13855183, 0.14191056,
                        0.1447239 , 0.1472054 , 0.15001784, 0.1522338 , 0.15287467,
                        0.15215965, 0.16101845, 0.1660292 , 0.16650034, 0.16420291,
                        0.16062643, 0.15659865, 0.15245382, 0.14878753, 0.14553498,
                        0.1429001 , 0.14109884, 0.14015729, 0.14015134, 0.141089  ,
                        0.1427195 , 0.1450639 , 0.14823157, 0.15220138, 0.15710731,
                        0.16332683, 0.17076869, 0.17942144, 0.18831005, 0.19544865,
                        0.20027134, 0.2031551 , 0.20431226, 0.20431089, 0.20333588,
                        0.20176965, 0.19970677, 0.19729833, 0.194629  , 0.1918973 ,
                        0.18898873, 0.1861282 , 0.18330681, 0.18048504, 0.17777579,
                        0.1751567 , 0.17269368, 0.17036853, 0.1680993 , 0.16591258,
                        0.16380179, 0.16170763, 0.15970761, 0.15767716, 0.15568707,
                        0.15367787, 0.15177837, 0.14989855, 0.14813889, 0.1464694 ,
                        0.14496545, 0.14371142, 0.14268205, 0.14166094, 0.14057099,
                        0.13943382, 0.13808576, 0.13678063, 0.13545642, 0.13413571,
                        0.13285683, 0.1315521 , 0.13035399, 0.12925464, 0.12816313,
                        0.12720239, 0.12632232, 0.12558945, 0.12489357, 0.12431238])/10.0
        return (x - mean)/var

    def periodic_neighbours(self, pos, maxdist, L):
        """
        Finds periodic neighbors within a given maximum distance in a PyTorch tensor context.

        Args:
            pos (torch.Tensor): Positions of particles (N x D tensor).
            maxdist (float): Maximum distance for neighbors.
            L (torch.Tensor): Box dimensions (D-dimensional tensor).

        Returns:
            torch.Tensor: Indices of neighbor pairs (M x 2 tensor).
            torch.Tensor: Distances to neighbors (M-dimensional tensor).
        """

        maxdistsq = maxdist**2
        rL = 1. / L  # Inverse box dimensions (D-dimensional tensor)

        # Calculate pairwise squared distances using broadcasting
        diff = pos.unsqueeze(1) - pos.unsqueeze(0)  # N x N x D
        diff_wrapped = diff - L.unsqueeze(0) * torch.floor(diff * rL.unsqueeze(0) + 0.5)
        distsq = torch.sum(diff_wrapped * diff_wrapped, dim=2)  # N x N

        # Mask out distances above the threshold and calculate square root efficiently
        dists = torch.sqrt(distsq[distsq < maxdistsq])

        # Efficiently collect neighbor indices using gather_nd
        idx_1, idx_2 = torch.where(distsq < maxdistsq)
        bonds = torch.stack([idx_1, idx_2], dim=1)  # M x 2

        return bonds, dists

    def nnDistance(self, i, bonds, dist):
        # Convert i to a 1D tensor for broadcasting
        i_tensor = i 

        # Efficiently filter and sort distances using indexing and slicing
        neighbor_idx = torch.where((bonds[:, 0] == i_tensor) | (bonds[:, 1] == i_tensor))[0]
        distances = dist[neighbor_idx]
        sorted_distances, _ = torch.topk(distances, 201, dim=0, largest=False, sorted=True)

        return sorted_distances[1::2]
        
    def getForOneFrame(self, bonds, dist):
        num_frames = 2000
        indices = torch.arange(num_frames)
        features = torch.stack([self.nnDistance(i, bonds, dist) for i in indices])  # Stack tensors for each frame
        return features
    
    def forward(self, positions, boxvectors):
        box = torch.tensor([float(boxvectors[i][i]) for i in range(3)])
        pos = positions[::4].to("cpu")
        bonds, dist = self.periodic_neighbours(pos, torch.tensor([1.0]), box)    
        features = self.getForOneFrame(bonds, dist).detach()# Update to handle combined tensor
        x = self.transform(features).to(torch.float)
        y = self.encoder(x)[:,1].sum()
        return y```

The openmm simulation(with MetaD):


import os, sys
import openmm as mm
import openmm.app as app
from openmm.app import GromacsGroFile, GromacsTopFile
from openmm.app import StateDataReporter
from openmm.app import XTCReporter
from openmm.app import PME, HBonds
from openmm.unit import nanometer, kelvin, picosecond, picoseconds, bar, kilojoules_per_mole
from openmmtorch import TorchForce
from openmm.app import BiasVariable
from openmm.app.metadynamics import Metadynamics

gro = GromacsGroFile('/home/dm/Dibyendu/Projects/ICE_AE/Phase_Space_Scaling/Liquid_test/md.gro')
top = GromacsTopFile('/home/dm/Dibyendu/Projects/ICE_AE/Phase_Space_Scaling/Liquid_test/topol.top',
                        periodicBoxVectors=gro.getPeriodicBoxVectors(),
                    includeDir='/home/dm/Soft/GMX22/share/gromacs/top/')

##### Create the OpenMM System based on the topology
system = top.createSystem(nonbondedMethod=PME, nonbondedCutoff=1*nanometer, constraints=HBonds)

##### Remove MM forces
while len(system.getForces()) > 0:
    system.removeForce(0)

force = TorchForce('forcemodel2.pt')
force.setUsesPeriodicBoundaryConditions(periodic=True)
psi = BiasVariable(force, -5, 5, 0.5, True)

meta = Metadynamics(system, [ psi],
                    250*kelvin, 1.2, 1.2*kilojoules_per_mole, 100)

platform = mm.Platform.getPlatformByName('CPU')

# Specify properties for CUDA platform (e.g., mixed precision)
#prop = dict(CudaPrecision='mixed')  # Use mixed single/double precision

# Add thermostat and barostat forces to the system
system.addForce(mm.AndersenThermostat(250*kelvin, 1/picosecond))
system.addForce(mm.MonteCarloBarostat(1*bar, 250*kelvin))

# Create Langevin integrator for molecular dynamics simulation
integrator = mm.LangevinMiddleIntegrator(250*kelvin, 1/picosecond, 0.0001*picoseconds)

# Create the OpenMM Simulation object
sim = app.Simulation(top.topology, system, integrator, platform)
config = gro
sim.context.setPositions(config.positions)
sim.minimizeEnergy()
meta.step(sim, 50000)
#sim.step(50000)
reporter = StateDataReporter(file=sys.stdout, reportInterval=100, step=True, time=True, potentialEnergy=True, temperature=True)
sim.reporters.append(reporter)

The error it throws:


---------------------------------------------------------------------------
OpenMMException                           Traceback (most recent call last)
/tmp/ipykernel_357599/4025831078.py in ?()
     32 # Create the OpenMM Simulation object
     33 sim = app.Simulation(top.topology, system, integrator, platform)
     34 config = gro
     35 sim.context.setPositions(config.positions)
---> 36 sim.minimizeEnergy()
     37 meta.step(sim, 50000)
     38 #sim.step(50000)
     39 reporter = StateDataReporter(file=sys.stdout, reportInterval=100, step=True, time=True, potentialEnergy=True, temperature=True)

~/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/app/simulation.py in ?(self, tolerance, maxIterations, reporter)
    139         reporter : MinimizationReporter = None
    140             an optional reporter to invoke after each iteration.  This can be used to monitor the progress
    141             of minimization or to stop minimization early.
    142         """
--> 143         mm.LocalEnergyMinimizer.minimize(self.context, tolerance, maxIterations, reporter)

~/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/openmm.py in ?(context, tolerance, maxIterations, reporter)
   4421         if unit.is_quantity(tolerance):
   4422             tolerance = tolerance.value_in_unit(unit.kilojoules_per_mole/unit.nanometer)
   4423 
   4424 
-> 4425         return _openmm.LocalEnergyMinimizer_minimize(context, tolerance, maxIterations, reporter)

OpenMMException: Expected a proper Tensor but got None (or an undefined Tensor in C++) for argument #0 'self'
Exception raised from checked_cast_variable at /home/conda/feedstock_root/build_artifacts/libtorch_1706712676143/work/torch/csrc/autograd/VariableTypeManual.cpp:60 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xaa (0x77db8353587a in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x77db834eac7e in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libc10.so)
frame #2: <unknown function> + 0x4a04cda (0x77db2fa04cda in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #3: <unknown function> + 0x36a1f5a (0x77db2e6a1f5a in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #4: <unknown function> + 0x36a25d0 (0x77db2e6a25d0 in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #5: at::_ops::_to_copy::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>) + 0x1f7 (0x77db2cacd127 in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #6: at::native::to(at::Tensor const&, c10::ScalarType, bool, bool, c10::optional<c10::MemoryFormat>) + 0xbc (0x77db2c5e311c in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #7: <unknown function> + 0x2446ea5 (0x77db2d446ea5 in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #8: at::_ops::to_dtype::call(at::Tensor const&, c10::ScalarType, bool, bool, c10::optional<c10::MemoryFormat>) + 0x188 (0x77db2cc44b98 in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/../libtorch_cpu.so)
frame #9: TorchPlugin::ReferenceCalcTorchForceKernel::execute(OpenMM::ContextImpl&, bool, bool) + 0xac0 (0x77db82114790 in /home/dm/Soft/miniconda3/envs/torch/lib/plugins/libOpenMMTorchReference.so)
frame #10: OpenMM::ContextImpl::calcForcesAndEnergy(bool, bool, int) + 0xc9 (0x77db7bb20109 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #11: OpenMM::ReferenceCustomCVForce::calculateIxn(OpenMM::ContextImpl&, std::vector<OpenMM::Vec3, std::allocator<OpenMM::Vec3> >&, std::map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, double, std::less<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, double> > > const&, std::vector<OpenMM::Vec3, std::allocator<OpenMM::Vec3> >&, double*, std::map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, double, std::less<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, double> > >&) + 0xfc (0x77db7bc3363c in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #12: OpenMM::ReferenceCalcCustomCVForceKernel::execute(OpenMM::ContextImpl&, OpenMM::ContextImpl&, bool, bool) + 0x2ff (0x77db7bc0b06f in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #13: OpenMM::ContextImpl::calcForcesAndEnergy(bool, bool, int) + 0xc9 (0x77db7bb20109 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #14: OpenMM::Context::getState(int, bool, int) const + 0x122 (0x77db7bb1d772 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #15: <unknown function> + 0x18847b (0x77db7bb8847b in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #16: <unknown function> + 0x188c26 (0x77db7bb88c26 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #17: lbfgs + 0x584 (0x77db7bbe8444 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #18: OpenMM::LocalEnergyMinimizer::minimize(OpenMM::Context&, double, int, OpenMM::MinimizationReporter*) + 0x7d9 (0x77db7bb89769 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/../../../libOpenMM.so.8.1)
frame #19: <unknown function> + 0x1293ae (0x77db887293ae in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/site-packages/openmm/_openmm.cpython-310-x86_64-linux-gnu.so)
frame #20: <unknown function> + 0x144468 (0x620ee873e468 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #21: _PyObject_MakeTpCall + 0x26b (0x620ee873797b in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #22: _PyEval_EvalFrameDefault + 0x54b6 (0x620ee87338c6 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #23: _PyFunction_Vectorcall + 0x6c (0x620ee873e8cc in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x4c12 (0x620ee8733022 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #25: _PyFunction_Vectorcall + 0x6c (0x620ee873e8cc in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #26: _PyEval_EvalFrameDefault + 0x72c (0x620ee872eb3c in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #27: <unknown function> + 0x1d7870 (0x620ee87d1870 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #28: PyEval_EvalCode + 0x87 (0x620ee87d17b7 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #29: <unknown function> + 0x1de9ba (0x620ee87d89ba in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #30: <unknown function> + 0x144a93 (0x620ee873ea93 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #31: _PyEval_EvalFrameDefault + 0x320 (0x620ee872e730 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #32: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #33: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #34: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #35: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #36: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #37: <unknown function> + 0x1f55f7 (0x620ee87ef5f7 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #38: <unknown function> + 0x14f3bd (0x620ee87493bd in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #39: _PyEval_EvalFrameDefault + 0x72c (0x620ee872eb3c in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #40: _PyFunction_Vectorcall + 0x6c (0x620ee873e8cc in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #41: _PyEval_EvalFrameDefault + 0x320 (0x620ee872e730 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #42: _PyFunction_Vectorcall + 0x6c (0x620ee873e8cc in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #43: _PyEval_EvalFrameDefault + 0x72c (0x620ee872eb3c in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #44: <unknown function> + 0x150402 (0x620ee874a402 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #45: PyObject_Call + 0xbc (0x620ee874ad9c in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #46: _PyEval_EvalFrameDefault + 0x2d84 (0x620ee8731194 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #47: <unknown function> + 0x150402 (0x620ee874a402 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #48: _PyEval_EvalFrameDefault + 0x13cc (0x620ee872f7dc in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #49: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #50: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #51: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #52: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #53: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #54: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #55: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #56: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #57: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #58: _PyEval_EvalFrameDefault + 0x1bc0 (0x620ee872ffd0 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #59: <unknown function> + 0x1e0fe4 (0x620ee87dafe4 in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #60: <unknown function> + 0x7bf6 (0x77db8ddf3bf6 in /home/dm/Soft/miniconda3/envs/torch/lib/python3.10/lib-dynload/_asyncio.cpython-310-x86_64-linux-gnu.so)
frame #61: <unknown function> + 0x143d2a (0x620ee873dd2a in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #62: <unknown function> + 0x25f22c (0x620ee885922c in /home/dm/Soft/miniconda3/envs/torch/bin/python)
frame #63: <unknown function> + 0xfda7b (0x620ee86f7a7b in /home/dm/Soft/miniconda3/envs/torch/bin/python)

My conda environment:

conda list:


# packages in environment at /home/dm/Soft/miniconda3/envs/torch:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
anyio                     4.2.0              pyhd8ed1ab_0    conda-forge
argon2-cffi               23.1.0             pyhd8ed1ab_0    conda-forge
argon2-cffi-bindings      21.2.0          py310h2372a71_4    conda-forge
arrow                     1.3.0              pyhd8ed1ab_0    conda-forge
asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
async-lru                 2.0.4              pyhd8ed1ab_0    conda-forge
attrs                     23.2.0             pyh71513ae_0    conda-forge
babel                     2.14.0             pyhd8ed1ab_0    conda-forge
beautifulsoup4            4.12.3             pyha770c72_0    conda-forge
blas                      1.0                         mkl
bleach                    6.1.0              pyhd8ed1ab_0    conda-forge
brotli-python             1.1.0           py310hc6cd4ac_1    conda-forge
bzip2                     1.0.8                hd590300_5    conda-forge
ca-certificates           2024.2.2             hbcca054_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
certifi                   2024.2.2           pyhd8ed1ab_0    conda-forge
cffi                      1.16.0          py310h2fee648_0    conda-forge
charset-normalizer        3.3.2              pyhd8ed1ab_0    conda-forge
comm                      0.2.1              pyhd8ed1ab_0    conda-forge
cuda-version              11.8                 h70ddcb2_2    conda-forge
cudatoolkit               11.8.0              h4ba93d1_13    conda-forge
cudnn                     8.8.0.121            hcdd5f01_4    conda-forge
debugpy                   1.8.1           py310hc6cd4ac_0    conda-forge
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
defusedxml                0.7.1              pyhd8ed1ab_0    conda-forge
entrypoints               0.4                pyhd8ed1ab_0    conda-forge
exceptiongroup            1.2.0              pyhd8ed1ab_2    conda-forge
executing                 2.0.1              pyhd8ed1ab_0    conda-forge
filelock                  3.13.1          py310h06a4308_0
fqdn                      1.5.1              pyhd8ed1ab_0    conda-forge
fsspec                    2024.2.0           pyhca7485f_0    conda-forge
gmp                       6.2.1                h295c915_3
gmpy2                     2.1.2           py310heeb90bb_0
h11                       0.14.0             pyhd8ed1ab_0    conda-forge
h2                        4.1.0              pyhd8ed1ab_0    conda-forge
hpack                     4.0.0              pyh9f0ad1d_0    conda-forge
httpcore                  1.0.2              pyhd8ed1ab_0    conda-forge
httpx                     0.26.0             pyhd8ed1ab_0    conda-forge
hyperframe                6.0.1              pyhd8ed1ab_0    conda-forge
icu                       73.2                 h59595ed_0    conda-forge
idna                      3.6                pyhd8ed1ab_0    conda-forge
importlib-metadata        7.0.1              pyha770c72_0    conda-forge
importlib_metadata        7.0.1                hd8ed1ab_0    conda-forge
importlib_resources       6.1.1              pyhd8ed1ab_0    conda-forge
intel-openmp              2022.0.1          h06a4308_3633
ipykernel                 6.29.2             pyhd33586a_0    conda-forge
ipython                   8.21.0             pyh707e725_0    conda-forge
ipywidgets                8.1.2              pyhd8ed1ab_0    conda-forge
isoduration               20.11.0            pyhd8ed1ab_0    conda-forge
jedi                      0.19.1             pyhd8ed1ab_0    conda-forge
jinja2                    3.1.3           py310h06a4308_0
joblib                    1.3.2              pyhd8ed1ab_0    conda-forge
json5                     0.9.14             pyhd8ed1ab_0    conda-forge
jsonpointer               2.4             py310hff52083_3    conda-forge
jsonschema                4.21.1             pyhd8ed1ab_0    conda-forge
jsonschema-specifications 2023.12.1          pyhd8ed1ab_0    conda-forge
jsonschema-with-format-nongpl 4.21.1             pyhd8ed1ab_0    conda-forge
jupyter                   1.0.0             pyhd8ed1ab_10    conda-forge
jupyter-lsp               2.2.2              pyhd8ed1ab_0    conda-forge
jupyter_client            8.6.0              pyhd8ed1ab_0    conda-forge
jupyter_console           6.6.3              pyhd8ed1ab_0    conda-forge
jupyter_core              5.7.1           py310hff52083_0    conda-forge
jupyter_events            0.9.0              pyhd8ed1ab_0    conda-forge
jupyter_server            2.12.5             pyhd8ed1ab_0    conda-forge
jupyter_server_terminals  0.5.2              pyhd8ed1ab_0    conda-forge
jupyterlab                4.1.0              pyhd8ed1ab_0    conda-forge
jupyterlab_pygments       0.3.0              pyhd8ed1ab_1    conda-forge
jupyterlab_server         2.25.2             pyhd8ed1ab_0    conda-forge
jupyterlab_widgets        3.0.10             pyhd8ed1ab_0    conda-forge
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
libabseil                 20230802.1      cxx17_h59595ed_0    conda-forge
libblas                   3.9.0           1_h86c2bf4_netlib    conda-forge
libcblas                  3.9.0           5_h92ddd45_netlib    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 13.2.0               h807b86a_5    conda-forge
libgfortran-ng            13.2.0               h69a702a_5    conda-forge
libgfortran5              13.2.0               ha4646dd_5    conda-forge
libgomp                   13.2.0               h807b86a_5    conda-forge
libhwloc                  2.9.3           default_h554bfaf_1009    conda-forge
libiconv                  1.17                 hd590300_2    conda-forge
liblapack                 3.9.0           5_h92ddd45_netlib    conda-forge
libmagma                  2.7.2                h09b5827_2    conda-forge
libmagma_sparse           2.7.2                h09b5827_2    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libprotobuf               4.25.1               hf27288f_1    conda-forge
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libsqlite                 3.45.1               h2797004_0    conda-forge
libstdcxx-ng              13.2.0               h7e041cc_5    conda-forge
libtorch                  2.1.2           cuda118_h12fe058_301    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libuv                     1.47.0               hd590300_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libxml2                   2.12.5               h232c23b_0    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
llvm-openmp               17.0.6               h4dfa4b3_0    conda-forge
magma                     2.7.2                h4aca40b_2    conda-forge
markupsafe                2.1.3           py310h5eee18b_0
matplotlib-inline         0.1.6              pyhd8ed1ab_0    conda-forge
mistune                   3.0.2              pyhd8ed1ab_0    conda-forge
mkl                       2023.2.0         h84fe81f_50496    conda-forge
mpc                       1.1.0                h10f8cd9_1
mpfr                      4.0.2                hb69a4c5_1
mpmath                    1.3.0           py310h06a4308_0
nbclient                  0.8.0              pyhd8ed1ab_0    conda-forge
nbconvert                 7.16.0             pyhd8ed1ab_0    conda-forge
nbconvert-core            7.16.0             pyhd8ed1ab_0    conda-forge
nbconvert-pandoc          7.16.0             pyhd8ed1ab_0    conda-forge
nbformat                  5.9.2              pyhd8ed1ab_0    conda-forge
nccl                      2.19.4.1             h6103f9b_0    conda-forge
ncurses                   6.4                  h59595ed_2    conda-forge
nest-asyncio              1.6.0              pyhd8ed1ab_0    conda-forge
networkx                  3.1             py310h06a4308_0
notebook                  7.0.7              pyhd8ed1ab_0    conda-forge
notebook-shim             0.2.3              pyhd8ed1ab_0    conda-forge
numpy                     1.26.4          py310hb13e2d6_0    conda-forge
ocl-icd                   2.3.1                h7f98852_0    conda-forge
ocl-icd-system            1.0.0                         1    conda-forge
openmm                    8.1.1           py310h43b6314_1    conda-forge
openmm-torch              1.4             cuda118py310hde6f947_3    conda-forge
openssl                   3.2.1                hd590300_0    conda-forge
overrides                 7.7.0              pyhd8ed1ab_0    conda-forge
packaging                 23.2               pyhd8ed1ab_0    conda-forge
pandoc                    3.1.11.1             ha770c72_0    conda-forge
pandocfilters             1.5.0              pyhd8ed1ab_0    conda-forge
parso                     0.8.3              pyhd8ed1ab_0    conda-forge
pexpect                   4.9.0              pyhd8ed1ab_0    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pip                       24.0               pyhd8ed1ab_0    conda-forge
pkgutil-resolve-name      1.3.10             pyhd8ed1ab_1    conda-forge
platformdirs              4.2.0              pyhd8ed1ab_0    conda-forge
prometheus_client         0.19.0             pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.42             pyha770c72_0    conda-forge
prompt_toolkit            3.0.42               hd8ed1ab_0    conda-forge
psutil                    5.9.8           py310h2372a71_0    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.2              pyhd8ed1ab_0    conda-forge
pycparser                 2.21               pyhd8ed1ab_0    conda-forge
pygments                  2.17.2             pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
python                    3.10.13         hd12c33a_1_cpython    conda-forge
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python-fastjsonschema     2.19.1             pyhd8ed1ab_0    conda-forge
python-json-logger        2.0.7              pyhd8ed1ab_0    conda-forge
python_abi                3.10                    4_cp310    conda-forge
pytorch                   2.1.2           cuda118_py310h59774e7_301    conda-forge
pytorch-mutex             1.0                         cpu    pytorch
pytz                      2024.1             pyhd8ed1ab_0    conda-forge
pyyaml                    6.0.1           py310h2372a71_1    conda-forge
pyzmq                     25.1.2          py310h795f18f_0    conda-forge
qtconsole-base            5.5.1              pyha770c72_0    conda-forge
qtpy                      2.4.1              pyhd8ed1ab_0    conda-forge
readline                  8.2                  h8228510_1    conda-forge
referencing               0.33.0             pyhd8ed1ab_0    conda-forge
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
rfc3339-validator         0.1.4              pyhd8ed1ab_0    conda-forge
rfc3986-validator         0.1.1              pyh9f0ad1d_0    conda-forge
rpds-py                   0.17.1          py310hcb5633a_0    conda-forge
scikit-learn              1.4.0           py310h1fdf081_0    conda-forge
scipy                     1.12.0          py310hb13e2d6_2    conda-forge
send2trash                1.8.2              pyh41d4057_0    conda-forge
setuptools                69.0.3             pyhd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sleef                     3.5.1                h9b69904_2    conda-forge
sniffio                   1.3.0              pyhd8ed1ab_0    conda-forge
soupsieve                 2.5                pyhd8ed1ab_1    conda-forge
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
sympy                     1.12            py310h06a4308_0
tbb                       2021.11.0            h00ab1b0_1    conda-forge
terminado                 0.18.0             pyh0d859eb_0    conda-forge
threadpoolctl             3.2.0              pyha21a80b_0    conda-forge
tinycss2                  1.2.1              pyhd8ed1ab_0    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
tomli                     2.0.1              pyhd8ed1ab_0    conda-forge
tornado                   6.3.3           py310h2372a71_1    conda-forge
traitlets                 5.14.1             pyhd8ed1ab_0    conda-forge
types-python-dateutil     2.8.19.20240106    pyhd8ed1ab_0    conda-forge
typing-extensions         4.9.0           py310h06a4308_1
typing_extensions         4.9.0           py310h06a4308_1
typing_utils              0.1.0              pyhd8ed1ab_0    conda-forge
tzdata                    2024a                h0c530f3_0    conda-forge
uri-template              1.3.0              pyhd8ed1ab_0    conda-forge
urllib3                   2.2.0              pyhd8ed1ab_0    conda-forge
wcwidth                   0.2.13             pyhd8ed1ab_0    conda-forge
webcolors                 1.13               pyhd8ed1ab_0    conda-forge
webencodings              0.5.1              pyhd8ed1ab_2    conda-forge
websocket-client          1.7.0              pyhd8ed1ab_0    conda-forge
wheel                     0.42.0             pyhd8ed1ab_0    conda-forge
widgetsnbextension        4.0.10             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
yaml                      0.2.5                h7f98852_2    conda-forge
zeromq                    4.3.5                h59595ed_0    conda-forge
zipp                      3.17.0             pyhd8ed1ab_0    conda-forge
zstd                      1.5.5                hfc55251_0    conda-forge

Used mamba to install openmm-torch.
I suppose this is a question and not likely a bug. It will be really helpful you can find where am I making mistake!

@dmighty007
Copy link
Author

Please let me know if you require any more information.

@RaulPPelaez
Copy link
Contributor

I believe this is an error with your model. In particular because lines like this:

        pos = positions[::4].to("cpu")
...
        y = self.encoder(x)[:,1].sum()

Pytorch does not like when you run backwards only on a subset of the output.
To test this try to run backwards on the model alone (no openmm or openmm-torch involved). My guess is you will see a similar error. Something like this:

import torch
pos = torch.rand(10, 3)
box = torch.eye(3) * 10
model = torch.jit.load("model.pt")
pos.requires_grad_()
y = model(pos, box)
y.backward()  # compute gradients
print(pos.grad)

As a side note, the box is already passed to your model as a 3x3 pytorch tensor, you should not need to convert it. You can extract its diagonal with "box.diag()"

@dmighty007
Copy link
Author

Thanks! I got your point. It does returns a NoneType. It makes sense. I'll try to modify the model to operable on whole system. But are there any trick to do certain operation on subset of the positions? Say I want only the oxygen atom of my system!

@dmighty007
Copy link
Author

Also I did noticed the box.diag() in the documentation, but was lazy to change :).

@RaulPPelaez
Copy link
Contributor

You want your TorchForce to act only on a subset of the system?
As an easy workaround you can just multiply by zero the ones you do not want, right?

@dmighty007
Copy link
Author

Yes. Thanks! I'll try that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants