Skip to content

Commit

Permalink
General improvements.
Browse files Browse the repository at this point in the history
- fingerprints/gaussian.py: Fixed a problem with
  unique_element_symbols.
- models/merger.py:
    * Refactored convergence criteria conditional statement.
    * Moved docstring so that they generate documentation with sphinx
      correctly.
  • Loading branch information
muammar committed Sep 25, 2019
1 parent 20b27cc commit 10b859f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 8 deletions.
5 changes: 5 additions & 0 deletions ml4chem/fingerprints/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ def calculate_features(self, images=None, purpose="training", data=None, svm=Fal

logger.info("Unique chemical elements: {}".format(unique_element_symbols))

elif isinstance(data.unique_element_symbols, dict):
unique_element_symbols = data.unique_element_symbols[purpose]

logger.info("Unique chemical elements: {}".format(unique_element_symbols))

# we make the features
self.GP = self.custom.get("GP", None)

Expand Down
63 changes: 55 additions & 8 deletions ml4chem/models/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def train(
targets,
data=None,
optimizer=(None, None),
regularization=None,
epochs=100,
regularization=None,
convergence=None,
lossfxn=None,
device="cpu",
Expand All @@ -93,9 +93,56 @@ def train(
independent_loss=True,
loss_weights=None,
):
"""Train the models
Parameters
----------
inputs : dict
Dictionary with hashed feature space.
targets : list
The expected values that the model has to learn aka y.
model : object
The NeuralNetwork class.
data : object
DataSet object created from the handler.
optimizer : tuple
The optimizer is a tuple with the structure:
>>> ('adam', {'lr': float, 'weight_decay'=float})
epochs : int
Number of full training cycles.
regularization : float
This is the L2 regularization. It is not the same as weight decay.
convergence : dict
Instead of using epochs, users can set a convergence criterion.
>>> convergence = {"rmse": [0.04, 0.02]}
lossfxn : obj
A loss function object.
device : str
Calculation can be run in the cpu or cuda (gpu).
batch_size : int
Number of data points per batch to use for training. Default is None.
lr_scheduler : tuple
Tuple with structure: scheduler's name and a dictionary with keyword
arguments.
>>> lr_scheduler = ('ReduceLROnPlateau',
{'mode': 'min', 'patience': 10})
independent_loss : bool
Whether or not models' weight are optimized independently.
loss_weights : list
How much the loss of model(i) contributes to the total loss.
"""

self.epochs = epochs

if isinstance(convergence["rmse"], float) or isinstance(convergence["rmse"], int):
convergence["rmse"] = np.array([convergence["rmse"] for model in range(len(self.models))])
elif isinstance(convergence["rmse"], list):
if len(convergence["rmse"]) != len(self.models):
raise("Your convergence list is not the same length of the number of models")
convergence["rmse"] = np.array(convergence["rmse"])

logger.info(" ")
logging.info("Model Merger")
logging.info("============")
Expand Down Expand Up @@ -253,10 +300,8 @@ def train(
rmse = []
for i, model in enumerate(self.models):
rmse.append(compute_rmse(outputs[i], self.targets[i]))
# print(outputs[1])
# print(targets[1])
rmse = np.array(rmse)

# print(rmse)
_rmse = np.average(rmse)

if self.optimizer_name != "LBFGS":
Expand All @@ -271,9 +316,7 @@ def train(

if convergence is None and epoch == self.epochs:
converged = True
elif convergence is not None and all(
i <= convergence["rmse"] for i in rmse
):
elif convergence is not None and (rmse <= convergence["rmse"]).all():
converged = True
new_state_dict = {}

Expand All @@ -285,8 +328,12 @@ def train(
print("Diff in {}".format(key))
else:
print("No diff in {}".format(key))
print(convergence)
print(rmse)

# print(rmse)
print("Final")
print(convergence)
print(rmse)

def closure(self, index, model, independent_loss, name=None):
"""Closure
Expand Down

0 comments on commit 10b859f

Please sign in to comment.