Skip to content

Commit

Permalink
Adding purpose to fix GaussianProcess and KernelRidge class.
Browse files Browse the repository at this point in the history
  • Loading branch information
muammar committed Aug 24, 2019
1 parent bf88896 commit 91ed153
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
22 changes: 13 additions & 9 deletions ml4chem/models/gaussian_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
else:
self.weights = weights

def get_potential_energy(self, fingerprints, reference_space):
def get_potential_energy(self, fingerprints, reference_space, purpose):
"""Get potential energy with Kernel Ridge
Parameters
Expand All @@ -134,41 +134,45 @@ def get_potential_energy(self, fingerprints, reference_space):
Dictionary with hash and features.
reference_space : array
Array with reference feature space.
purpose : str
Purpose of this function: 'training', 'inference'.
Returns
-------
energy, variance
Energy of a molecule and its respective variance.
"""
reference_space = reference_space[b"reference_space"]
computations = self.get_kernel_matrix(fingerprints, reference_space)
computations = self.get_kernel_matrix(fingerprints, reference_space, purpose)
kernel = np.array(dask.compute(*computations, scheduler=self.scheduler))
weights = np.array(self.weights["energy"])
dim = int(kernel.shape[0] / weights.shape[0])
kernel = kernel.reshape(dim, len(weights))
energy_per_atom = np.dot(kernel, weights)
energy = np.sum(energy_per_atom)
variance = self.get_variance(fingerprints, kernel, reference_space)
variance = self.get_variance(fingerprints, kernel, reference_space, purpose)
return energy, variance

def get_variance(self, fingerprints, ks, reference_space):
def get_variance(self, fingerprints, ks, reference_space, purpose):
"""Compute predictive variance
Parameters
----------
fingerprints : dict
Dictionary with data point to be predicted.
Dictionary with data point to be predicted.
ks : array
Variance between data point and reference space.
reference_space : list
Reference space used to compute kernel.
purpose : str
Purpose of this function: 'training', 'inference'.
Returns
-------
variance
Predictive variance.
"""
K = self.get_kernel_matrix(reference_space, reference_space)
K = self.get_kernel_matrix(reference_space, reference_space, purpose)
K = np.array(dask.compute(*K, scheduler=self.scheduler))
dim = int(np.sqrt(len(K)))
K = K.reshape(dim, dim)
Expand All @@ -184,7 +188,7 @@ def get_variance(self, fingerprints, ks, reference_space):
betas = np.linalg.solve(cholesky_U.T, ks.T)

variance = ks.dot(np.linalg.solve(cholesky_U, betas))
kxx = self.get_kernel_matrix(fingerprints, fingerprints)
kxx = self.get_kernel_matrix(fingerprints, fingerprints, purpose)
kxx = np.array(dask.compute(*kxx, scheduler=self.scheduler))
dim = int(np.sqrt(len(kxx)))
kxx = kxx.reshape(dim, dim)
Expand Down
6 changes: 4 additions & 2 deletions ml4chem/models/kernelridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def train(self, inputs, targets, data=None):
]
self.weights["energy"] = _weights

def get_potential_energy(self, fingerprints, reference_space):
def get_potential_energy(self, fingerprints, reference_space, purpose):
"""Get potential energy with Kernel Ridge
Parameters
Expand All @@ -379,6 +379,8 @@ def get_potential_energy(self, fingerprints, reference_space):
Dictionary with hash and features.
reference_space : array
Array with reference feature space.
purpose : str
Purpose of this function: 'training', 'inference'.
Returns
-------
Expand All @@ -387,7 +389,7 @@ def get_potential_energy(self, fingerprints, reference_space):
"""
reference_space = reference_space[b"reference_space"]
computations = self.get_kernel_matrix(
fingerprints, reference_space, purpose="inference"
fingerprints, reference_space, purpose=purpose
)
kernel = np.array(dask.compute(*computations, scheduler=self.scheduler))
weights = np.array(self.weights["energy"])
Expand Down
2 changes: 1 addition & 1 deletion ml4chem/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def calculate(self, atoms, properties, system_changes):
except:
raise ("This is not a database...")

energy = self.model.get_potential_energy(fingerprints, reference_space)
energy = self.model.get_potential_energy(fingerprints, reference_space, purpose=purpose)
else:
input_dimension = len(list(fingerprints.values())[0][0][-1])
model = copy.deepcopy(self.model)
Expand Down

0 comments on commit 91ed153

Please sign in to comment.