Skip to content

Commit

Permalink
Universal train import using the dynamic_import function
Browse files Browse the repository at this point in the history
The way train() classes are imported in potentials module has been
changed to be more universal.
  • Loading branch information
muammar committed Jul 31, 2019
1 parent 0aeccb5 commit 0a2a7e0
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions ml4chem/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,14 @@ def train(
self.model.to(device_)

# This is something specific of pytorch.
if self.model.name() == 'RetentionTimes':
from ml4chem.models.rt import train
else:
from ml4chem.models.neuralnetwork import train
module_names = {
"PytorchPotentials": "neuralnetwork",
"PytorchIonicPotentials": "ionic",
"RetentionTimes": "rt",
}

module = module_names[self.model.name()]
train = dynamic_import("train", "ml4chem.models", alt_name=module)

train(
feature_space,
Expand Down

0 comments on commit 0a2a7e0

Please sign in to comment.