Skip to content

Commit

Permalink
Support test error during training.
Browse files Browse the repository at this point in the history
- `NeuralNetwork` and `RetentionTime` classes support test set error
   prediction during training.
- Black clean.
- Started porting from str.format() to f"" format for better code.
  • Loading branch information
muammar committed Feb 17, 2020
1 parent 641e4eb commit fb5885c
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 26 deletions.
121 changes: 102 additions & 19 deletions ml4chem/atomistic/models/neuralnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ class train(DeepLearningTrainer):
This is the L2 regularization. It is not the same as weight decay.
convergence : dict
Instead of using epochs, users can set a convergence criterion.
Supported keys are "training" and "test".
lossfxn : obj
A loss function object.
device : str
Expand All @@ -255,6 +256,18 @@ class train(DeepLearningTrainer):
`label` refers to the name used to save the checkpoint, `checkpoint`
is a integer or -1 for saving all epochs, and the path is where the
checkpoint is stored. Default is None and no checkpoint is saved.
test : dict
A dictionary used to compute the error over a validation/test set
during training procedures.
>>> test = {"features": test_space, "targets": test_targets, "data": data_test}
The keys,values of the dictionary are:
- "data": a `Data` object.
- "targets": test set targets.
- "features": a feature space obtained using `features.calculate()`.
"""

def __init__(
Expand All @@ -273,6 +286,7 @@ def __init__(
lr_scheduler=None,
uncertainty=None,
checkpoint=None,
test=None,
):

self.initial_time = time.time()
Expand Down Expand Up @@ -342,27 +356,14 @@ def __init__(
if lr_scheduler is not None:
self.scheduler = get_lr_scheduler(self.optimizer, lr_scheduler)

logger.info(" ")
logger.info("Starting training...")
logger.info(" ")

logger.info(
"{:6s} {:19s} {:12s} {:8s} {:8s}".format(
"Epoch", "Time Stamp", "Loss", "RMSE/img", "RMSE/atom"
)
)
logger.info(
"{:6s} {:19s} {:12s} {:8s} {:8s}".format(
"------", "-------------------", "------------", "--------", "---------"
)
)
self.atoms_per_image = atoms_per_image
self.convergence = convergence
self.device = device
self.epochs = epochs
self.model = model
self.lr_scheduler = lr_scheduler
self.checkpoint = checkpoint
self.test = test

# Data scattering
client = dask.distributed.get_client()
Expand All @@ -385,6 +386,55 @@ def __init__(
def trainer(self):
"""Run the training class"""

logger.info(" ")
logger.info("Starting training...\n")
if self.uncertainty is not None:
logger.info("Loss function will penalize based on uncertainties.\n")

if self.test is None:
logger.info(
"{:6s} {:19s} {:12s} {:12s} {:8s}".format(
"Epoch", "Time Stamp", "Loss", "Error/img", "Error/atom"
)
)
logger.info(
"{:6s} {:19s} {:12s} {:8s} {:8s}".format(
"------",
"-------------------",
"------------",
"------------",
"------------",
)
)

else:
test_features = self.test.get("features", None)
test_targets = self.test.get("targets", None)
test_data = self.test.get("data", None)

logger.info(
"{:6s} {:19s} {:12s} {:12s} {:12s} {:12s} {:16s}".format(
"Epoch",
"Time Stamp",
"Loss",
"Error/img",
"Error/atom",
"Error/img (t)",
"Error/atom (t)",
)
)
logger.info(
"{:6s} {:19s} {:12s} {:8s} {:8s} {:8s} {:8s}".format(
"------",
"-------------------",
"------------",
"------------",
"------------",
"------------",
"------------",
)
)

converged = False
_loss = []
_rmse = []
Expand Down Expand Up @@ -413,27 +463,60 @@ def trainer(self):
self.optimizer.step(options)

# RMSE per image and per/atom

rmse = client.submit(compute_rmse, *(outputs_, self.targets))
atoms_per_image = torch.cat(self.atoms_per_image)

rmse_atom = client.submit(
compute_rmse, *(outputs_, self.targets, atoms_per_image)
)
rmse = rmse.result()
rmse_atom = rmse_atom.result()

_loss.append(loss.item())
_rmse.append(rmse)

# In the case that lr_scheduler is not None
if self.lr_scheduler is not None:
self.scheduler.step(loss)
print("Epoch {} lr {}".format(epoch, get_lr(self.optimizer)))

ts = time.time()
ts = datetime.datetime.fromtimestamp(ts).strftime("%Y-%m-%d " "%H:%M:%S")

logger.info(
"{:6d} {} {:8e} {:8f} {:8f}".format(epoch, ts, loss, rmse, rmse_atom)
)
if self.test is None:
logger.info(
"{:6d} {} {:8e} {:4e} {:4e}".format(
epoch, ts, loss.detach(), rmse, rmse_atom
)
)
else:
test_model = self.model.eval()
test_predictions = test_model(test_features).detach()
rmse_test = client.submit(
compute_rmse, *(test_predictions, test_targets)
)

atoms_per_image_test = torch.tensor(
test_data.atoms_per_image, requires_grad=False
)
rmse_atom_test = client.submit(
compute_rmse,
*(test_predictions, test_targets, atoms_per_image_test)
)

rmse_test = rmse_test.result()
rmse_atom_test = rmse_atom_test.result()

logger.info(
"{:6d} {} {:8e} {:4e} {:4e} {:4e} {:4e}".format(
epoch,
ts,
loss.detach(),
rmse,
rmse_atom,
rmse_test,
rmse_atom_test,
)
)

if self.checkpoint is not None:
self.checkpoint_save(epoch, self.model, **self.checkpoint)
Expand Down
5 changes: 3 additions & 2 deletions ml4chem/atomistic/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,11 @@ def train(
"""

purpose = "training"
data_handler = Data(training_set, purpose=purpose)
# Raw input and targets aka X, y
data_handler = Data(training_set, purpose=purpose)
training_set, targets = data_handler.get_data(purpose=purpose)

# Now let's train
# Now let's featurize
# SVM models
if self.model.name() in Potentials.svm_models:
# Mapping raw positions into a feature space aka X
Expand Down Expand Up @@ -317,6 +317,7 @@ def train(
module = Potentials.module_names[self.model.name()]
train = dynamic_import("train", "ml4chem.atomistic.models", alt_name=module)

# Let's train
train(
feature_space,
targets,
Expand Down
2 changes: 1 addition & 1 deletion ml4chem/data/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, images, purpose=None):
self.images = None
self.targets = None
self.unique_element_symbols = None
logger.info("Data")
logger.info("\nData")
logger.info("====")
now = datetime.datetime.now()
logger.info("Module accessed on {}.".format(now.strftime("%Y-%m-%d %H:%M:%S")))
Expand Down
9 changes: 6 additions & 3 deletions ml4chem/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,26 @@ def set(self, purpose):
preprocessor_name = "Normalizer"

elif self.preprocessing is not None and purpose == "inference":
logger.info("\nData preprocessing")
logger.info("------------------")
logger.info(f"Preprocessor loaded from file : {self.preprocessing}.")
self.preprocessor = joblib.load(self.preprocessing)

else:
logger.warning(
"{} with {} is not supported.".format(self.preprocessing, self.kwargs)
f"Preprocessor ({self.preprocessing}, {self.kwargs}) is not supported."
)
self.preprocessor = preprocessor_name = None

if purpose == "training" and preprocessor_name is not None:
logger.info("Data preprocessing")
logger.info("\nData preprocessing")
logger.info("------------------")
logger.info("Preprocessor: {}.".format(preprocessor_name))
logger.info("Options:")
for k, v in self.kwargs.items():
logger.info(" - {}: {}.".format(k, v))

logger.info(" ")
logger.info(" ")

return self.preprocessor

Expand Down
4 changes: 3 additions & 1 deletion readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ conda:

python:
version: 3.7
pip_install: true
install:
- requirements: docs/requirements.txt
# pip_install: true

0 comments on commit fb5885c

Please sign in to comment.