Skip to content

Commit

Permalink
AtomicMSELoss supports uncertainty to penalize.
Browse files Browse the repository at this point in the history
  • Loading branch information
muammar committed Feb 13, 2020
1 parent 13a7aa7 commit 6781f7b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
23 changes: 18 additions & 5 deletions ml4chem/atomistic/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np


def AtomicMSELoss(outputs, targets, atoms_per_image):
def AtomicMSELoss(outputs, targets, atoms_per_image, uncertainty=None):
"""Default loss function
If user does not input loss function we provide mean-squared error loss
Expand All @@ -14,6 +14,11 @@ def AtomicMSELoss(outputs, targets, atoms_per_image):
Outputs of the model.
targets : tensor
Expected value of outputs.
atoms_per_image : tensor
A tensor with the number of atoms per image.
uncertainty : tensor, optional
A tensor of uncertainties that are used to penalize during the loss
function evaluation.
Returns
Expand All @@ -22,11 +27,19 @@ def AtomicMSELoss(outputs, targets, atoms_per_image):
The value of the loss function.
"""

criterion = torch.nn.MSELoss(reduction="sum")
outputs_atom = torch.div(outputs, atoms_per_image)
targets_atom = torch.div(targets, atoms_per_image)
if uncertainty == None:
criterion = torch.nn.MSELoss(reduction="sum")
outputs_atom = torch.div(outputs, atoms_per_image)
targets_atom = torch.div(targets, atoms_per_image)

loss = criterion(outputs_atom, targets_atom) * 0.5
loss = criterion(outputs_atom, targets_atom) * 0.5
else:
criterion = torch.nn.MSELoss(reduction="none")
outputs_atom = torch.div(outputs, atoms_per_image)
targets_atom = torch.div(targets, atoms_per_image)
loss = (
criterion(outputs_atom, targets_atom) / torch.pow(uncertainty, 2)
).sum() * 0.5

return loss

Expand Down
4 changes: 2 additions & 2 deletions ml4chem/atomistic/models/neuralnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ def trainer(self):
_rmse = []
epoch = 0

client = dask.distributed.get_client()

while not converged:
epoch += 1

Expand All @@ -384,8 +386,6 @@ def trainer(self):
self.optimizer.step(options)

# RMSE per image and per/atom
client = dask.distributed.get_client()

rmse = client.submit(compute_rmse, *(outputs_, self.targets))
atoms_per_image = torch.cat(self.atoms_per_image)
rmse_atom = client.submit(
Expand Down
1 change: 1 addition & 0 deletions ml4chem/atomistic/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def train(
feature_space, reference_features = self.features.calculate(
training_set, data=data_handler, purpose=purpose, svm=True
)

self.model.prepare_model(
feature_space, reference_features, data=data_handler
)
Expand Down

0 comments on commit 6781f7b

Please sign in to comment.