Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve explosions in memory and time during inference #102

Merged
merged 15 commits into from
Apr 5, 2024
7 changes: 5 additions & 2 deletions openff/nagl/features/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,14 +282,17 @@ def _encode(self, molecule: "Molecule") -> torch.Tensor:
lowest_energy_only=True,
include_all_transfer_pathways=False,
as_dicts=True,
as_fragments=True,
)
formal_charges: typing.List[float] = []
for index in range(molecule.n_atoms):
charges = [
graph["atoms"][index]["formal_charge"] for graph in resonance_forms
graph["atoms"][index]["formal_charge"]
for graph in resonance_forms
if index in graph["atoms"]
]
if not charges:
molecule.atoms[index].formal_charge
charges = [molecule.atoms[index].formal_charge]

charges = [q.m_as(unit.elementary_charge) for q in charges]
charge = np.mean(charges)
Expand Down
6 changes: 2 additions & 4 deletions openff/nagl/molecule/_graph/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from openff.nagl.features._featurizers import AtomFeaturizer, BondFeaturizer

import networkx as nx
import numpy as np
import torch
from ._batch import FrameDict

Expand Down Expand Up @@ -136,13 +137,10 @@ def number_of_nodes(self):

def in_edges(self, nodes, form="uv"):
u, v, i = self._all_edges()
# mask = [x in nodes for x in v]

mask = []
for node in nodes:
for i_, v_ in enumerate(v):
if v_ == node:
mask.append(i_)
mask.extend(np.where(v == node)[0])

U, V, I = u[mask], v[mask], i[mask]

Expand Down
16 changes: 16 additions & 0 deletions openff/nagl/tests/nn/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,22 @@ def test_load_and_compute(self, smiles):

assert_allclose(computed, desired, atol=1e-5)

def test_protein_computable(self):
"""
Test that working with moderately sized protein
is feasible in time and memory.
See Issue #101
"""
from openff.toolkit import Molecule

model = GNNModel.load(EXAMPLE_AM1BCC_MODEL, eval_mode=True)

protein = Molecule.from_smiles(
"CC[C@H](C)[C@H](NC(=O)CNC(=O)CNC(=O)[C@H](CCC(N)=O)NC(=O)[C@H](CCC(N)=O)NC(=O)[C@H](CC(=O)[O-])NC(=O)[C@H](CC(=O)[O-])NC(=O)[C@H](CCCNC(N)=[NH2+])NC(=O)CNC(=O)[C@H](CS)NC(=O)[C@@H](NC(=O)[C@H](CCCNC(N)=[NH2+])NC(=O)[C@H](Cc1ccccc1)NC(=O)[C@@H](NC(=O)[C@H](CC(N)=O)NC(=O)[C@H](Cc1c[nH]c2ccccc12)NC(=O)[C@H](CC(C)C)NC(=O)CNC(=O)[C@H](CCCNC(N)=[NH2+])NC(=O)[C@H](CCC(=O)[O-])NC(=O)[C@@H](NC(=O)[C@H](Cc1ccccc1)NC(=O)[C@@H](NC(=O)[C@@H]1CCCN1C(=O)[C@H](CC(N)=O)NC(=O)[C@@H](NC(=O)[C@H](C)NC(=O)[C@H](C)NC(=O)[C@@H]1CCCN1C(=O)[C@@H](NC(=O)[C@@H](NC(=O)[C@H](CCC(=O)[O-])NC(=O)[C@H](CC(C)C)NC(=O)[C@@H](NC(=O)CNC(=O)[C@H](CC(N)=O)NC(=O)[C@H](CCC(=O)[O-])NC(=O)[C@H](C)NC(=O)[C@H](Cc1ccc(O)cc1)NC(=O)[C@@H](NC(=O)[C@H](CCC(=O)[O-])NC(=O)[C@H](CO)NC(=O)[C@H](C)NC(=O)[C@@H]1CCCN1C(=O)[C@@H](NC(=O)[C@H](CO)NC(=O)[C@@H](NC(=O)CNC(=O)[C@H](CO)NC(=O)[C@H](CC(N)=O)NC(=O)[C@H](Cc1ccccc1)NC(=O)[C@H](Cc1cnc[nH]1)NC(=O)CNC(=O)[C@@H](NC(=O)[C@@H](NC(=O)[C@H](Cc1ccccc1)NC(=O)[C@H](CCCC[NH3+])NC(=O)[C@@H](NC(=O)CNC(=O)[C@H](CC(=O)[O-])NC(=O)[C@H](C)NC(=O)[C@@H](NC(=O)[C@H](Cc1ccccc1)NC(=O)[C@H](CCCC[NH3+])NC(=O)[C@H](CC(N)=O)NC(=O)[C@H](C)NC(=O)[C@@H](NC(=O)[C@H](CO)NC(=O)[C@@H](NC(=O)CNC(=O)[C@H](CCC(N)=O)NC(=O)[C@H](CCCC[NH3+])NC(=O)[C@@H]1CCCN1C(=O)[C@H](CC(=O)[O-])NC(=O)[C@H](CO)NC(=O)[C@@H](NC(=O)[C@H](CC(=O)[O-])NC(=O)[C@H](CC(=O)[O-])NC(=O)CNC(=O)[C@H](CC(C)C)NC(=O)[C@@H](NC(=O)[C@@H](NC(=O)[C@H](CCCC[NH3+])NC(=O)[C@@H](NC(=O)[C@H](CCC(N)=O)NC(=O)[C@H](CCC(=O)[O-])NC(=O)CNC(=O)[C@H](CC(N)=O)NC(=O)[C@@H](NC(=O)CNC(=O)CNC(=O)[C@H](C)NC(=O)[C@H](C)NC(=O)[C@H](CC(N)=O)NC(=O)[C@@H](NC(=O)[C@H](CC(=O)[O-])NC(=O)[C@H](CCCC[NH3+])NC(=O)[C@H](C)NC(=O)[C@H](C)NC(=O)[C@H](CCC(N)=O)NC(=O)[C@H](CCC(=O)[O-])NC(=O)[C@H](C)NC(=O)CNC(=O)[C@H](CCCC[NH3+])NC(=O)[C@H](CCC(N)=O)NC(=O)[C@@H](NC(=O)[C@H](CCC(N)=O)NC(=O)[C@H](C)NC(=O)CNC(=O)[C@H](Cc1ccccc1)NC(=O)[C@H](C)NC(=O)[C@H](C)NC(=O)[C@H](CC(N)=O)NC(=O)[C@@H]1CCCN1C(=O)CNC(=O)[C@@H](NC(=O)[C@H](CC(C)C)NC(=O)[C@@H]1CCCN1C(=O)[C@H](C)NC(=O)CNC(=O)[C@@H](NC(=O)[C@H](C)NC(=O)[C@@H](NC(=O)[C@@H](NC(=O)[C@@H](NC(=O)[C@H](CC(=O)[O-])NC(=O)[C@@H]([NH3+])CCSC)C(C)C)C(C)C)[C@@H](C)CC)C(C)C)[C@@H](C)O)[C@@H](C)CC)[C@@H](C)CC)[C@@H](C)CC)[C@@H](C)CC)[C@@H](C)CC)C(C)C)C(C)C)[C@@H](C)CC)C(C)C)C(C)C)C(C)C)C(C)C)C(C)C)C(C)C)[C@@H](C)CC)C(C)C)[C@@H](C)CC)[C@@H](C)CC)[C@@H](C)O)[C@@H](C)O)C(C)C)[C@@H](C)O)[C@@H](C)O)[C@@H](C)O)C(=O)N[C@@H](C)C(=O)NCC(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](C)C(=O)N[C@@H](CC(=O)[O-])C(=O)N[C@@H](Cc1c[nH]cn1)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](CC(=O)[O-])C(=O)N[C@@H](C)C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@H](C(=O)N[C@@H](C)C(=O)N[C@H](C(=O)N[C@H](C(=O)N[C@@H](Cc1cnc[nH]1)C(=O)N[C@@H](CC(=O)[O-])C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@H](C(=O)N1CCC[C@H]1C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)NCC(=O)N[C@@H](CCC(N)=O)C(=O)NCC(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](C)C(=O)N[C@@H](CC(=O)[O-])C(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@H](C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](C)C(=O)N[C@@H](C)C(=O)N[C@@H](CC(N)=O)C(=O)N[C@@H](C)C(=O)N[C@@H](C)C(=O)NCC(=O)N[C@H](C(=O)N[C@H](C(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@H](C(=O)N[C@@H](CCSC)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@@H](CCC(=O)[O-])C(=O)NCC(=O)N[C@H](C(=O)N[C@@H](CC(N)=O)C(=O)N[C@H](C(=O)N[C@@H](CS[C@H]1CC(=O)N(c2ccc3c(c2)C(=O)OC32c3ccc(O)cc3Oc3cc(O)ccc32)C1=O)C(=O)N[C@@H](CC(=O)[O-])C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](CC(=O)[O-])C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CO)C(=O)N[C@@H](C)C(=O)N[C@@H](CC(C)C)C(=O)N[C@H](C(=O)N[C@@H](CO)C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](CCSC)C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@@H](C)C(=O)NCC(=O)N[C@H](C(=O)N[C@@H](CO)C(=O)N[C@H](C(=O)N[C@H](C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@@H](Cc1c[nH]c2ccccc12)C(=O)NCC(=O)NCC(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](Cc1c[nH]cn1)C(=O)N[C@H](C(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@@H](C)C(=O)NCC(=O)N[C@@H](CC(C)C)C(=O)N[C@H](C(=O)N[C@H](C(=O)N[C@@H](CCCNC(N)=[NH2+])C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](C)C(=O)N[C@@H](C)C(=O)N[C@@H](CC(=O)[O-])C(=O)N[C@@H](CCC(N)=O)C(=O)NCC(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](C)C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](CC(C)C)C(=O)N[C@H](C(=O)N[C@@H](CO)C(=O)NCC(=O)N[C@@H](CC(=O)[O-])C(=O)NCC(=O)N[C@H](C(=O)N[C@H](C(=O)N[C@@H](CO)C(=O)N[C@@H](CC(N)=O)C(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](C)C(=O)N[C@@H](CO)C(=O)N[C@H](C(=O)N[C@@H](C)C(=O)NCC(=O)N[C@@H](CC(=O)[O-])C(=O)N[C@@H](C)C(=O)N[C@H](C(=O)N[C@@H](CCC(=O)[O-])C(=O)NCC(=O)N[C@H](C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CC(N)=O)C(=O)N[C@H](C(=O)N[C@@H](Cc1ccccc1)C(=O)NCC(=O)N1CCC[C@H]1C(=O)N[C@@H](CC(=O)[O-])C(=O)N1CCC[C@H]1C(=O)N[C@H](C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CCCNC(N)=[NH2+])C(=O)N1CCC[C@H]1C(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@@H](CC(N)=O)C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@@H](CC(C)C)C(=O)N[C@H](C(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](C)C(=O)N[C@@H](C)C(=O)NCC(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CC(N)=O)C(=O)N1CCC[C@H]1C(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@@H](C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@@H](CO)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@@H](C)C(=O)N[C@@H](C)C(=O)N[C@@H](CCSC)C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](C)C(=O)N[C@H](C(=O)N[C@@H](C)C(=O)NCC(=O)N[C@@H](C)C(=O)N[C@@H](C)C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](C)C(=O)N[C@@H](C)C(=O)NCC(=O)N[C@@H](CO)C(=O)N[C@H](C(=O)N[C@@H](CCC(=O)[O-])C(=O)N1CCC[C@H]1C(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@H](C(=O)N[C@@H](C)C(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@@H](C)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](CCCC[NH3+])C(=O)NCC(=O)N[C@@H](CO)C(=O)N[C@@H](Cc1ccccc1)C(=O)N1CCC[C@H]1C(=O)N[C@H](C(=O)N[C@@H](C)C(=O)N[C@@H](CC(C)C)C(=O)NCC(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@H](C(=O)N[C@@H](CO)C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@@H](CC(=O)[O-])C(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@@H](CCCC[NH3+])C(=O)NCC(=O)N[C@@H](CC(=O)[O-])C(=O)N1CCC[C@H]1C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](CC(C)C)C(=O)N1CCC[C@H]1C(=O)NCC(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](C(=O)N[C@@H](CCSC)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@@H](CCC(=O)[O-])C(=O)N[C@@H](Cc1c[nH]c2ccccc12)C(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](CCCC[NH3+])C(=O)NCC(=O)N1CCC[C@H]1C(=O)N[C@@H](CC(=O)[O-])C(=O)NCC(=O)N[C@@H](CCCC[NH3+])C(=O)N[C@@H](Cc1ccccc1)C(=O)N[C@H](C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@H](C(=O)N[C@@H](CCC(N)=O)C(=O)N[C@@H](CCC(N)=O)C(=O)NC)[C@@H](C)CC)[C@@H](C)O)C(C)C)[C@@H](C)CC)[C@@H](C)O)C(C)C)C(C)C)[C@@H](C)CC)[C@@H](C)O)C(C)C)[C@@H](C)O)[C@@H](C)O)[C@@H](C)O)C(C)C)[C@@H](C)CC)C(C)C)[C@@H](C)CC)C(C)C)[C@@H](C)CC)[C@@H](C)CC)[C@@H](C)O)[C@@H](C)CC)[C@@H](C)CC)C(C)C)[C@@H](C)CC)C(C)C)C(C)C)C(C)C)[C@@H](C)O)C(C)C)[C@@H](C)O)[C@@H](C)O)[C@@H](C)CC)[C@@H](C)CC)C(C)C"
)
model.compute_property(protein, as_numpy=True)

def test_save(self, am1bcc_model, openff_methane_uncharged, tmpdir, expected_methane_charges):
with tmpdir.as_cwd():
am1bcc_model.save("model.pt")
Expand Down
63 changes: 63 additions & 0 deletions openff/nagl/tests/utils/test_openff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
smiles_to_inchi_key,
calculate_circular_fingerprint_similarity,
capture_toolkit_warnings,
molecule_from_networkx,
_molecule_from_dict,
_molecule_to_dict,
)
from openff.nagl.utils._utils import transform_coordinates

Expand Down Expand Up @@ -228,3 +231,63 @@ def test_openff_toolkit_registry(openff_methane_uncharged):
rdkit_registry = ToolkitRegistry([NAGLRDKitToolkitWrapper()])
with toolkit_registry_manager(rdkit_registry):
normalize_molecule(openff_methane_uncharged)


def test_molecule_from_networkx(openff_methane_uncharged):
graph = openff_methane_uncharged.to_networkx()
molecule = molecule_from_networkx(graph)
assert len(molecule.atoms) == 5

atomic_numbers = [atom.atomic_number for atom in molecule.atoms]
assert atomic_numbers == [6, 1, 1, 1, 1]
is_aromatic = [atom.is_aromatic for atom in molecule.atoms]
assert is_aromatic == [False, False, False, False, False]
formal_charges = [atom.formal_charge for atom in molecule.atoms]
assert formal_charges == [0, 0, 0, 0, 0]
bond_orders = [bond.bond_order for bond in molecule.bonds]
assert bond_orders == [1, 1, 1, 1]

assert molecule.is_isomorphic_with(openff_methane_uncharged)


def test_molecule_to_dict(openff_methane_uncharged):
graph = _molecule_to_dict(openff_methane_uncharged)
atoms = graph["atoms"]
bonds = graph["bonds"]
assert len(atoms) == 5
assert len(bonds) == 4

c = {
"atomic_number": 6,
"is_aromatic": False,
"formal_charge": 0,
"stereochemistry": None
}
h = {
"atomic_number": 1,
"is_aromatic": False,
"formal_charge": 0,
"stereochemistry": None
}
assert atoms[0] == c
assert atoms[1] == h
assert atoms[2] == h
assert atoms[3] == h
assert atoms[4] == h

ch_bond = {
"bond_order": 1,
"is_aromatic": False,
"stereochemistry": None,
}

assert bonds[(0, 1)] == ch_bond
assert bonds[(0, 2)] == ch_bond
assert bonds[(0, 3)] == ch_bond
assert bonds[(0, 4)] == ch_bond


def test_molecule_from_dict(openff_methane_uncharged):
graph = _molecule_to_dict(openff_methane_uncharged)
molecule = _molecule_from_dict(graph)
assert molecule.is_isomorphic_with(openff_methane_uncharged)
98 changes: 96 additions & 2 deletions openff/nagl/toolkits/openff.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import copy
import functools
from typing import TYPE_CHECKING, Tuple, List, Union, Dict
from typing import TYPE_CHECKING, Tuple, List, Union, Dict, NamedTuple, Any, Optional

import numpy as np

Expand Down Expand Up @@ -380,9 +380,12 @@ def normalize_molecule(
"[N,P,As,Sb;-1:1]=[C+;v3:2]>>[*+0:1]#[C+0:2]", # Charge recombination
)

molecule_ = type(molecule)(molecule)
molecule_._conformers = None

normalized = call_toolkit_function(
"_run_normalization_reactions",
molecule=molecule,
molecule=molecule_,
normalization_reactions=normalizations,
max_iter=max_iter,
toolkit_registry=toolkit_registry,
Expand Down Expand Up @@ -711,3 +714,94 @@ def molecule_from_networkx(graph):
stereochemistry=info.get("stereochemistry", None),
)
return molecule


def _molecule_to_dict(molecule: "Molecule") -> dict[str, dict]:
"""
Convert an OpenFF molecule to a graph representation.

Parameters
----------
molecule
The molecule to convert.


Returns
-------
graph: dict[str, dict]
This is a dictionary with the keys "atoms" and "bonds".
The "atoms" key maps to a dictionary of atom indices to atom information.
Each atom information dictionary contains the following keys:
atomic_number, formal_charge, is_aromatic, stereochemistry.
The "bonds" key maps to a dictionary of bond indices as a tuple of integers.
The bond indices are sorted so the lowest value is first.
Each bond indices tuple is mapped to bond information.
Each bond information dictionary contains the following keys:
bond_order, is_aromatic, stereochemistry.
"""
atoms = {}
for i, atom in enumerate(molecule.atoms):
atoms[i] = {
"atomic_number": atom.atomic_number,
"formal_charge": atom.formal_charge,
"is_aromatic": atom.is_aromatic,
"stereochemistry": atom.stereochemistry,
}

bonds = {}
for bond in molecule.bonds:
indices = tuple(sorted((bond.atom1_index, bond.atom2_index)))
bonds[indices] = {
"bond_order": bond.bond_order,
"is_aromatic": bond.is_aromatic,
"stereochemistry": bond.stereochemistry,
}

return {"atoms": atoms, "bonds": bonds}



def _molecule_from_dict(graph: dict[str, dict]) -> "Molecule":
"""
Convert a graph representation to an OpenFF molecule.

Parameters
----------
graph
The graph representation to convert.
This is a dictionary with the keys "atoms" and "bonds".
The "atoms" key maps to a dictionary of atom indices to atom information.
Each atom information dictionary contains the following keys:
atomic_number, formal_charge, is_aromatic, stereochemistry.
The "bonds" key maps to a dictionary of bond indices as a tuple of integers.
The bond indices are sorted so the lowest value is first.
Each bond indices tuple is mapped to bond information.
Each bond information dictionary contains the following keys:
bond_order, is_aromatic, stereochemistry.

Returns
-------
molecule
The OpenFF molecule representation.
"""
from openff.toolkit.topology import Molecule

molecule = Molecule()
for atom_index in sorted(graph["atoms"]):
atom = graph["atoms"][atom_index]
molecule.add_atom(
atomic_number=atom["atomic_number"],
formal_charge=atom["formal_charge"],
is_aromatic=atom["is_aromatic"],
stereochemistry=atom.get("stereochemistry", None),
)

for (u, v), info in graph["bonds"].items():
molecule.add_bond(
u,
v,
bond_order=info["bond_order"],
is_aromatic=info["is_aromatic"],
stereochemistry=info.get("stereochemistry", None),
)
return molecule
Loading
Loading