Skip to content

Commit

Permalink
Merge pull request #229 from hjkgrp/permutational_kernel
Browse files Browse the repository at this point in the history
Permutation invariant ML models
  • Loading branch information
ralf-meyer committed May 6, 2024
2 parents 2959f90 + a483469 commit bf9b0e5
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 68 deletions.
1 change: 0 additions & 1 deletion .github/workflows/python-linter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ jobs:
pip install mypy types-setuptools types-PyYAML types-requests types-tensorflow types-beautifulsoup4 pandas-stubs PyQt5-stubs
- name: Typecheck with mypy
run: |
# Exclude parts of Informatics for now
mypy --ignore-missing-imports molSimplify
- name: Report Status
Expand Down
15 changes: 15 additions & 0 deletions molSimplify/Classes/mol2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Union
from packaging import version
from molSimplify.Classes.globalvars import globalvars
from molSimplify.Classes.mol3D import mol3D as Mol3D

try:
from openbabel import openbabel # version 3 style import
Expand Down Expand Up @@ -122,6 +123,20 @@ def from_mol_file(cls, filename):

return mol

@classmethod
def from_mol3d(cls, mol3d: Mol3D):
if len(mol3d.graph) == 0:
raise ValueError("Mol3D object does not have molecular graph attached.")

mol = cls()

for i, atom in enumerate(mol3d.atoms):
mol.add_node(i, symbol=atom.sym)

bonds = ((int(e[0]), int(e[1])) for e in zip(*mol3d.graph.nonzero()))
mol.add_edges_from(bonds)
return mol

def graph_hash(self) -> str:
"""Calculates the node attributed graph hash of the molecule.
Expand Down
58 changes: 50 additions & 8 deletions molSimplify/Informatics/graph_racs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def atom_centered_AC(
mol, source=starting_node, cutoff=depth
)
p_i = property_fun(mol, starting_node)
output = np.zeros((depth + 1, len(p_i)))
output = np.zeros((len(p_i), depth + 1))
for node, d_ij in lengths.items():
p_j = property_fun(mol, node)
output[d_ij] += operation(p_i, p_j)
output[:, d_ij] += operation(p_i, p_j)
return output


Expand Down Expand Up @@ -103,14 +103,14 @@ def multi_centered_AC(
full-scope autocorrelation vector
"""
n_props = len(property_fun(mol, list(mol.nodes.keys())[0]))
output = np.zeros((depth + 1, n_props))
output = np.zeros((n_props, depth + 1))
# Generate all pairwise path lengths
lengths = nx.all_pairs_shortest_path_length(mol, cutoff=depth)
for node_i, lengths_i in lengths:
p_i = property_fun(mol, node_i)
for node_j, d_ij in lengths_i.items():
p_j = property_fun(mol, node_j)
output[d_ij] += operation(p_i, p_j)
output[:, d_ij] += operation(p_i, p_j)
return output


Expand All @@ -123,7 +123,7 @@ def octahedral_racs(
# Following J. Phys. Chem. A 2017, 121, 8939 there are 6 start/scope
# combinations for product ACs and 3 for difference ACs.
n_props = len(property_fun(mol, list(mol.nodes.keys())[0]))
output = np.zeros((6 + 3, depth + 1, n_props))
output = np.zeros((6 + 3, n_props, depth + 1))

# start = f, scope = all, product
output[0] = multi_centered_AC(mol, depth=depth, property_fun=property_fun)
Expand Down Expand Up @@ -198,13 +198,15 @@ def octahedral_racs(
axis=0,
)

# start = f, scope = ax, product
output[4] = np.mean(
[
multi_centered_AC(g, depth=depth, property_fun=property_fun)
for (_, g) in axial_ligands
],
axis=0,
)
# start = f, scope = ax, product
output[5] = np.mean(
[
multi_centered_AC(g, depth=depth, property_fun=property_fun)
Expand Down Expand Up @@ -243,6 +245,29 @@ def octahedral_racs(
return output


def octahedral_racs_names(depth=3, properties=None) -> List[str]:
if properties is None:
properties = ["Z", "chi", "T", "I", "S"]

start_scope = [
("f", "all"),
("mc", "all"),
("lc", "ax"),
("lc", "eq"),
("f", "ax"),
("f", "eq"),
("D_mc", "all"),
("D_lc", "ax"),
("D_lc", "eq"),
]
return [
f"{start}-{prop}-{d}-{scope}"
for start, scope in start_scope
for prop in properties
for d in range(0, depth + 1)
]


def ligand_racs(
mol: Mol2D,
depth: int = 3,
Expand All @@ -261,8 +286,8 @@ def ligand_racs(

n_ligands = len(connecting_atoms)
n_props = len(property_fun(mol, list(mol.nodes.keys())[0]))
n_scopes = 4 if full_scope else 2
output = np.zeros((n_ligands, n_scopes, depth + 1, n_props))
n_scopes = 3 if full_scope else 2
output = np.zeros((n_ligands, n_scopes, n_props, depth + 1))

# Then cut the graph by removing all connections to the metal atom
subgraphs.remove_edges_from([(metal, c) for c in connecting_atoms])
Expand All @@ -283,6 +308,23 @@ def ligand_racs(
# Add full scope RACs if requested
if full_scope:
output[i, 2] = multi_centered_AC(g, depth=depth, operation=operator.mul, property_fun=property_fun)
output[i, 3] = multi_centered_AC(g, depth=depth, operation=operator.sub, property_fun=property_fun)

return output


def ligand_racs_names(depth: int = 3, properties=None, full_scope: bool = True) -> List[str]:
if properties is None:
properties = ["Z", "chi", "T", "I", "S"]

starts = [
"lc",
"D_lc",
]
if full_scope:
starts += ["f"]
return [
f"{start}-{prop}-{d}"
for start in starts
for prop in properties
for d in range(0, depth + 1)
]
71 changes: 71 additions & 0 deletions molSimplify/ml/kernels.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from sklearn.gaussian_process.kernels import Kernel


Expand Down Expand Up @@ -34,3 +35,73 @@ def __repr__(self):
def is_stationary(self):
"""Returns whether the kernel is stationary."""
return self.kernel.is_stationary()


class PermutationalKernel(Kernel):
def __init__(self, shape, permutations, kernel):
self.shape = shape
self.permutations = permutations
self.kernel = kernel

@property
def theta(self):
return self.kernel.theta

@theta.setter
def theta(self, theta):
self.kernel.theta = theta

@property
def bounds(self):
return self.kernel.bounds

def __call__(self, X, Y=None, eval_gradient=False):
n = X.shape[0]
n_perms = len(self.permutations)

# The main idea of this implementation is to vectorize the double loop over the
# permutations. This is done by building a new X array that includes all possible
# permutations of the input features. The kernel is then evaluated on this reshaped
# array and the result is averaged over the permutations.
X_reshaped = X.reshape(-1, *self.shape)
X_permuted = np.stack(
[X_reshaped[:, perm].reshape(X.shape) for perm in self.permutations],
axis=1
).reshape(n*n_perms, -1)

if eval_gradient:
if Y is not None:
raise ValueError("Gradient can only be evaluated when Y is None.")

K, K_grad = self.kernel(X_permuted, eval_gradient=True)
# Reshape and average over the permutations
return (
K.reshape(n, n_perms, n, n_perms).sum(axis=(1, 3)) / n_perms**2,
K_grad.reshape(n, n_perms, n, n_perms, -1).sum(axis=(1, 3))
/ n_perms**2,
)

if Y is None:
# Reshape and average over the permutations
return self.kernel(X_permuted).reshape(n, n_perms, n, n_perms).sum(axis=(1, 3)) / n_perms ** 2

m = Y.shape[0]
Y_reshaped = Y.reshape(-1, *self.shape)
Y_permuted = np.stack(
[Y_reshaped[:, perm].reshape(Y.shape) for perm in self.permutations],
axis=1
).reshape(m*n_perms, -1)
# Reshape and average over the permutations
return self.kernel(X_permuted, Y_permuted).reshape(n, n_perms, m, n_perms).sum(axis=(1, 3)) / n_perms ** 2

def diag(self, X):
# TODO: More efficient implementation
return np.diag(self(X))

def __repr__(self):
return "PermutationalKernel(shape={0}, permutations={1}, kernel={2})".format(
self.shape, self.permutations, self.kernel)

def is_stationary(self):
"""Returns whether the kernel is stationary."""
return self.kernel.is_stationary()
10 changes: 10 additions & 0 deletions molSimplify/ml/layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import tensorflow as tf
from typing import List, Tuple

# Workaround for moved functionality
try:
from tensorflow.keras.saving import register_keras_serializable
except ImportError:
from tensorflow.keras.utils import register_keras_serializable


register_keras_serializable(package="molSimplify")
class PermutationLayer(tf.keras.layers.Layer):

def __init__(self, permutations: List[Tuple[int]]):
Expand All @@ -27,3 +34,6 @@ def call(self, inputs):
)
)
return tf.stack(outputs, axis=1)

def get_config(self):
return {"permutations": self.permutations}
83 changes: 28 additions & 55 deletions tests/informatics/test_graph_racs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
atom_centered_AC,
multi_centered_AC,
octahedral_racs,
octahedral_racs_names,
ligand_racs,
ligand_racs_names,
)


Expand Down Expand Up @@ -42,7 +44,7 @@ def test_atom_centered_AC(furan):
[112.0, 32.68, 16.0, 4.0, 1.6644],
[16.0, 15.136, 4.0, 2.0, 0.5402],
]
np.testing.assert_allclose(descriptors, ref)
np.testing.assert_allclose(descriptors.T, ref)


def test_atom_centered_AC_diff(furan):
Expand All @@ -54,7 +56,7 @@ def test_atom_centered_AC_diff(furan):
[18.0, 4.26, 0.0, 0.0, 0.64],
[14.0, 2.48, 2.0, 0.0, 0.72],
]
np.testing.assert_allclose(descriptors, ref)
np.testing.assert_allclose(descriptors.T, ref)


def test_multi_centered_AC(furan):
Expand All @@ -66,7 +68,7 @@ def test_multi_centered_AC(furan):
[512.0, 171.695, 122.0, 26.0, 10.3050],
[110.0, 126.632, 50.0, 22.0, 5.3206],
]
np.testing.assert_allclose(descriptors, ref)
np.testing.assert_allclose(descriptors.T, ref)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -95,41 +97,24 @@ def test_octahedral_racs(
ref_dict = json.load(fin)

depth = 3
properties = ["Z", "chi", "T", "I", "S"]
descriptors = octahedral_racs(
mol,
depth=depth,
equatorial_connecting_atoms=eq_atoms,
)

# Dictionary encoding the order of the descriptors in the numpy array
start_scopes = {
0: ("f", "all"),
1: ("mc", "all"),
2: ("lc", "ax"),
3: ("lc", "eq"),
4: ("f", "ax"),
5: ("f", "eq"),
6: ("D_mc", "all"),
7: ("D_lc", "ax"),
8: ("D_lc", "eq"),
}

for s, (start, scope) in start_scopes.items():
for d in range(depth + 1):
for p, prop in enumerate(properties):
print(
start,
scope,
d,
prop,
descriptors[s, d, p],
ref_dict[f"{start}-{prop}-{d}-{scope}"],
)
assert (
abs(descriptors[s, d, p] - ref_dict[f"{start}-{prop}-{d}-{scope}"])
< atol
)
descriptor_names = octahedral_racs_names(depth=depth)

for name, rac in zip(descriptor_names, descriptors.flatten()):
print(
name,
rac,
ref_dict[name],
)
assert (
abs(rac - ref_dict[name])
< atol
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -165,29 +150,17 @@ def test_ligand_racs(
full_scope=True
)

assert descriptors.shape == (n_ligs, 4, depth + 1, 5)
descriptor_names = ligand_racs_names(depth=depth)

starts = {
0: "lc_P",
1: "lc_D",
2: "f_P",
3: "f_D",
}
assert descriptors.shape == (n_ligs, 3, 5, depth + 1)

properties = ["Z", "chi", "T", "I", "S"]
for lig in range(n_ligs):
for s, start in starts.items():
for d in range(depth + 1):
for p, prop in enumerate(properties):
print(
f"lig_{lig}",
start,
d,
prop,
descriptors[lig, s, d, p],
ref_dict[f"lig_{lig}-{start}-{prop}-{d}"],
)
assert (
abs(descriptors[lig, s, d, p] - ref_dict[f"lig_{lig}-{start}-{prop}-{d}"])
< atol
)
for name, rac in zip(descriptor_names, descriptors[lig].flatten()):
print(
rac,
ref_dict[f"lig_{lig}-{name}"],
)
assert (
abs(rac - ref_dict[f"lig_{lig}-{name}"])
< atol
)

0 comments on commit bf9b0e5

Please sign in to comment.