From 910845079914fc6bcba830c7af57f01905e6782e Mon Sep 17 00:00:00 2001 From: Amit Gupta Date: Wed, 14 Feb 2024 19:22:40 -0600 Subject: [PATCH 1/8] Trainer module --- kliff/trainer/__init__.py | 0 kliff/trainer/kliff_trainer.py | 403 +++++++++++++++++++++++++++++++++ 2 files changed, 403 insertions(+) create mode 100644 kliff/trainer/__init__.py create mode 100644 kliff/trainer/kliff_trainer.py diff --git a/kliff/trainer/__init__.py b/kliff/trainer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kliff/trainer/kliff_trainer.py b/kliff/trainer/kliff_trainer.py new file mode 100644 index 00000000..06337a74 --- /dev/null +++ b/kliff/trainer/kliff_trainer.py @@ -0,0 +1,403 @@ +import os +from copy import deepcopy +from datetime import datetime, timedelta +from enum import Enum +from glob import glob +from pathlib import Path + +import numpy as np +import yaml +from loguru import logger + + +class ModelTypes(Enum): + KIM = 0 + TORCH = 1 + TAR = 2 + + @staticmethod + def get_model_type(input_str: str): + if input_str.lower() == "kim": + return ModelTypes.KIM + elif ( + input_str.lower() == "torch" + or input_str.lower() == "pt" + or input_str.lower() == "pth" + ): + return ModelTypes.TORCH + elif input_str.lower() == "tar": + return ModelTypes.TAR + else: + raise TrainerError(f"Model type {input_str} not supported.") + + @staticmethod + def get_model_config(input_type): + if input_type == ModelTypes.KIM: + return "KIM" + elif input_type == ModelTypes.TORCH: + return "TORCH" + elif input_type == ModelTypes.TAR: + return "TAR" + else: + raise TrainerError(f"Model type {input_type} not supported.") + + +class DataTypes(Enum): + ASE = 0 + COLABFIT = 1 + KLIFF = 2 + TORCH_GEOMETRIC = 3 + + @staticmethod + def get_data_type(input_str: str): + if input_str.lower() == "ase": + return DataTypes.ASE + elif input_str.lower() == "colabfit": + return DataTypes.COLABFIT + elif input_str.lower() == "kliff": + return DataTypes.KLIFF + elif input_str.lower() == "torch_geometric": + return DataTypes.TORCH_GEOMETRIC + else: + raise TrainerError(f"Data type {input_str} not supported.") + + @staticmethod + def get_data_config(input_type): + if input_type == DataTypes.ASE: + return "ASE" + elif input_type == DataTypes.COLABFIT: + return "COLABFIT" + elif input_type == DataTypes.KLIFF: + return "KLIFF" + elif input_type == DataTypes.TORCH_GEOMETRIC: + return "TORCH_GEOMETRIC" + else: + raise TrainerError(f"Data type {input_type} not supported.") + + +class ConfigurationTransformationTypes(Enum): + GRAPH = 0 + DESCRIPTORS = 1 + NEIGHBORS = 2 + + @staticmethod + def get_config_transformation_type(input_str: str): + if input_str.lower() == "graph": + return ConfigurationTransformationTypes.GRAPH + elif input_str.lower() == "descriptors": + return ConfigurationTransformationTypes.DESCRIPTORS + elif input_str.lower() == "neighbors" or input_str.lower() == "none": + return ConfigurationTransformationTypes.NEIGHBORS + else: + raise TrainerError( + f"Configuration transformation type {input_str} not supported." + ) + + @staticmethod + def get_config_transformation_config(input_type): + if input_type == ConfigurationTransformationTypes.GRAPH: + return "GRAPH" + elif input_type == ConfigurationTransformationTypes.DESCRIPTORS: + return "DESCRIPTORS" + else: + raise TrainerError( + f"Configuration transformation type {input_type} not supported." + ) + + +class OptimizerProvider(Enum): + TORCH = 0 + SCIPY = 1 + + @staticmethod + def get_optimizer_provider(input_str: str): + if input_str.lower() == "torch": + return OptimizerProvider.TORCH + elif input_str.lower() == "scipy": + return OptimizerProvider.SCIPY + else: + raise TrainerError(f"Optimizer provider {input_str} not supported.") + + @staticmethod + def get_optimizer_config(input_type): + if input_type == OptimizerProvider.TORCH: + return "TORCH" + elif input_type == OptimizerProvider.SCIPY: + return "SCIPY" + else: + raise TrainerError(f"Optimizer provider {input_type} not supported.") + + +class Trainer: + def __init__(self, configuration: dict): + self.start_time = datetime.now() + self.indices_file = None + self.val_indices = None + self.train_indices = None + self.train_dataset = None + self.val_dataset = None + self.dataset = None + self.model_source = None + self.model = None + self.optimizer = None + logger.info( + f"Starting training. Time: {self.start_time.strftime('%Y-%m-%d-%H-%M-%S')}" + ) + + self.configuration = self.parse_dict(configuration) + + # set computation limits + logger.info(f"Starting trainer. {self.configuration['optimizer_provider']}") + if self.configuration["optimizer_provider"] == OptimizerProvider.TORCH: + # Cant interject SCIPY optimizer with walltime + max_walltime = timedelta(seconds=configuration["max_walltime"]) + self.end_time = self.start_time + max_walltime + + self.root_dir = configuration["root_dir"] + self.current_run_title = configuration["run_title"] + self.append = configuration["append"] + self.resume = configuration["resume"] + self.current_run_dir = configuration["current_run_dir"] + self.optimizer_provider = configuration["optimizer_provider"] + self.device = configuration["device"] + self.model_name = configuration["model_name"] + self.model_source = configuration["model_source"] + self.dataset_type = configuration["dataset_type"] + self.dataset_path = configuration["dataset_path"] + if self.dataset_type == DataTypes.COLABFIT: + self.dataset_name = configuration["dataset_name"] + self.database_name = configuration["database_name"] + self.seed = configuration["seed"] + + # set up indices and dataset + self.indices_file = configuration["indices_file"] + self.energy_loss_weight = configuration["loss_weights"]["energy"] + self.forces_loss_weight = configuration["loss_weights"]["forces"] + + self.checkpoint_freq = configuration["checkpoint_freq"] + self.max_epochs = configuration["max_epoch"] + + def parse_dict(self, configuration: dict): + if "run_title" not in configuration: + logger.error("run_title not provided.") + raise ValueError("run_title not provided.") + + if "root_dir" not in configuration: + logger.warning("root_dir not provided.") + configuration["root_dir"] = "root_dir" + + if "append" not in configuration: + configuration["append"] = False + + resume, current_run_dir = self.workdir( + f"{configuration['root_dir']}/{configuration['run_title']}", + configuration["append"], + ) + configuration["current_run_dir"] = current_run_dir + configuration["resume"] = resume + + if "seed" not in configuration: + configuration["seed"] = None + + if "model_name" not in configuration: + configuration["model_name"] = "model" + + if "model_source" not in configuration: + TrainerError("model_source not provided.") + else: + configuration["model_source"] = ModelTypes.get_model_type( + configuration["model_source"] + ) + + if "dataset_type" not in configuration: + TrainerError("dataset_type not provided.") + else: + configuration["dataset_type"] = DataTypes.get_data_type( + configuration["dataset_type"] + ) + + if configuration["dataset_type"] == DataTypes.COLABFIT: + if ( + "dataset_name" not in configuration + or "database_name" not in configuration + ): + raise TrainerError("colabfit_name not provided.") + elif ( + configuration["dataset_type"] == DataTypes.ASE + or configuration["dataset_type"] == DataTypes.KLIFF + ): + if "dataset_path" not in configuration: + raise TrainerError("dataset_name not provided.") + + # optimizer parameters + configuration["optimizer_provider"] = OptimizerProvider.get_optimizer_provider( + configuration["optimizer_provider"] + ) + + if configuration["optimizer_provider"] == OptimizerProvider.TORCH: + if "n_train" not in configuration: + raise TrainerError("n_train not provided.") + if "n_val" not in configuration: + raise TrainerError("n_val not provided.") + if "batch_size" not in configuration: + raise TrainerError("batch_size not provided.") + if "max_epoch" not in configuration: + configuration["max_epoch"] = 10000 + if "max_walltime" not in configuration: + configuration["max_walltime"] = 48 * 60 * 60 # max in NYU Greene + + if "optimizer" not in configuration: + configuration["optimizer"] = ( + "adam" + if configuration["optimizer_provider"] == OptimizerProvider.TORCH + else "l-bfgs-b" + ) + + # defaults + + if "optimizer_kwargs" not in configuration: + configuration["optimizer_kwargs"] = {} + + if "indices_file" not in configuration: + # to be populated later + configuration["indices_file"] = {"train": None, "val": None} + + if "checkpoint_freq" not in configuration: + configuration["checkpoint_freq"] = 100 + + if "device" not in configuration: + configuration["device"] = "cpu" + + if "loss_weights" not in configuration: + configuration["loss_weights"] = {"energy": 1.0, "forces": 1.0} + + if "cpu_workers" not in configuration: + configuration["cpu_workers"] = 1 + + if "max_epoch" not in configuration: + configuration["max_epoch"] = None + + if "max_walltime" not in configuration: + configuration["max_walltime"] = None + + return configuration + + def get_dict(self): + configuration_dict = deepcopy(self.configuration) + configuration_dict["model_source"] = ModelTypes.get_model_config( + configuration_dict["model_source"] + ) + configuration_dict["dataset_type"] = DataTypes.get_data_config( + configuration_dict["dataset_type"] + ) + configuration_dict[ + "optimizer_provider" + ] = OptimizerProvider.get_optimizer_config( + configuration_dict["optimizer_provider"] + ) + return configuration_dict + + def get_indices(self, size_of_dataset: int): + if self.configuration["indices_file"]["train"] is None: + all_indices = np.arange(size_of_dataset) + np.random.shuffle(all_indices) + self.train_indices = all_indices[: self.configuration["n_train"]] + self.val_indices = all_indices[-self.configuration["n_val"] :] + else: + self.train_indices = np.load(self.configuration["indices_file"]["train"]) + self.val_indices = np.load(self.configuration["indices_file"]["val"]) + + def workdir( + self, + current_run_dir, + append, + ): + """ + Check all the existing runs in the root directory and see if it finished the run + :param current_run_dir: + :return: + """ + dir_list = sorted(glob(f"{current_run_dir}*")) + if len(dir_list) == 0: + resume = False + current_run_dir = current_run_dir + return resume, current_run_dir + elif not append: + resume = False + current_run_dir = ( + f"{current_run_dir}_{self.start_time.strftime('%Y-%m-%d-%H-%M-%S')}" + ) + return resume, current_run_dir + else: + last_dir = dir_list[-1] + was_it_finished = os.path.exists(f"{last_dir}/.finished") + if was_it_finished: + resume = False + current_run_dir = ( + f"{current_run_dir}_{self.start_time.strftime('%Y-%m-%d-%H-%M-%S')}" + ) + return resume, current_run_dir + + # incomplete run encountered + # config_file = f"{dir_list[-1]}/config.yaml" + # try: + # with open(config_file, "r") as f: + # last_config = yaml.safe_load(f) + # except FileNotFoundError: + # raise FileNotFoundError(f"Previous config file not found, most likely corrupted data.") + + # check if anything changed from the last time + # dataset + # when can we resume vs new run? + return True, dir_list[-1] + + @classmethod + def from_file(cls, filename: Path): + with open(filename, "r") as f: + configuration = yaml.safe_load(f) + configuration["filename"] = str(filename) + return cls(configuration) + + def to_file(self, filename): + configuration = self.get_dict() + try: + if self.indices_file is None: + configuration["indices_file"]["train"] = ( + filename.split("/")[-1] + "train_indices.txt" + ) + configuration["indices_file"]["val"] = ( + filename.split("/")[-1] + "val_indices.txt" + ) + np.savetxt(configuration["indices_file"]["train"], self.train_indices) + np.savetxt(configuration["indices_file"]["val"], self.val_indices) + except ValueError: + logger.warning("Indices file not saved. It is normal for KIM models.") + + with open(filename, "w") as f: + yaml.dump(configuration, f, default_flow_style=False) + + def loss(self, energy_prediction, energy_target, forces_prediction, forces_target): + TrainerError("loss not implemented.") + + def checkpoint(self): + TrainerError("checkpoint not implemented.") + + def train_step(self): + TrainerError("train_step not implemented.") + + def validation_step(self): + TrainerError("validation_step not implemented.") + + def get_optimizer(self): + TrainerError("get_optimizer not implemented.") + + def get_dataset(self): # Specific to trainer + TrainerError("get_dataset not implemented.") + + def train(self): + TrainerError("train not implemented.") + + +class TrainerError(Exception): + def __init__(self, message): + super().__init__(message) From 731c5df7cdafe881026f2940729c01eec71ba48d Mon Sep 17 00:00:00 2001 From: Amit Gupta Date: Fri, 23 Feb 2024 08:36:29 -0600 Subject: [PATCH 2/8] Trainer base class implemented --- kliff/trainer/kliff_trainer.py | 845 +++++++++++++++++++-------------- 1 file changed, 490 insertions(+), 355 deletions(-) diff --git a/kliff/trainer/kliff_trainer.py b/kliff/trainer/kliff_trainer.py index 06337a74..485dc84f 100644 --- a/kliff/trainer/kliff_trainer.py +++ b/kliff/trainer/kliff_trainer.py @@ -1,403 +1,538 @@ +import json import os +import random from copy import deepcopy from datetime import datetime, timedelta -from enum import Enum from glob import glob from pathlib import Path +from typing import Callable import numpy as np import yaml from loguru import logger +from kliff._exceptions import TrainerError +from.option_enumerations import DataSource, ModelTypes, OptimizerProvider +from kliff.transforms.configuration_transforms import ConfigurationTransform +from kliff.transforms.parameter_transforms import ParameterTransform +from kliff.transforms.property_transforms import PropertyTransform +from kliff.dataset import Dataset -class ModelTypes(Enum): - KIM = 0 - TORCH = 1 - TAR = 2 - - @staticmethod - def get_model_type(input_str: str): - if input_str.lower() == "kim": - return ModelTypes.KIM - elif ( - input_str.lower() == "torch" - or input_str.lower() == "pt" - or input_str.lower() == "pth" - ): - return ModelTypes.TORCH - elif input_str.lower() == "tar": - return ModelTypes.TAR - else: - raise TrainerError(f"Model type {input_str} not supported.") - - @staticmethod - def get_model_config(input_type): - if input_type == ModelTypes.KIM: - return "KIM" - elif input_type == ModelTypes.TORCH: - return "TORCH" - elif input_type == ModelTypes.TAR: - return "TAR" - else: - raise TrainerError(f"Model type {input_type} not supported.") - - -class DataTypes(Enum): - ASE = 0 - COLABFIT = 1 - KLIFF = 2 - TORCH_GEOMETRIC = 3 - - @staticmethod - def get_data_type(input_str: str): - if input_str.lower() == "ase": - return DataTypes.ASE - elif input_str.lower() == "colabfit": - return DataTypes.COLABFIT - elif input_str.lower() == "kliff": - return DataTypes.KLIFF - elif input_str.lower() == "torch_geometric": - return DataTypes.TORCH_GEOMETRIC - else: - raise TrainerError(f"Data type {input_str} not supported.") - - @staticmethod - def get_data_config(input_type): - if input_type == DataTypes.ASE: - return "ASE" - elif input_type == DataTypes.COLABFIT: - return "COLABFIT" - elif input_type == DataTypes.KLIFF: - return "KLIFF" - elif input_type == DataTypes.TORCH_GEOMETRIC: - return "TORCH_GEOMETRIC" - else: - raise TrainerError(f"Data type {input_type} not supported.") - - -class ConfigurationTransformationTypes(Enum): - GRAPH = 0 - DESCRIPTORS = 1 - NEIGHBORS = 2 - - @staticmethod - def get_config_transformation_type(input_str: str): - if input_str.lower() == "graph": - return ConfigurationTransformationTypes.GRAPH - elif input_str.lower() == "descriptors": - return ConfigurationTransformationTypes.DESCRIPTORS - elif input_str.lower() == "neighbors" or input_str.lower() == "none": - return ConfigurationTransformationTypes.NEIGHBORS - else: - raise TrainerError( - f"Configuration transformation type {input_str} not supported." - ) - - @staticmethod - def get_config_transformation_config(input_type): - if input_type == ConfigurationTransformationTypes.GRAPH: - return "GRAPH" - elif input_type == ConfigurationTransformationTypes.DESCRIPTORS: - return "DESCRIPTORS" - else: - raise TrainerError( - f"Configuration transformation type {input_type} not supported." - ) - - -class OptimizerProvider(Enum): - TORCH = 0 - SCIPY = 1 - - @staticmethod - def get_optimizer_provider(input_str: str): - if input_str.lower() == "torch": - return OptimizerProvider.TORCH - elif input_str.lower() == "scipy": - return OptimizerProvider.SCIPY - else: - raise TrainerError(f"Optimizer provider {input_str} not supported.") - - @staticmethod - def get_optimizer_config(input_type): - if input_type == OptimizerProvider.TORCH: - return "TORCH" - elif input_type == OptimizerProvider.SCIPY: - return "SCIPY" - else: - raise TrainerError(f"Optimizer provider {input_type} not supported.") +import dill # TODO: include dill in requirements.txt +import importlib class Trainer: + """Base class for all trainers. + + This class is the base class for all trainers. It provides the basic structure for + training a model. The derived classes should implement the required methods. This + class will provide the basic functionality for training, such as setting up the + work directory, saving the configuration, and setting up the indices for training + and validation datasets. It will save hashes of the configuration fingerprints and + training configuration to the work directory. This would ensure reproducibility of + the training process, and easy restarting. + + The core trainer class will provide the following functionality: + - Set up the work directory + - Set up the dataset + - Set up the test train split + Model and optimizer setup are left for the derived classes to implement. + + Args: + configuration: configuration dictionary + """ def __init__(self, configuration: dict): - self.start_time = datetime.now() - self.indices_file = None + # workspace variables + self.workspace_name = None # name of default directory, root + self.workspace_name = None # where to save everything from current run (inside workspace) + self.current_run_title = None # title of current run, usually model name + date and time + self.export_kim_model = False # whether to export the model to KIM model + self.seed = 12345 # random seed + self.resume = False # whether to resume from previous run (conditions apply) + self.walltime = None # maximum walltime for the run + + # dataset variables + self.dataset_type: DataSource = DataSource.UNDEFINED + self.dataset_path = None + self.dataset_save = None + self.dataset_shuffle = None + self.dataset = None + self.train_size = None + self.val_size = None + self.indices_files: dict = {"train": None, "val": None} self.val_indices = None self.train_indices = None self.train_dataset = None self.val_dataset = None - self.dataset = None - self.model_source = None - self.model = None - self.optimizer = None - logger.info( - f"Starting training. Time: {self.start_time.strftime('%Y-%m-%d-%H-%M-%S')}" - ) - - self.configuration = self.parse_dict(configuration) - - # set computation limits - logger.info(f"Starting trainer. {self.configuration['optimizer_provider']}") - if self.configuration["optimizer_provider"] == OptimizerProvider.TORCH: - # Cant interject SCIPY optimizer with walltime - max_walltime = timedelta(seconds=configuration["max_walltime"]) - self.end_time = self.start_time + max_walltime - - self.root_dir = configuration["root_dir"] - self.current_run_title = configuration["run_title"] - self.append = configuration["append"] - self.resume = configuration["resume"] - self.current_run_dir = configuration["current_run_dir"] - self.optimizer_provider = configuration["optimizer_provider"] - self.device = configuration["device"] - self.model_name = configuration["model_name"] - self.model_source = configuration["model_source"] - self.dataset_type = configuration["dataset_type"] - self.dataset_path = configuration["dataset_path"] - if self.dataset_type == DataTypes.COLABFIT: - self.dataset_name = configuration["dataset_name"] - self.database_name = configuration["database_name"] - self.seed = configuration["seed"] - - # set up indices and dataset - self.indices_file = configuration["indices_file"] - self.energy_loss_weight = configuration["loss_weights"]["energy"] - self.forces_loss_weight = configuration["loss_weights"]["forces"] - - self.checkpoint_freq = configuration["checkpoint_freq"] - self.max_epochs = configuration["max_epoch"] - - def parse_dict(self, configuration: dict): - if "run_title" not in configuration: - logger.error("run_title not provided.") - raise ValueError("run_title not provided.") - - if "root_dir" not in configuration: - logger.warning("root_dir not provided.") - configuration["root_dir"] = "root_dir" - - if "append" not in configuration: - configuration["append"] = False - - resume, current_run_dir = self.workdir( - f"{configuration['root_dir']}/{configuration['run_title']}", - configuration["append"], - ) - configuration["current_run_dir"] = current_run_dir - configuration["resume"] = resume - - if "seed" not in configuration: - configuration["seed"] = None - - if "model_name" not in configuration: - configuration["model_name"] = "model" - - if "model_source" not in configuration: - TrainerError("model_source not provided.") - else: - configuration["model_source"] = ModelTypes.get_model_type( - configuration["model_source"] - ) + self.colabfit_dataset: dict = { + "dataset_name": None, + "database_name": None, + "database_url": None, + } + + # model variables + self.model_type: ModelTypes = ModelTypes.UNDEFINED + self.model: Callable = None + self.model_name = None # KIM string or name of pt/pth file + self.model_path = None # path to the model file + + # transform variables + self.property_transform:PropertyTransform = None + self.property_transform_options = None + self.parameter_transform:ParameterTransform = None + self.parameter_transform_options = None + self.configuration_transform:ConfigurationTransform = None + self.configuration_transform_options = None + + # training variables + self.loss_function: Callable = None + self.energy_loss_weight = 1.0 + self.forces_loss_weight = 0.0 + + self.optimizer_provider: OptimizerProvider = OptimizerProvider.UNDEFINED + self.optimizer = None # instance of optimizer, "scipy" for scipy torch.optim instance for torch + self.optimizer_name = None # name of optimizer, e.g. "l-bfgs-b" for scipy, "adam" for torch + self.learning_rate = None # learning rate for torch + + self.max_epochs = 10000 # maximum number of epochs + self.device = "cpu" + self.batch_size = 1 + self.chkpt_interval = 100 + self.stop_condition = None # function to check if training should stop + + self.configuration = self.config_from_dict(configuration) + + # state variables + self.current_epoch = 0 + self.current_step = 0 + self.current_best_loss = None + self.current_best_model = None + self.current_loss = None + self.current_run_dir = None + self.appending_to_previous_run = False + self.current_dataset_hash = None + self.start_current_run_title = None # start time of the current run + self.expected_end_time = None + + self._initialize() + + def config_from_dict(self, configuration: dict): + """ + It accepts the raw configuration dictionary, and processes it to the formatted + configuration. This includes mapping the string fields to enums, and setting sane + defaults for missing fields. - if "dataset_type" not in configuration: - TrainerError("dataset_type not provided.") - else: - configuration["dataset_type"] = DataTypes.get_data_type( - configuration["dataset_type"] - ) + Args: + configuration: raw incoming dictionary - if configuration["dataset_type"] == DataTypes.COLABFIT: - if ( - "dataset_name" not in configuration - or "database_name" not in configuration - ): - raise TrainerError("colabfit_name not provided.") - elif ( - configuration["dataset_type"] == DataTypes.ASE - or configuration["dataset_type"] == DataTypes.KLIFF - ): - if "dataset_path" not in configuration: - raise TrainerError("dataset_name not provided.") - - # optimizer parameters - configuration["optimizer_provider"] = OptimizerProvider.get_optimizer_provider( - configuration["optimizer_provider"] - ) - - if configuration["optimizer_provider"] == OptimizerProvider.TORCH: - if "n_train" not in configuration: - raise TrainerError("n_train not provided.") - if "n_val" not in configuration: - raise TrainerError("n_val not provided.") - if "batch_size" not in configuration: - raise TrainerError("batch_size not provided.") - if "max_epoch" not in configuration: - configuration["max_epoch"] = 10000 - if "max_walltime" not in configuration: - configuration["max_walltime"] = 48 * 60 * 60 # max in NYU Greene - - if "optimizer" not in configuration: - configuration["optimizer"] = ( - "adam" - if configuration["optimizer_provider"] == OptimizerProvider.TORCH - else "l-bfgs-b" + Returns: + Processed configuration dictionary + """ + start_time = datetime.now() + date_time_str = start_time.strftime("%Y-%m-%d-%H-%M-%S") + processed_configuration = {} + + # Workspace variables + workspace_block = configuration.get("workspace", None) + if workspace_block is not None: + processed_configuration["start_time"] = start_time + processed_configuration["workspace_name"] = workspace_block.get("name", f"kliff_{date_time_str}") + processed_configuration["current_run_title"] = None # will be assigned in the model block + processed_configuration["export_kim_model"] = workspace_block.get("export", False) + processed_configuration["seed"] = workspace_block.get("seed", 12345) + processed_configuration["resume"] = workspace_block.get("resume", False) + walltime = workspace_block.get("walltime", "2:00:00:00") + processed_configuration["walltime"] = timedelta( + days=int(walltime.split(":")[0]), + hours=int(walltime.split(":")[1]), + minutes=int(walltime.split(":")[2]), + seconds=int(walltime.split(":")[3]), ) + processed_configuration["expected_end_time"] = start_time + processed_configuration["walltime"] + else: + raise TrainerError("Workspace block not found in the configuration.") + + # Dataset variables + dataset_block = configuration.get("dataset", None) + if dataset_block is not None: + processed_configuration["dataset_type"] = DataSource.get_data_config(dataset_block.get("type", "kliff")) + processed_configuration["dataset_path"] = dataset_block.get("dataset_path", None) + processed_configuration["dataset_save"] = dataset_block.get("save", False) + processed_configuration["dataset_shuffle"] = dataset_block.get("shuffle", False) + train_dataset_info = dataset_block.get("training_dataset", None) + if train_dataset_info is not None: + # none values will be tackled during dataset loading + processed_configuration["train_size"] = train_dataset_info.get("train_size", None) + processed_configuration["train_indices"] = train_dataset_info.get("train_indices", None) + else: + processed_configuration["train_size"] = None + processed_configuration["train_indices"] = None + + val_dataset_info = dataset_block.get("validation_dataset", None) + if val_dataset_info is not None: + processed_configuration["val_size"] = val_dataset_info.get("val_size", None) + processed_configuration["val_indices"] = val_dataset_info.get("val_indices", None) + else: + processed_configuration["val_size"] = None + processed_configuration["val_indices"] = None + processed_configuration["indices_file"] = {"train": None, "val": None} + if type(processed_configuration["train_indices"]) is str: + processed_configuration["indices_file"] = processed_configuration["train_indices"] + if type(processed_configuration["val_indices"]) is str: + processed_configuration["indices_file"] = processed_configuration["val_indices"] + + processed_configuration["train_dataset"] = None # To be assigned + processed_configuration["val_dataset"] = None # To be assigned + processed_configuration["dataset"] = None # To be assigned + + colabfit_dict = dataset_block.get("colabfit_dataset", None) + if colabfit_dict is not None: + processed_configuration["colabfit_dataset"] = { + "dataset_name": colabfit_dict.get("dataset_name", None), + "database_name": colabfit_dict.get("database_name", None), + "database_url": colabfit_dict.get("database_url", None), + } + else: + raise TrainerError("Dataset block not found in the configuration.") + + # model variables + model_block = configuration.get("model", {}) + processed_configuration["model_type"] = ModelTypes.get_model_type(model_block.get("model_type", "kim")) + processed_configuration["model_name"] = model_block.get("model_name", None) + processed_configuration["model_path"] = model_block.get("model_path", None) + processed_configuration["model"] = None # To be assigned + if processed_configuration["model_name"] is None: + processed_configuration["current_run_title"] = f"{processed_configuration['model_type']}_{date_time_str}" + else: + processed_configuration["current_run_title"] = f"{processed_configuration['model_name']}_{date_time_str}" + + # transform variables + transform_block = configuration.get("transform", {}) + property_transform_sub_block = transform_block.get("property_transform", {}) + parameter_transform_sub_block = transform_block.get("parameter_transform", {}) + configuration_transform_sub_block = transform_block.get("configuration_transform", {}) + processed_configuration["property_transform_option"]["name"] = property_transform_sub_block.get("name", None) + processed_configuration["property_transform"]= property_transform_sub_block.get("instance", None) # no executable given. initialize on own + processed_configuration["parameter_transform_option"]["name"] = parameter_transform_sub_block.get("name", None) + processed_configuration["parameter_transform"]= parameter_transform_sub_block.get("instance", None) # no executable given. initialize on own + processed_configuration["configuration_transform_option"] = configuration_transform_sub_block # this might contain lot of variables + processed_configuration["configuration_transform"] = configuration_transform_sub_block.get("instance", None) # no executable given. initialize on own + + # training variables + training_block = configuration.get("training", {}) + loss_block = training_block.get("loss", {}) + processed_configuration["loss_function"] = loss_block.get("loss_function", None) + processed_configuration["energy_loss_weight"] = loss_block.get("energy_loss_weight", 1.0) + processed_configuration["forces_loss_weight"] = loss_block.get("forces_loss_weight", 0.0) + + optimizer_block = training_block.get("optimizer", {}) + processed_configuration["optimizer_provider"] = OptimizerProvider.get_optimizer_provider(optimizer_block.get("provider", "scipy")) + processed_configuration["optimizer"] = None # To be assigned + processed_configuration["optimizer_name"] = optimizer_block.get("name", None) + processed_configuration["learning_rate"] = optimizer_block.get("learning_rate", None) + + processed_configuration["max_epochs"] = training_block.get("max_epochs", 10000) + processed_configuration["device"] = training_block.get("device", "cpu") + processed_configuration["batch_size"] = training_block.get("batch_size", 1) + processed_configuration["chkpt_interval"] = training_block.get("chkpt_interval", 100) + processed_configuration["stop_condition"] = training_block.get("stop_condition", None) + + return processed_configuration + + def config_to_dict(self): + """ + Convert the configuration to a dictionary. + """ + configuration = {} + configuration["workspace"] = { + "name": self.workspace_name, + "export": self.export_kim_model, + "seed": self.seed, + "resume": self.resume, + "walltime": f"{self.walltime.days}:{self.walltime.seconds // 3600}:{(self.walltime.seconds // 60) % 60}:{self.walltime.seconds % 60}", + } + + configuration["dataset"] = { + "type": DataSource.get_data_config(self.dataset_type), + "dataset_path": self.dataset_path, + "save": self.dataset_save, + "shuffle": self.dataset_shuffle, + "training_dataset": { + "train_size": self.train_size, + "train_indices": self.train_indices, + }, + "validation_dataset": { + "val_size": self.val_size, + "val_indices": self.val_indices, + }, + "colabfit_dataset": self.colabfit_dataset, + } + + configuration["model"] = { + "model_type": ModelTypes.get_model_config(self.model_type), + "model_name": self.model_name, + "model_path": self.model_path, + } + # TODO: Add transforms correctly + configuration["transform"] = { + "property_transform": self.property_transform, + "parameter_transform": self.parameter_transform, + "configuration_transform": self.configuration_transform, + } + configuration["training"] = { + "loss": { + "loss_function": self.loss_function, + "energy_loss_weight": self.energy_loss_weight, + "forces_loss_weight": self.forces_loss_weight, + }, + "optimizer": { + "provider": OptimizerProvider.get_optimizer_config(self.optimizer_provider), + "name": self.optimizer_name, + "learning_rate": self.learning_rate, + }, + "max_epochs": self.max_epochs, + "device": self.device, + "batch_size": self.batch_size, + "chkpt_interval": self.chkpt_interval, + "stop_condition": self.stop_condition, + } + return configuration - # defaults - - if "optimizer_kwargs" not in configuration: - configuration["optimizer_kwargs"] = {} - - if "indices_file" not in configuration: - # to be populated later - configuration["indices_file"] = {"train": None, "val": None} - - if "checkpoint_freq" not in configuration: - configuration["checkpoint_freq"] = 100 - - if "device" not in configuration: - configuration["device"] = "cpu" - - if "loss_weights" not in configuration: - configuration["loss_weights"] = {"energy": 1.0, "forces": 1.0} + @classmethod + def from_file(cls, filename: Path): + """ + Load the configuration from a YAML file. - if "cpu_workers" not in configuration: - configuration["cpu_workers"] = 1 + Args: + filename: name of the yaml file - if "max_epoch" not in configuration: - configuration["max_epoch"] = None + Returns: + Trainer instance - if "max_walltime" not in configuration: - configuration["max_walltime"] = None + """ + with open(filename, "r") as f: + configuration = yaml.safe_load(f) + configuration["filename"] = str(filename) + return cls(configuration) - return configuration + def get_trainer_hash(self): + """ + Get the hash of the current configuration. It will be used to create a unique + directory for the current run. It will be the hash of the configuration dictionary + string. + """ + config = self.config_to_dict() + config_immut_str = json.dumps(config, sort_keys=True) + return hash(config_immut_str) - def get_dict(self): - configuration_dict = deepcopy(self.configuration) - configuration_dict["model_source"] = ModelTypes.get_model_config( - configuration_dict["model_source"] - ) - configuration_dict["dataset_type"] = DataTypes.get_data_config( - configuration_dict["dataset_type"] - ) - configuration_dict[ - "optimizer_provider" - ] = OptimizerProvider.get_optimizer_config( - configuration_dict["optimizer_provider"] - ) - return configuration_dict - - def get_indices(self, size_of_dataset: int): - if self.configuration["indices_file"]["train"] is None: - all_indices = np.arange(size_of_dataset) - np.random.shuffle(all_indices) - self.train_indices = all_indices[: self.configuration["n_train"]] - self.val_indices = all_indices[-self.configuration["n_val"] :] - else: - self.train_indices = np.load(self.configuration["indices_file"]["train"]) - self.val_indices = np.load(self.configuration["indices_file"]["val"]) - - def workdir( - self, - current_run_dir, - append, - ): + def _initialize(self): + """ + Initialize the trainer. Assigns the configuration objects, and + call setup methods. + """ + # Step 1 - Assign the processed configuration objects to the class variables + for key, value in self.configuration.items(): + setattr(self, key, value) + # Step 2 - Initialize all seeds + self.seed_all() + # Step 3 - Set up the workspace folder + self.setup_workspace() + # Step 4 - Read or load the dataset, initialize the property/configuration transforms + self.setup_dataset() + # Step 5 - Set up the test and train datasets, based on the provided indices + self.setup_test_train_datasets() + # Step 6 - Set up the model + self.setup_model() + # Step 7 - Set up the optimizer + self.setup_optimizer() + + def seed_all(self): + """ + Seed all the random number generators. + """ + np.random.seed(self.seed) + random.seed(self.seed) + # torch.manual_seed(self.seed) # enable torch seed in respective children + # torch.cuda.manual_seed_all(self.seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + + def setup_workspace(self): """ Check all the existing runs in the root directory and see if it finished the run - :param current_run_dir: - :return: """ - dir_list = sorted(glob(f"{current_run_dir}*")) - if len(dir_list) == 0: - resume = False - current_run_dir = current_run_dir - return resume, current_run_dir - elif not append: - resume = False - current_run_dir = ( - f"{current_run_dir}_{self.start_time.strftime('%Y-%m-%d-%H-%M-%S')}" - ) - return resume, current_run_dir + dir_list = sorted(glob(f"{self.workspace_name}*")) + if len(dir_list) == 0 or not self.resume: + self.appending_to_previous_run = False + self.current_run_dir = f"{self.workspace_name}/{self.current_run_title}" + os.makedirs(self.current_run_dir, exist_ok=True) else: last_dir = dir_list[-1] was_it_finished = os.path.exists(f"{last_dir}/.finished") - if was_it_finished: - resume = False - current_run_dir = ( - f"{current_run_dir}_{self.start_time.strftime('%Y-%m-%d-%H-%M-%S')}" - ) - return resume, current_run_dir - - # incomplete run encountered - # config_file = f"{dir_list[-1]}/config.yaml" - # try: - # with open(config_file, "r") as f: - # last_config = yaml.safe_load(f) - # except FileNotFoundError: - # raise FileNotFoundError(f"Previous config file not found, most likely corrupted data.") - - # check if anything changed from the last time - # dataset - # when can we resume vs new run? - return True, dir_list[-1] + if was_it_finished: # start new run + current_run_dir = f"{self.workspace_name}/{self.current_run_title}" + os.makedirs(current_run_dir, exist_ok=True) + self.appending_to_previous_run = False + else: + self.appending_to_previous_run = True + self.current_run_dir = dir_list[-1] + + def setup_dataset(self): + """ + Set up the dataset based on the provided information. It will check the + {workspace}/{dataset_name} directory to see if dataset hash is already there. + The dataset hash is determined but hashing the full path to the dataset + + transforms names + configuration transform properties. + """ + dataset_path = "" + dataset_transforms = "" + dataset_hash = "" + if self.dataset_type == DataSource.KLIFF or self.dataset_type == DataSource.ASE: + dataset_path = os.path.abspath(self.dataset_path) + elif self.dataset_type == DataSource.COLABFIT: + dataset_path = self.colabfit_dataset["database_url"] + "/" + self.colabfit_dataset["database_name"] + else: + raise TrainerError(f"Dataset type {self.dataset_type} not supported.") + + if self.property_transform is not None: + dataset_transforms += self.property_transform_options["name"] + if self.configuration_transform is not None: + dataset_transforms += self.configuration_transform_options["name"] + dataset_transforms += str(self.configuration_transform_options["cutoff"]) + + dataset_hash = hash(dataset_path + dataset_transforms) + self.current_dataset_hash = dataset_hash + dataset_dir = f"{self.workspace_name}/{dataset_hash}" + os.makedirs(dataset_dir, exist_ok=True) + try: + self.dataset = dill.load(open(f"{dataset_dir}/dataset.dill", "rb")) + logger.info(f"Dataset found in {dataset_dir}.") + return + except FileNotFoundError: + logger.info(f"Dataset not found in {self.workspace_name} directory. Creating dataset.") + + if self.dataset_type == DataSource.KLIFF: + dataset = Dataset.from_path(dataset_path) + elif self.dataset_type == DataSource.ASE: + dataset = Dataset.from_path(dataset_path) + elif self.dataset_type == DataSource.COLABFIT: + dataset = Dataset.from_colabfit(self.colabfit_dataset["dataset_name"], + self.colabfit_dataset["database_name"], + self.colabfit_dataset["database_url"]) + else: + raise TrainerError(f"Dataset type {self.dataset_type} not supported.") - @classmethod - def from_file(cls, filename: Path): - with open(filename, "r") as f: - configuration = yaml.safe_load(f) - configuration["filename"] = str(filename) - return cls(configuration) + if self.property_transform is not None: + if not isinstance(self.property_transform, PropertyTransform): + raise TrainerError("Property transform is not none and not an instance of PropertyTransform.") + else: + # check if property_instance_options have "instance" + if "instance" in self.property_transform_options: + self.property_transform = self.property_transform_options["instance"] + else: + try: + # try getting class "name" from kliff.transforms.property_transforms + module = importlib.import_module("kliff.transforms.property_transforms") + class_ = getattr(module, self.property_transform_options["name"]) + self.property_transform = class_(property_key=self.property_transform_options["property_key"],) + except AttributeError: + raise TrainerError(f"Property transform {self.property_transform_options['name']} not found." + "If it is a custom transform, please provide the instance.") + + self.property_transform(dataset) + + if self.configuration_transform is not None: + if not isinstance(self.configuration_transform, ConfigurationTransform): + raise TrainerError("Configuration transform is not none and not an instance of ConfigurationTransform.") + else: + # check if configuration_instance_options have "instance" + if "instance" in self.configuration_transform_options: + self.configuration_transform = self.configuration_transform_options["instance"] + else: + try: + # try getting class "name" from kliff.transforms.configuration_transforms + module = importlib.import_module("kliff.transforms.configuration_transforms") + class_ = getattr(module, self.configuration_transform_options["name"]) + self.configuration_transform = class_(**self.configuration_transform_options["kwargs"],copy_to_config=True) + except AttributeError: + raise TrainerError(f"Configuration transform {self.configuration_transform_options['name']} not found." + "If it is a custom transform, please provide the instance.") + for configuration in dataset: + self.configuration_transform(configuration) + + dill.dump(dataset, open(f"{dataset_dir}/dataset.dill", "wb")) + logger.info(f"Dataset saved in {dataset_dir}.") + if self.dataset_shuffle: + random.shuffle(dataset.configs) + self.dataset = dataset + + def setup_test_train_datasets(self): + """ + Set up the test and train datasets based on the provided indices. If the indices + are not provided, shuffled serial indices will be used. If val_indices are not + provided, the train_indices no validation dataset will be used. + """ - def to_file(self, filename): - configuration = self.get_dict() - try: - if self.indices_file is None: - configuration["indices_file"]["train"] = ( - filename.split("/")[-1] + "train_indices.txt" - ) - configuration["indices_file"]["val"] = ( - filename.split("/")[-1] + "val_indices.txt" - ) - np.savetxt(configuration["indices_file"]["train"], self.train_indices) - np.savetxt(configuration["indices_file"]["val"], self.val_indices) - except ValueError: - logger.warning("Indices file not saved. It is normal for KIM models.") - - with open(filename, "w") as f: - yaml.dump(configuration, f, default_flow_style=False) - - def loss(self, energy_prediction, energy_target, forces_prediction, forces_target): + # training indices + if self.indices_files["train"] is not None: + self.train_indices = np.load(self.indices_files["train"]) + else: + if self.train_size is not None: + self.train_indices = np.arange(self.train_size) + else: + self.train_indices = np.arange(len(self.dataset)) + + # validation indices + if self.indices_files["val"] is not None: + self.val_indices = np.load(self.indices_files["val"]) + else: + if self.val_size is not None: + self.val_indices = np.arange(self.val_size) + else: + self.val_indices = None + + self.train_dataset = self.dataset[self.train_indices] + if self.val_indices: + self.val_dataset = self.dataset[self.val_indices] + + def setup_model(self): + """ + Set up the model based on the provided information. If the model is not provided, + it will be loaded from the model_path. If the model_path is not provided, it will + raise an error. If the model_type is KIM, it will be loaded from the KIM model + repository. If KIM type model is installed in CWD, it will be loaded from there, and + model_path will be set to the KIM CWD. If model is of type TAR, it will be untarred + and model_path will be set to the untarred directory. Left for the derived classes + to implement. + """ + raise TrainerError("setup_model not implemented.") + + def setup_optimizer(self): + """ + Set up the optimizer based on the provided information. If the optimizer is not + provided, it will be loaded from the optimizer_name. If the optimizer_name is not + provided, it will raise an error. If the optimizer_provider is scipy, it will be + loaded from the scipy.optimize. If the optimizer_provider is torch, it will be + loaded from the torch.optim. Left for the derived classes to implement. + """ + raise TrainerError("setup_optimizer not implemented.") + + def loss(self, *args, **kwargs): TrainerError("loss not implemented.") - def checkpoint(self): + def checkpoint(self, *args, **kwargs): TrainerError("checkpoint not implemented.") - def train_step(self): + def train_step(self, *args, **kwargs): TrainerError("train_step not implemented.") - def validation_step(self): + def validation_step(self, *args, **kwargs): TrainerError("validation_step not implemented.") - def get_optimizer(self): + def get_optimizer(self, *args, **kwargs): TrainerError("get_optimizer not implemented.") - def get_dataset(self): # Specific to trainer - TrainerError("get_dataset not implemented.") - - def train(self): + def train(self, *args, **kwargs): TrainerError("train not implemented.") - -class TrainerError(Exception): - def __init__(self, message): - super().__init__(message) From 86ab5d2778b010837c2bdb551da5381fed59043d Mon Sep 17 00:00:00 2001 From: Amit Gupta Date: Wed, 28 Feb 2024 00:46:10 -0600 Subject: [PATCH 3/8] Trainer base class working --- kliff/dataset/dataset.py | 16 +- kliff/trainer/__init__.py | 1 + kliff/trainer/kliff_trainer.py | 398 +++++++++++++++++++++++---------- 3 files changed, 292 insertions(+), 123 deletions(-) diff --git a/kliff/dataset/dataset.py b/kliff/dataset/dataset.py index c7f0918b..68f56bef 100644 --- a/kliff/dataset/dataset.py +++ b/kliff/dataset/dataset.py @@ -962,17 +962,23 @@ def __len__(self) -> int: """ return len(self.configs) - def __getitem__(self, idx) -> Configuration: + def __getitem__(self, idx:Union[int,np.ndarray, List]) -> Union[Configuration, "Dataset"]: """ - Get the configuration at index `idx`. + Get the configuration at index `idx`. If the index is a list, it returns a new + dataset with the configurations at the indices. Args: - idx: Index of the configuration to get. + idx: Index of the configuration to get or a list of indices. Returns: - The configuration at index `idx`. + The configuration at index `idx` or a new dataset with the configurations at + the indices. """ - return self.configs[idx] + if isinstance(idx, int): + return self.configs[idx] + else: + configs = [self.configs[i] for i in idx] + return Dataset(configs) class ConfigurationError(Exception): diff --git a/kliff/trainer/__init__.py b/kliff/trainer/__init__.py index e69de29b..39d10110 100644 --- a/kliff/trainer/__init__.py +++ b/kliff/trainer/__init__.py @@ -0,0 +1 @@ +from .kliff_trainer import Trainer diff --git a/kliff/trainer/kliff_trainer.py b/kliff/trainer/kliff_trainer.py index 485dc84f..b8c25222 100644 --- a/kliff/trainer/kliff_trainer.py +++ b/kliff/trainer/kliff_trainer.py @@ -5,21 +5,25 @@ from datetime import datetime, timedelta from glob import glob from pathlib import Path -from typing import Callable +from typing import Callable, Union import numpy as np import yaml from loguru import logger +import kliff.transforms.configuration_transforms from kliff._exceptions import TrainerError -from.option_enumerations import DataSource, ModelTypes, OptimizerProvider + +from .option_enumerations import DataSource, ModelTypes, OptimizerProvider +import importlib + +import dill # TODO: include dill in requirements.txt +import hashlib + +from kliff.dataset import Dataset from kliff.transforms.configuration_transforms import ConfigurationTransform from kliff.transforms.parameter_transforms import ParameterTransform from kliff.transforms.property_transforms import PropertyTransform -from kliff.dataset import Dataset - -import dill # TODO: include dill in requirements.txt -import importlib class Trainer: @@ -42,15 +46,20 @@ class will provide the basic functionality for training, such as setting up the Args: configuration: configuration dictionary """ + def __init__(self, configuration: dict): # workspace variables - self.workspace_name = None # name of default directory, root - self.workspace_name = None # where to save everything from current run (inside workspace) - self.current_run_title = None # title of current run, usually model name + date and time - self.export_kim_model = False # whether to export the model to KIM model - self.seed = 12345 # random seed - self.resume = False # whether to resume from previous run (conditions apply) - self.walltime = None # maximum walltime for the run + self.workspace_name = None # name of default directory, root + self.workspace_name = ( + None # where to save everything from current run (inside workspace) + ) + self.current_run_title = ( + None # title of current run, usually model name + date and time + ) + self.export_kim_model = False # whether to export the model to KIM model + self.seed = 12345 # random seed + self.resume = False # whether to resume from previous run (conditions apply) + self.walltime = None # maximum walltime for the run # dataset variables self.dataset_type: DataSource = DataSource.UNDEFINED @@ -61,6 +70,7 @@ def __init__(self, configuration: dict): self.train_size = None self.val_size = None self.indices_files: dict = {"train": None, "val": None} + self.ase_keys = {"energy_key": "energy", "forces_key": "forces"} self.val_indices = None self.train_indices = None self.train_dataset = None @@ -74,15 +84,15 @@ def __init__(self, configuration: dict): # model variables self.model_type: ModelTypes = ModelTypes.UNDEFINED self.model: Callable = None - self.model_name = None # KIM string or name of pt/pth file - self.model_path = None # path to the model file + self.model_name = None # KIM string or name of pt/pth file + self.model_path = None # path to the model file # transform variables - self.property_transform:PropertyTransform = None + self.property_transform: PropertyTransform = None self.property_transform_options = None - self.parameter_transform:ParameterTransform = None + self.parameter_transform: ParameterTransform = None self.parameter_transform_options = None - self.configuration_transform:ConfigurationTransform = None + self.configuration_transform: ConfigurationTransform = None self.configuration_transform_options = None # training variables @@ -91,15 +101,17 @@ def __init__(self, configuration: dict): self.forces_loss_weight = 0.0 self.optimizer_provider: OptimizerProvider = OptimizerProvider.UNDEFINED - self.optimizer = None # instance of optimizer, "scipy" for scipy torch.optim instance for torch - self.optimizer_name = None # name of optimizer, e.g. "l-bfgs-b" for scipy, "adam" for torch - self.learning_rate = None # learning rate for torch + self.optimizer = None # instance of optimizer, "scipy" for scipy torch.optim instance for torch + self.optimizer_name = ( + None # name of optimizer, e.g. "l-bfgs-b" for scipy, "adam" for torch + ) + self.learning_rate = None # learning rate for torch - self.max_epochs = 10000 # maximum number of epochs + self.max_epochs = 10000 # maximum number of epochs self.device = "cpu" self.batch_size = 1 self.chkpt_interval = 100 - self.stop_condition = None # function to check if training should stop + self.stop_condition = None # function to check if training should stop self.configuration = self.config_from_dict(configuration) @@ -112,7 +124,7 @@ def __init__(self, configuration: dict): self.current_run_dir = None self.appending_to_previous_run = False self.current_dataset_hash = None - self.start_current_run_title = None # start time of the current run + self.start_current_run_title = None # start time of the current run self.expected_end_time = None self._initialize() @@ -137,54 +149,90 @@ def config_from_dict(self, configuration: dict): workspace_block = configuration.get("workspace", None) if workspace_block is not None: processed_configuration["start_time"] = start_time - processed_configuration["workspace_name"] = workspace_block.get("name", f"kliff_{date_time_str}") - processed_configuration["current_run_title"] = None # will be assigned in the model block - processed_configuration["export_kim_model"] = workspace_block.get("export", False) + processed_configuration["workspace_name"] = workspace_block.get( + "name", f"kliff_{date_time_str}" + ) + processed_configuration["current_run_title"] = ( + None # will be assigned in the model block + ) + processed_configuration["export_kim_model"] = workspace_block.get( + "export", False + ) processed_configuration["seed"] = workspace_block.get("seed", 12345) processed_configuration["resume"] = workspace_block.get("resume", False) - walltime = workspace_block.get("walltime", "2:00:00:00") - processed_configuration["walltime"] = timedelta( - days=int(walltime.split(":")[0]), - hours=int(walltime.split(":")[1]), - minutes=int(walltime.split(":")[2]), - seconds=int(walltime.split(":")[3]), + walltime: Union[str,int] = workspace_block.get("walltime", "2:00:00:00") + if type(walltime) is int: # yaml parsed the time + processed_configuration["walltime"] = timedelta(seconds=walltime) + elif type(walltime) is str: + processed_configuration["walltime"] = timedelta( + days=int(walltime.split(":")[0]), + hours=int(walltime.split(":")[1]), + minutes=int(walltime.split(":")[2]), + seconds=int(walltime.split(":")[3]), + ) + else: + raise TrainerError("Walltime not in correct format. dd:hh:mm:ss expected.") + processed_configuration["expected_end_time"] = ( + start_time + processed_configuration["walltime"] ) - processed_configuration["expected_end_time"] = start_time + processed_configuration["walltime"] else: raise TrainerError("Workspace block not found in the configuration.") # Dataset variables dataset_block = configuration.get("dataset", None) if dataset_block is not None: - processed_configuration["dataset_type"] = DataSource.get_data_config(dataset_block.get("type", "kliff")) - processed_configuration["dataset_path"] = dataset_block.get("dataset_path", None) + processed_configuration["dataset_type"] = DataSource.get_data_enum( + dataset_block.get("type", "kliff") + ) + processed_configuration["dataset_path"] = dataset_block.get( + "path", None + ) processed_configuration["dataset_save"] = dataset_block.get("save", False) - processed_configuration["dataset_shuffle"] = dataset_block.get("shuffle", False) + processed_configuration["dataset_shuffle"] = dataset_block.get( + "shuffle", False + ) + ase_keys = dataset_block.get("keys", {}) + processed_configuration["ase_keys"] = { + "energy_key": ase_keys.get("energy", "energy"), + "forces_key": ase_keys.get("forces", "forces"), + } train_dataset_info = dataset_block.get("training_dataset", None) if train_dataset_info is not None: # none values will be tackled during dataset loading - processed_configuration["train_size"] = train_dataset_info.get("train_size", None) - processed_configuration["train_indices"] = train_dataset_info.get("train_indices", None) + processed_configuration["train_size"] = train_dataset_info.get( + "train_size", None + ) + processed_configuration["train_indices"] = train_dataset_info.get( + "train_indices", None + ) else: processed_configuration["train_size"] = None processed_configuration["train_indices"] = None val_dataset_info = dataset_block.get("validation_dataset", None) if val_dataset_info is not None: - processed_configuration["val_size"] = val_dataset_info.get("val_size", None) - processed_configuration["val_indices"] = val_dataset_info.get("val_indices", None) + processed_configuration["val_size"] = val_dataset_info.get( + "val_size", None + ) + processed_configuration["val_indices"] = val_dataset_info.get( + "val_indices", None + ) else: processed_configuration["val_size"] = None processed_configuration["val_indices"] = None processed_configuration["indices_file"] = {"train": None, "val": None} if type(processed_configuration["train_indices"]) is str: - processed_configuration["indices_file"] = processed_configuration["train_indices"] + processed_configuration["indices_file"] = processed_configuration[ + "train_indices" + ] if type(processed_configuration["val_indices"]) is str: - processed_configuration["indices_file"] = processed_configuration["val_indices"] + processed_configuration["indices_file"] = processed_configuration[ + "val_indices" + ] - processed_configuration["train_dataset"] = None # To be assigned - processed_configuration["val_dataset"] = None # To be assigned - processed_configuration["dataset"] = None # To be assigned + processed_configuration["train_dataset"] = None # To be assigned + processed_configuration["val_dataset"] = None # To be assigned + processed_configuration["dataset"] = None # To be assigned colabfit_dict = dataset_block.get("colabfit_dataset", None) if colabfit_dict is not None: @@ -198,45 +246,88 @@ def config_from_dict(self, configuration: dict): # model variables model_block = configuration.get("model", {}) - processed_configuration["model_type"] = ModelTypes.get_model_type(model_block.get("model_type", "kim")) + processed_configuration["model_type"] = ModelTypes.get_model_enum( + model_block.get("model_type", "kim") + ) processed_configuration["model_name"] = model_block.get("model_name", None) processed_configuration["model_path"] = model_block.get("model_path", None) - processed_configuration["model"] = None # To be assigned + processed_configuration["model"] = None # To be assigned if processed_configuration["model_name"] is None: - processed_configuration["current_run_title"] = f"{processed_configuration['model_type']}_{date_time_str}" + processed_configuration["current_run_title"] = ( + f"{processed_configuration['model_type']}_{date_time_str}" + ) else: - processed_configuration["current_run_title"] = f"{processed_configuration['model_name']}_{date_time_str}" + processed_configuration["current_run_title"] = ( + f"{processed_configuration['model_name']}_{date_time_str}" + ) # transform variables - transform_block = configuration.get("transform", {}) - property_transform_sub_block = transform_block.get("property_transform", {}) - parameter_transform_sub_block = transform_block.get("parameter_transform", {}) - configuration_transform_sub_block = transform_block.get("configuration_transform", {}) - processed_configuration["property_transform_option"]["name"] = property_transform_sub_block.get("name", None) - processed_configuration["property_transform"]= property_transform_sub_block.get("instance", None) # no executable given. initialize on own - processed_configuration["parameter_transform_option"]["name"] = parameter_transform_sub_block.get("name", None) - processed_configuration["parameter_transform"]= parameter_transform_sub_block.get("instance", None) # no executable given. initialize on own - processed_configuration["configuration_transform_option"] = configuration_transform_sub_block # this might contain lot of variables - processed_configuration["configuration_transform"] = configuration_transform_sub_block.get("instance", None) # no executable given. initialize on own + transform_block = configuration.get("transforms", {}) + property_transform_sub_block = transform_block.get("property", {}) + parameter_transform_sub_block = transform_block.get("parameter", {}) + configuration_transform_sub_block = transform_block.get("configuration", {}) + + processed_configuration["property_transform_options"] = { + "name": property_transform_sub_block.get("name", None), + "property_key": property_transform_sub_block.get("property_key", None) + } + processed_configuration["property_transform"] = ( + property_transform_sub_block.get("instance", None) + ) # no executable given. initialize on own + + processed_configuration["parameter_transform_options"] = { + "name": parameter_transform_sub_block.get("name", None), + } + processed_configuration["parameter_transform"] = ( + parameter_transform_sub_block.get("instance", None) + ) # no executable given. initialize on own + + # map default hyperparameters + configuration_transform_kwargs = configuration_transform_sub_block.get("kwargs", {}) + hyperparams = configuration_transform_kwargs.get("hyperparameters", None) + if hyperparams == "default": + configuration_transform_kwargs["hyperparameters"] = \ + kliff.transforms.configuration_transforms.get_default_hyperparams() + + processed_configuration["configuration_transform_options"] = ( + configuration_transform_sub_block # this might contain lot of variables + ) + processed_configuration["configuration_transform"] = ( + configuration_transform_sub_block.get("instance", None) + ) # no executable given. initialize on own # training variables training_block = configuration.get("training", {}) loss_block = training_block.get("loss", {}) processed_configuration["loss_function"] = loss_block.get("loss_function", None) - processed_configuration["energy_loss_weight"] = loss_block.get("energy_loss_weight", 1.0) - processed_configuration["forces_loss_weight"] = loss_block.get("forces_loss_weight", 0.0) + processed_configuration["energy_loss_weight"] = loss_block.get( + "energy_loss_weight", 1.0 + ) + processed_configuration["forces_loss_weight"] = loss_block.get( + "forces_loss_weight", 0.0 + ) optimizer_block = training_block.get("optimizer", {}) - processed_configuration["optimizer_provider"] = OptimizerProvider.get_optimizer_provider(optimizer_block.get("provider", "scipy")) - processed_configuration["optimizer"] = None # To be assigned + processed_configuration["optimizer_provider"] = ( + OptimizerProvider.get_optimizer_enum( + optimizer_block.get("provider", "scipy") + ) + ) + processed_configuration["optimizer"] = None # To be assigned processed_configuration["optimizer_name"] = optimizer_block.get("name", None) - processed_configuration["learning_rate"] = optimizer_block.get("learning_rate", None) + processed_configuration["learning_rate"] = optimizer_block.get( + "learning_rate", None + ) processed_configuration["max_epochs"] = training_block.get("max_epochs", 10000) processed_configuration["device"] = training_block.get("device", "cpu") processed_configuration["batch_size"] = training_block.get("batch_size", 1) - processed_configuration["chkpt_interval"] = training_block.get("chkpt_interval", 100) - processed_configuration["stop_condition"] = training_block.get("stop_condition", None) + processed_configuration["chkpt_interval"] = training_block.get( + "chkpt_interval", 100 + ) + processed_configuration["stop_condition"] = training_block.get( + "stop_condition", None + ) return processed_configuration @@ -244,60 +335,81 @@ def config_to_dict(self): """ Convert the configuration to a dictionary. """ - configuration = {} - configuration["workspace"] = { + config = {} + config["workspace"] = { "name": self.workspace_name, "export": self.export_kim_model, "seed": self.seed, "resume": self.resume, - "walltime": f"{self.walltime.days}:{self.walltime.seconds // 3600}:{(self.walltime.seconds // 60) % 60}:{self.walltime.seconds % 60}", + "walltime": self.walltime.total_seconds(), } - configuration["dataset"] = { - "type": DataSource.get_data_config(self.dataset_type), - "dataset_path": self.dataset_path, + config["dataset"] = { + "type": DataSource.get_data_str(self.dataset_type), + "path": self.dataset_path, "save": self.dataset_save, "shuffle": self.dataset_shuffle, "training_dataset": { "train_size": self.train_size, - "train_indices": self.train_indices, + "train_indices": self.indices_files["train"], }, "validation_dataset": { "val_size": self.val_size, - "val_indices": self.val_indices, + "val_indices": self.indices_files["val"], }, - "colabfit_dataset": self.colabfit_dataset, + "colabfit_dataset": { + "dataset_name": self.colabfit_dataset["dataset_name"], + "database_name": self.colabfit_dataset["database_name"], + "database_url": self.colabfit_dataset["database_url"], + } } - - configuration["model"] = { - "model_type": ModelTypes.get_model_config(self.model_type), + if self.ase_keys is not None: + config["dataset"]["keys"] = { + "energy": self.ase_keys["energy_key"], + "forces": self.ase_keys["forces_key"], + } + + config["model"] = { + "model_type": ModelTypes.get_model_str(self.model_type), "model_name": self.model_name, "model_path": self.model_path, } - # TODO: Add transforms correctly - configuration["transform"] = { - "property_transform": self.property_transform, - "parameter_transform": self.parameter_transform, - "configuration_transform": self.configuration_transform, + + config["transforms"] = { + "property": { + "name": self.property_transform_options["name"], + "property_key": self.property_transform_options["property_key"], + }, + "parameter": { + "name": self.parameter_transform_options["name"], + }, + "configuration": { + "name": self.configuration_transform_options["name"], + "kwargs": self.configuration_transform_options, + } } - configuration["training"] = { + + config["training"] = { "loss": { "loss_function": self.loss_function, - "energy_loss_weight": self.energy_loss_weight, - "forces_loss_weight": self.forces_loss_weight, + "weight": { + "energy": self.energy_loss_weight, + "forces": self.forces_loss_weight, + }, }, "optimizer": { - "provider": OptimizerProvider.get_optimizer_config(self.optimizer_provider), + "provider": OptimizerProvider.get_optimizer_str(self.optimizer_provider), "name": self.optimizer_name, "learning_rate": self.learning_rate, }, - "max_epochs": self.max_epochs, + "epochs": self.max_epochs, "device": self.device, "batch_size": self.batch_size, "chkpt_interval": self.chkpt_interval, "stop_condition": self.stop_condition, } - return configuration + + return config @classmethod def from_file(cls, filename: Path): @@ -324,7 +436,7 @@ def get_trainer_hash(self): """ config = self.config_to_dict() config_immut_str = json.dumps(config, sort_keys=True) - return hash(config_immut_str) + return hashlib.md5(config_immut_str.encode()).hexdigest() def _initialize(self): """ @@ -346,6 +458,8 @@ def _initialize(self): self.setup_model() # Step 7 - Set up the optimizer self.setup_optimizer() + # Step 8 - Save the configuration for future + self.save_config() def seed_all(self): """ @@ -370,7 +484,7 @@ def setup_workspace(self): else: last_dir = dir_list[-1] was_it_finished = os.path.exists(f"{last_dir}/.finished") - if was_it_finished: # start new run + if was_it_finished: # start new run current_run_dir = f"{self.workspace_name}/{self.current_run_title}" os.makedirs(current_run_dir, exist_ok=True) self.appending_to_previous_run = False @@ -391,17 +505,25 @@ def setup_dataset(self): if self.dataset_type == DataSource.KLIFF or self.dataset_type == DataSource.ASE: dataset_path = os.path.abspath(self.dataset_path) elif self.dataset_type == DataSource.COLABFIT: - dataset_path = self.colabfit_dataset["database_url"] + "/" + self.colabfit_dataset["database_name"] + dataset_path = ( + self.colabfit_dataset["database_url"] + + "/" + + self.colabfit_dataset["database_name"] + ) else: raise TrainerError(f"Dataset type {self.dataset_type} not supported.") - if self.property_transform is not None: + if self.property_transform_options is not None: dataset_transforms += self.property_transform_options["name"] - if self.configuration_transform is not None: - dataset_transforms += self.configuration_transform_options["name"] - dataset_transforms += str(self.configuration_transform_options["cutoff"]) - dataset_hash = hash(dataset_path + dataset_transforms) + dataset_transforms += "_" + + if self.configuration_transform_options is not None: + dataset_transforms += self.configuration_transform_options["name"] + "_" + dataset_transforms += str(self.configuration_transform_options["kwargs"]["cutoff"]) + + dataset_hash_str = dataset_path + "_" + dataset_transforms + dataset_hash = hashlib.md5(dataset_hash_str.encode()).hexdigest() self.current_dataset_hash = dataset_hash dataset_dir = f"{self.workspace_name}/{dataset_hash}" os.makedirs(dataset_dir, exist_ok=True) @@ -410,54 +532,80 @@ def setup_dataset(self): logger.info(f"Dataset found in {dataset_dir}.") return except FileNotFoundError: - logger.info(f"Dataset not found in {self.workspace_name} directory. Creating dataset.") + logger.info( + f"Dataset not found in {self.workspace_name} directory. Creating dataset." + ) if self.dataset_type == DataSource.KLIFF: dataset = Dataset.from_path(dataset_path) elif self.dataset_type == DataSource.ASE: - dataset = Dataset.from_path(dataset_path) + dataset = Dataset.from_ase(dataset_path, **self.ase_keys) elif self.dataset_type == DataSource.COLABFIT: - dataset = Dataset.from_colabfit(self.colabfit_dataset["dataset_name"], - self.colabfit_dataset["database_name"], - self.colabfit_dataset["database_url"]) + dataset = Dataset.from_colabfit( + self.colabfit_dataset["dataset_name"], + self.colabfit_dataset["database_name"], + self.colabfit_dataset["database_url"], + ) else: raise TrainerError(f"Dataset type {self.dataset_type} not supported.") if self.property_transform is not None: if not isinstance(self.property_transform, PropertyTransform): - raise TrainerError("Property transform is not none and not an instance of PropertyTransform.") + raise TrainerError( + "Property transform is not none and not an instance of PropertyTransform." + ) else: # check if property_instance_options have "instance" - if "instance" in self.property_transform_options: + if self.property_transform_options.get("instance") is not None: self.property_transform = self.property_transform_options["instance"] else: try: # try getting class "name" from kliff.transforms.property_transforms - module = importlib.import_module("kliff.transforms.property_transforms") + module = importlib.import_module( + "kliff.transforms.property_transforms" + ) class_ = getattr(module, self.property_transform_options["name"]) - self.property_transform = class_(property_key=self.property_transform_options["property_key"],) + self.property_transform = class_( + property_key=self.property_transform_options["property_key"], + ) except AttributeError: - raise TrainerError(f"Property transform {self.property_transform_options['name']} not found." - "If it is a custom transform, please provide the instance.") + raise TrainerError( + f"Property transform {self.property_transform_options['name']} not found." + "If it is a custom transform, please provide the instance." + ) self.property_transform(dataset) if self.configuration_transform is not None: if not isinstance(self.configuration_transform, ConfigurationTransform): - raise TrainerError("Configuration transform is not none and not an instance of ConfigurationTransform.") + raise TrainerError( + "Configuration transform is not none and not an instance of ConfigurationTransform." + ) else: # check if configuration_instance_options have "instance" - if "instance" in self.configuration_transform_options: - self.configuration_transform = self.configuration_transform_options["instance"] + if "instance" in self.configuration_transform_options \ + and self.configuration_transform_options["instance"] is not None: + self.configuration_transform = self.configuration_transform_options[ + "instance" + ] else: try: # try getting class "name" from kliff.transforms.configuration_transforms - module = importlib.import_module("kliff.transforms.configuration_transforms") - class_ = getattr(module, self.configuration_transform_options["name"]) - self.configuration_transform = class_(**self.configuration_transform_options["kwargs"],copy_to_config=True) + module = importlib.import_module( + "kliff.transforms.configuration_transforms" + ) + class_ = getattr( + module, self.configuration_transform_options["name"] + ) + self.configuration_transform = class_( + **self.configuration_transform_options["kwargs"], + copy_to_config=True, + ) except AttributeError: - raise TrainerError(f"Configuration transform {self.configuration_transform_options['name']} not found." - "If it is a custom transform, please provide the instance.") + raise TrainerError( + f"Configuration transform {self.configuration_transform_options['name']} not found." + "If it is a custom transform, please provide the instance." + ) for configuration in dataset: self.configuration_transform(configuration) @@ -493,8 +641,23 @@ def setup_test_train_datasets(self): self.val_indices = None self.train_dataset = self.dataset[self.train_indices] + self.indices_files["train"] = f"{self.current_run_dir}/train_indices.npy" + self.train_indices.dump(self.indices_files["train"]) + if self.val_indices: self.val_dataset = self.dataset[self.val_indices] + self.indices_files["val"] = f"{self.current_run_dir}/val_indices.npy" + self.val_indices.dump(self.indices_files["val"]) + + def save_config(self): + """ + Hash and save the configuration to the current run directory. + """ + config_hash = self.get_trainer_hash() + config_file = f"{self.current_run_dir}/{config_hash}.yaml" + with open(config_file, "w") as f: + yaml.dump(self.configuration, f, default_flow_style=False) + logger.info(f"Configuration saved in {config_file}.") def setup_model(self): """ @@ -535,4 +698,3 @@ def get_optimizer(self, *args, **kwargs): def train(self, *args, **kwargs): TrainerError("train not implemented.") - From 2f57f561fc58d5d1c2aa2b3effbc0188fb8e771f Mon Sep 17 00:00:00 2001 From: Amit Gupta Date: Sun, 3 Mar 2024 19:17:01 -0600 Subject: [PATCH 4/8] First draft trainer framework --- kliff/_exceptions.py | 13 + kliff/dataset/dataset.py | 229 +++++++++++-- kliff/trainer/__init__.py | 1 + kliff/trainer/kim_residuals.py | 19 ++ kliff/trainer/kim_trainer.py | 393 ++++++++++++++++++++++ kliff/trainer/kliff_trainer.py | 399 ++++++++++++++--------- kliff/trainer/option_enumerations.py | 225 +++++++++++++ kliff/utils.py | 38 +++ setup.py | 3 +- tests/dataset/test_weight.py | 20 ++ tests/test_data/weights/Si_4_weights.dat | 4 + 11 files changed, 1163 insertions(+), 181 deletions(-) create mode 100644 kliff/_exceptions.py create mode 100644 kliff/trainer/kim_residuals.py create mode 100644 kliff/trainer/kim_trainer.py create mode 100644 kliff/trainer/option_enumerations.py create mode 100644 tests/test_data/weights/Si_4_weights.dat diff --git a/kliff/_exceptions.py b/kliff/_exceptions.py new file mode 100644 index 00000000..de1c8288 --- /dev/null +++ b/kliff/_exceptions.py @@ -0,0 +1,13 @@ +""" +This module contains exceptions to be raised in kliff modules, along with details on +where they are raised. +""" + + +class TrainerError(Exception): + """ + Exceptions to be raised in Trainer and associated classes. + """ + + def __init__(self, message): + super().__init__(message) diff --git a/kliff/dataset/dataset.py b/kliff/dataset/dataset.py index 68f56bef..b8424a74 100644 --- a/kliff/dataset/dataset.py +++ b/kliff/dataset/dataset.py @@ -580,14 +580,18 @@ def from_colabfit( colabfit_database: str, colabfit_dataset: str, colabfit_uri: str = "mongodb://localhost:27017", - weight: Optional[Weight] = None, + weight: Optional[Union[Weight, Path]] = None, ) -> "Dataset": """ Read configurations from colabfit database and initialize a dataset. Args: weight: an instance that computes the weight of the configuration in the loss - function. + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). colabfit_database: Name of the colabfit Mongo database to read from. colabfit_dataset: Name of the colabfit dataset instance to read from, usually it is of form, e.g., "DS_xxxxxxxxxxxx_0" @@ -607,7 +611,7 @@ def from_colabfit( def _read_from_colabfit( database_client: MongoDatabase, colabfit_dataset: str, - weight: Optional[Weight] = None, + weight: Optional[Union[Weight, Path]] = None, ) -> List[Configuration]: """ Read configurations from colabfit database. @@ -617,7 +621,11 @@ def _read_from_colabfit( fetch database from colabfit-tools dataset. colabfit_dataset: Name of the colabfit dataset instance to read from. weight: an instance that computes the weight of the configuration in the loss - function. + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). Returns: A list of configurations. @@ -630,10 +638,37 @@ def _read_from_colabfit( logger.error(f"{colabfit_dataset} is either empty or does not exist") raise DatasetError(f"{colabfit_dataset} is either empty or does not exist") + if isinstance(weight, Path): + print(weight) + weights = np.loadtxt(weight) + if weights.ndim == 1 and len(weights) == 4: + weights = np.tile(weights, (len(data_objects), 1)) + elif weights.ndim == 2 and len(weights) == len(data_objects): + pass + else: + raise DatasetError( + "Weight file must be a plain text file with 4 whitespace separated " + "columns: config_weight, energy_weight, forces_weight, and " + "stress_weight. Length of the file must be equal to the number of " + "configurations, or 1 (in which case the same weight is used for all " + "configurations)." + ) + weights = [ + Weight( + config_weight=w[0], + energy_weight=w[1], + forces_weight=w[2], + stress_weight=w[3], + ) + for w in weights + ] + else: + weights = [weight] * len(data_objects) + configs = [] - for data_object in data_objects: + for data_object, weight_obj in zip(data_objects, weights): configs.append( - Configuration.from_colabfit(database_client, data_object, weight) + Configuration.from_colabfit(database_client, data_object, weight_obj) ) if len(configs) <= 0: @@ -649,7 +684,7 @@ def add_from_colabfit( colabfit_database: str, colabfit_dataset: str, colabfit_uri: str = "mongodb://localhost:27017", - weight: Optional[Weight] = None, + weight: Optional[Union[Weight, Path]] = None, ): """ Read configurations from colabfit database and add them to the dataset. @@ -660,7 +695,11 @@ def add_from_colabfit( it is of form, e.g., "DS_xxxxxxxxxxxx_0") colabfit_uri: connection URI of the colabfit Mongo database to read from. weight: an instance that computes the weight of the configuration in the loss - function. + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). """ # open link to the mongo @@ -672,7 +711,7 @@ def add_from_colabfit( def from_path( cls, path: Union[Path, str], - weight: Optional[Weight] = None, + weight: Optional[Union[Path, Weight]] = None, file_format: str = "xyz", ) -> "Dataset": """ @@ -681,7 +720,11 @@ def from_path( Args: path: Path the directory (or filename) storing the configurations. weight: an instance that computes the weight of the configuration in the loss - function. + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). file_format: Format of the file that stores the configuration, e.g. `xyz`. Returns: @@ -693,7 +736,9 @@ def from_path( @staticmethod def _read_from_path( - path: Path, weight: Optional[Weight] = None, file_format: str = "xyz" + path: Path, + weight: Optional[Union[Weight, Path]] = None, + file_format: str = "xyz", ) -> List[Configuration]: """ Read configurations from path. @@ -702,7 +747,11 @@ def _read_from_path( path: Path of the directory storing the configurations in individual files. For single file with multiple configurations, use `_read_from_ase()` instead. weight: an instance that computes the weight of the configuration in the loss - function. + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). file_format: Format of the file that stores the configuration, e.g. `xyz`. Returns: @@ -730,9 +779,37 @@ def _read_from_path( parent = path.parent all_files = [path] + if isinstance(weight, Path): + print(weight) + weights = np.loadtxt(weight) + if weights.ndim == 1 and len(weights) == 4: + weights = np.tile(weights, (len(all_files), 1)) + elif weights.ndim == 2 and len(weights) == len(all_files): + pass + else: + raise DatasetError( + "Weight file must be a plain text file with 4 whitespace separated " + "columns: config_weight, energy_weight, forces_weight, and " + "stress_weight. Length of the file must be equal to the number of " + "configurations, or 1 (in which case the same weight is used for all " + "configurations)." + ) + weights = [ + Weight( + config_weight=w[0], + energy_weight=w[1], + forces_weight=w[2], + stress_weight=w[3], + ) + for w in weights + ] + + else: + weights = [weight] * len(all_files) + configs = [ - Configuration.from_file(f, copy.copy(weight), file_format) - for f in all_files + Configuration.from_file(f, copy.copy(w), file_format) + for f, w in zip(all_files, weights) ] if len(configs) <= 0: @@ -744,7 +821,7 @@ def _read_from_path( def add_from_path( self, path: Union[Path, str], - weight: Optional[Weight] = None, + weight: Optional[Union[Weight, Path]] = None, file_format: str = "xyz", ): """ @@ -753,7 +830,11 @@ def add_from_path( Args: path: Path the directory (or filename) storing the configurations. weight: an instance that computes the weight of the configuration in the loss - function. + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). file_format: Format of the file that stores the configuration, e.g. `xyz`. """ if isinstance(path, str): @@ -766,7 +847,7 @@ def from_ase( cls, path: Union[Path, str] = None, ase_atoms_list: List[ase.Atoms] = None, - weight: Optional[Weight] = None, + weight: Optional[Union[Weight, Path]] = None, energy_key: str = "energy", forces_key: str = "forces", slices: str = ":", @@ -791,7 +872,11 @@ def from_ase( path: Path the directory (or filename) storing the configurations. ase_atoms_list: A list of ase.Atoms objects. weight: an instance that computes the weight of the configuration in the loss - function. + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). energy_key: Name of the field in extxyz/ase.Atoms that stores the energy. forces_key: Name of the field in extxyz/ase.Atoms that stores the forces. slices: Slice of the configurations to read. It is used only when `path` is @@ -825,7 +910,11 @@ def _read_from_ase( path: Path the directory (or filename) storing the configurations. ase_atoms_list: A list of ase.Atoms objects. weight: an instance that computes the weight of the configuration in the loss - function. + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). energy_key: Name of the field in extxyz/ase.Atoms that stores the energy. forces_key: Name of the field in extxyz/ase.Atoms that stores the forces. slices: Slice of the configurations to read. It is used only when `path` is @@ -841,14 +930,42 @@ def _read_from_ase( ) if ase_atoms_list: + if isinstance(weight, Path): + weights = np.loadtxt(weight) + if weights.ndim == 1 and len(weights) == 4: + weights = np.tile(weights, (len(ase_atoms_list), 1)) + if weights.ndim == 2 and len(weights) == len(ase_atoms_list): + pass + else: + raise DatasetError( + "Length of weights must be equal to the number of configurations, or 1 " + "(in which case the same weight is used for all configurations)." + ) + weights = [ + Weight( + config_weight=w[0], + energy_weight=w[1], + forces_weight=w[2], + stress_weight=w[3], + ) + for w in weights + ] + else: + weights = [weight] * len(ase_atoms_list) + + if len(ase_atoms_list) != len(weights): + raise DatasetError( + "Length of weights must be equal to the number of configurations, or 1 " + "(in which case the same weight is used for all configurations)." + ) configs = [ Configuration.from_ase_atoms( config, - weight=copy.copy(weight), + weight=copy.copy(weight_obj), energy_key=energy_key, forces_key=forces_key, ) - for config in ase_atoms_list + for config, weight_obj in zip(ase_atoms_list, weights) ] else: try: @@ -875,24 +992,77 @@ def _read_from_ase( if len(all_files) == 1: # single xyz file with multiple configs all_configs = ase.io.read(all_files[0], index=slices) + + # This code fragment is duplicated because in ASE loading, there can be multiple + # branches on how the configurations are loaded, and it is simplest to + # assign weights accordingly per configuration. + if isinstance(weight, Path): + weights = np.loadtxt(weight) + if weights.ndim == 1 and len(weights) == 4: + weights = np.tile(weights, (len(all_configs), 1)) + if weights.ndim == 2 and len(weights) == len(all_configs): + pass + else: + raise DatasetError( + "Length of weights must be equal to the number of configurations, or 1 " + "(in which case the same weight is used for all configurations)." + ) + weights = [ + Weight( + config_weight=w[0], + energy_weight=w[1], + forces_weight=w[2], + stress_weight=w[3], + ) + for w in weights + ] + else: + weights = [weight] * len(all_configs) + configs = [ Configuration.from_ase_atoms( config, - weight=copy.copy(weight), + weight=copy.copy(weight_obj), energy_key=energy_key, forces_key=forces_key, ) - for config in all_configs + for config, weight_obj in zip(all_configs, weights) ] else: + # This code fragment is duplicated because in ASE loading, there can be multiple + # branches on how the configurations are loaded, and it is simplest to + # assign weights accordingly per configuration. + if isinstance(weight, Path): + weights = np.loadtxt(weight) + if weights.ndim == 1 and len(weights) == 4: + weights = np.tile(weights, (len(all_files), 1)) + if weights.ndim == 2 and len(weights) == len(all_files): + pass + else: + raise DatasetError( + "Length of weights must be equal to the number of configurations, or 1 " + "(in which case the same weight is used for all configurations)." + ) + weights = [ + Weight( + config_weight=w[0], + energy_weight=w[1], + forces_weight=w[2], + stress_weight=w[3], + ) + for w in weights + ] + else: + weights = [weight] * len(all_files) + configs = [ Configuration.from_ase_atoms( ase.io.read(f), - weight=copy.copy(weight), + weight=copy.copy(w), energy_key=energy_key, forces_key=forces_key, ) - for f in all_files + for f, w in zip(all_files, weights) ] if len(configs) <= 0: @@ -933,6 +1103,11 @@ def add_from_ase( path: Path the directory (or filename) storing the configurations. ase_atoms_list: A list of ase.Atoms objects. weight: an instance that computes the weight of the configuration in the loss + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). energy_key: Name of the field in extxyz/ase.Atoms that stores the energy. forces_key: Name of the field in extxyz/ase.Atoms that stores the forces. slices: Slice of the configurations to read. It is used only when `path` is @@ -962,7 +1137,9 @@ def __len__(self) -> int: """ return len(self.configs) - def __getitem__(self, idx:Union[int,np.ndarray, List]) -> Union[Configuration, "Dataset"]: + def __getitem__( + self, idx: Union[int, np.ndarray, List] + ) -> Union[Configuration, "Dataset"]: """ Get the configuration at index `idx`. If the index is a list, it returns a new dataset with the configurations at the indices. diff --git a/kliff/trainer/__init__.py b/kliff/trainer/__init__.py index 39d10110..18f04637 100644 --- a/kliff/trainer/__init__.py +++ b/kliff/trainer/__init__.py @@ -1 +1,2 @@ +from .kim_trainer import KIMTrainer from .kliff_trainer import Trainer diff --git a/kliff/trainer/kim_residuals.py b/kliff/trainer/kim_residuals.py new file mode 100644 index 00000000..eced653b --- /dev/null +++ b/kliff/trainer/kim_residuals.py @@ -0,0 +1,19 @@ +from typing import Any, Dict + +import numpy as np + + +def MSE_residuals( + predictions: np.ndarray, + targets: np.ndarray, +) -> np.ndarray: + r""" + Compute the mean squared error (MSE) of the residuals. + + Args: + + Returns: + The MSE of the residuals. + """ + residuals = predictions - targets + return np.mean(residuals**2) diff --git a/kliff/trainer/kim_trainer.py b/kliff/trainer/kim_trainer.py new file mode 100644 index 00000000..e50f32aa --- /dev/null +++ b/kliff/trainer/kim_trainer.py @@ -0,0 +1,393 @@ +import importlib +import multiprocessing +import subprocess +import tarfile +from pathlib import Path +from typing import Callable, Tuple, Union + +import kimpy +import numpy as np +from loguru import logger + +import kliff.models +from kliff._exceptions import TrainerError +from kliff.dataset import Configuration +from kliff.models import KIMModel +from kliff.utils import install_kim_model + +from .kim_residuals import MSE_residuals +from .kliff_trainer import Trainer +from .option_enumerations import ModelTypes, OptimizerProvider + +# list of model drivers that are not supported by this trainer. +# example quip, torchml, etc. +# TODO: Get the complete list of unsupported model drivers. +UNSUPPORTED_MODEL_DRIVERS = [ + "TorchML", +] +SCIPY_MINIMIZE_METHODS = [ + "Nelder-Mead", + "Powell", + "CG", + "BFGS", + "Newton-CG", + "L-BFGS-B", + "TNC", + "COBYLA", + "SLSQP", + "trust-constr", + "dogleg", + "trust-ncg", + "trust-exact", + "trust-krylov", +] + + +class KIMTrainer(Trainer): + """ + This class extends the base Trainer class for training OpenKIM physics based models. + It will use the scipy optimizers. It will perform a check to exclude TorchML model + driver based models, as they would be handled by TorchTrainer. + """ + + def __init__(self, configuration: dict, collection: str = "user"): + self.collection = collection + + self.model_driver_name = None + self.parameters = None + self.mutable_parameters_list = [] + self.use_energy = True + self.use_forces = False + + super().__init__(configuration) + + self.loss_function = self._get_loss_fn() + + def setup_model(self): + """ + Load either the KIM model, or install it from the tarball. If the model driver + required is TorchML* family, then it will raise an error, as it should be handled + by the TorchTrainer. + """ + # check for unsupported model drivers + # 1. get the model driver name + if self.model_type == ModelTypes.KIM: + self.model_driver_name = self.get_model_driver_name_for_kim(self.model_name) + elif self.model_type == ModelTypes.TAR: + self.model_driver_name = self.get_model_driver_name_for_tarball( + self.model_name + ) + else: + raise TrainerError(f"Model type {self.model_type} not supported.") + + # 2. check if the model driver is supported + if self.model_driver_name in UNSUPPORTED_MODEL_DRIVERS: + raise TrainerError( + f"Model driver {self.model_driver_name} not supported by KIMTrainer." + ) + elif self.model_driver_name is None: + logger.warning( + f"Could not determine model-driver name for {self.model_name}. Please be careful and check if the model is supported." + ) + else: + logger.info(f"Model driver name: {self.model_driver_name}") + + # 3. load the model + if self.model_type == ModelTypes.KIM: + self.ensure_kim_model_installation(self.model_name, self.collection) + elif self.model_type == ModelTypes.TAR: + # reinstall model to be sure + self.ensure_tarball_model_installation(self.model_name, self.collection) + + self.model = KIMModel(self.model_name) + self.parameters = self.model.get_model_params() + + def setup_parameter_transforms(self): + """ + This method set up the transformed parameter space for models. It can be used + for any model type in general, but as there exists a significant difference + between how models handles their parameters, it is left for the subclass to + implement. Although to ensure that `initialize` function remains consistent + this method will not raise NotImplemented error, rather it will quietly pass. + So be aware. + """ + self.set_parameters_as_mutable() + mutable_params = self.model.parameters() + parameter_transforms_input = self.parameter_transform_options["parameter_list"] + if parameter_transforms_input is not None: + for model_params, input_params in zip( + mutable_params, parameter_transforms_input + ): + if isinstance(input_params, dict): + param_name = list(input_params.keys())[0] + if param_name != model_params.name: + raise TrainerError( + f"Parameter name mismatch. Expected {model_params.name}, got {param_name}." + ) + + param_value_dict = input_params[param_name] + transform_name = param_value_dict.get("transform_name", None) + params_value = param_value_dict.get("value", None) + bounds = param_value_dict.get("bounds", None) + + if transform_name is not None: + transform_module = getattr( + importlib.import_module( + f"kliff.transforms.parameter_transforms" + ), + transform_name, + ) + transform_module = transform_module() + model_params.add_transform(transform_module) + + if params_value is not None: + model_params.copy_from_model_space(params_value) + + if bounds is not None: + model_params.add_bounds_model_space(np.array(bounds)) + + elif isinstance(input_params, str): + if input_params != model_params.name: + raise TrainerError( + f"Parameter name mismatch. Expected {model_params.name}, got {input_params}." + ) + else: + raise TrainerError( + f"Optimizable parameters must be string or value dict. Got {input_params} instead." + ) + + def setup_optimizer(self): + """ + Set up the optimizer based on the provided information. If the optimizer is not + provided, it will be loaded from the optimizer_name. If the optimizer_name is not + provided, it will raise an error. If the optimizer_provider is scipy, it will be + loaded from the scipy.optimize. If the optimizer_provider is torch, it will be + loaded from the torch.optim. Left for the derived classes to implement. + """ + if self.optimizer_provider is not OptimizerProvider.SCIPY: + raise TrainerError( + f"Optimizer provider {self.optimizer_provider} not supported by KIMTrainer." + ) + + if self.optimizer_name not in SCIPY_MINIMIZE_METHODS: + raise TrainerError(f"Optimizer not supported: {self.optimizer_name}.") + optimizer_lib = importlib.import_module(f"scipy.optimize") + self.optimizer = getattr(optimizer_lib, "minimize") + + def loss(self, x): + """ + Compute the loss function for the given parameters. It will compute the loss + function. It seems like MPI might be only way to make it parallel as the + multiprocessing does not work with the KIM models. KIMPY models are not yet + pickelable. TODO: include MPI support. + """ + # set the parameters + self.model.update_model_params(x) + # compute the loss + loss = 0.0 + for configuration in self.train_dataset: + compute_energy = True if configuration.weight.energy_weight else False + compute_forces = True if configuration.weight.forces_weight else False + compute_stress = True if configuration.weight.stress_weight else False + + prediction = self.model( + configuration, + compute_energy=compute_energy, + compute_forces=compute_forces, + compute_stress=compute_stress, + ) + + if configuration.weight.energy_weight: + loss += configuration.weight.energy_weight * self.loss_function( + prediction["energy"], configuration.energy + ) + if configuration.weight.forces_weight: + loss += configuration.weight.forces_weight * self.loss_function( + prediction["forces"], configuration.forces + ) + if configuration.weight.stress_weight: + loss += configuration.weight.stress_weight * self.loss_function( + prediction["stress"], configuration.stress + ) + loss *= configuration.weight.config_weight + + return loss + + def checkpoint(self, *args, **kwargs): + TrainerError("checkpoint not implemented.") + + def train_step(self, *args, **kwargs): + TrainerError("train_step not implemented.") + + def validation_step(self, *args, **kwargs): + TrainerError("validation_step not implemented.") + + def get_optimizer(self, *args, **kwargs): + TrainerError("get_optimizer not implemented.") + + def train(self, *args, **kwargs): + def _wrapper_func(x): + return self.loss(x) + + x = self.model.get_opt_params() + options = self.optimizer_kwargs + options["options"] = {"maxiter": self.max_epochs, "disp": self.verbose} + result = self.optimizer( + _wrapper_func, x, method=self.optimizer_name, **self.optimizer_kwargs + ) + + if result.success: + logger.info(f"Optimization successful: {result.message}") + self.model.update_model_params(result.x) + else: + logger.error(f"Optimization failed: {result.message}") + + @staticmethod + def get_model_driver_name_for_kim(model_name: str) -> Union[str, None]: + """ + Get the model driver from the model name. It will return the model driver + string from the installed KIM API model. If the model is not installed, and the + model name is a tarball, it will extract the model driver name from the CMakeLists.txt. + This is needed to ensure that it excludes the model drivers that it cannot handle. + Example: TorchML driver based models. These models are to be trained using the + TorchTrainer. + + TODO: This is not a clean solution. I think KIMPY must have a better way to handle this. + Ask Mingjian/Yaser for comment. + + Args: + model_name: name of the model. + + Returns: + Model driver name. + """ + collections = kimpy.collections.create() + try: + shared_obj_path, collection = ( + collections.get_item_library_file_name_and_collection( + kimpy.collection_item_type.portableModel, model_name + ) + ) + except RuntimeError: # not a portable model + return None + shared_obj_content = open(shared_obj_path, "rb").read() + md_start_idx = shared_obj_content.find(b"model-driver") + + if md_start_idx == -1: + return None + else: + md_start_idx += 15 # length of 'model-driver" "' + md_end_idx = shared_obj_content.find(b'"', md_start_idx) + return shared_obj_content[md_start_idx:md_end_idx].decode("utf-8") + + @staticmethod + def get_model_driver_name_for_tarball(tarball: str) -> Union[str, None]: + """ + Get the model driver name from the tarball. It will extract the model driver + name from the CMakeLists.txt file in the tarball. This is needed to ensure that + it excludes the model drivers that it cannot handle. Example: TorchML driver based + models. These models are to be trained using the TorchTrainer. + + Args: + tarball: path to the tarball. + + Returns: + Model driver name. + """ + archive_content = tarfile.open(tarball) + cmake_file_path = archive_content.getnames()[0] + "/CMakeLists.txt" + cmake_file = archive_content.extractfile(cmake_file_path) + cmake_file_content = cmake_file.read().decode("utf-8") + + md_start_idx = cmake_file_content.find("DRIVER_NAME") + if md_start_idx == -1: + return None + else: + # name strats at " + md_start_idx = cmake_file_content.find('"', md_start_idx) + 1 + if md_start_idx == -1: + return None + md_end_idx = cmake_file_content.find('"', md_start_idx) + return cmake_file_content[md_start_idx:md_end_idx] + + @staticmethod + def ensure_kim_model_installation(model_name: str, collection: str = "user"): + """ + Ensure that the KIM model is installed. If the model is not installed, it will + install the model in the user collection. If the model is already installed, it + will not do anything. + + Args: + model_name: name of the model. + collection: collection to install the model in. + """ + is_model_installed = install_kim_model(model_name) + if not install_kim_model(model_name): + logger.error( + f"Mode: {model_name} neither installed nor available in the KIM API collections. Please check the model name and try again." + ) + raise TrainerError(f"Model {model_name} not found.") + else: + logger.info(f"Model {model_name} is present in {collection} collection.") + + def ensure_tarball_model_installation(self, tarball: str, collection: str = "user"): + """ + Ensure that the model is installed from the tarball. If the model is not installed, + it will install the model in the user collection. If the model is already installed, + it will reinstall the model. + + Args: + tarball: path to the tarball. + collection: collection to install the model in. + """ + scratch_dir = f"{self.current_run_dir}/.scratch" + archive_content = tarfile.open(tarball) + model = archive_content.getnames()[0] + archive_content.extractall(scratch_dir) + subprocess.run( + [ + "kim-api-collections-management", + "install", + "--force", + collection, + scratch_dir + "/" + model, + ], + check=True, + ) + logger.info(f"Tarball Model {model} installed in {collection} collection.") + + def set_parameters_as_mutable(self): + if self.parameter_transform_options is not None: + for param_to_transform in self.parameter_transform_options[ + "parameter_list" + ]: + if isinstance(param_to_transform, dict): + parameter_name = list(param_to_transform.keys())[0] + elif isinstance(param_to_transform, str): + parameter_name = param_to_transform + else: + raise TrainerError( + f"Optimizable parameters must be string or value dict. Got {param_to_transform} instead." + ) + self.mutable_parameters_list.append(parameter_name) + else: + for param in self.parameters: + self.mutable_parameters_list.append(param) + + self.model.set_params_mutable(self.mutable_parameters_list) + logger.info(f"Mutable parameters: {self.mutable_parameters_list}") + + def _get_loss_fn(self): + if self.loss_function_name == "MSE": + return MSE_residuals + + def save_kim_model(self): + if self.export_model_type is ModelTypes.KIM: + path = Path(self.export_model_path) / self.export_model_name + self.model.write_kim_model(path) + elif self.export_model_type is ModelTypes.TAR: + path = Path(self.export_model_path) / self.export_model_name + self.model.write_kim_model(path) + tarfile_path = path.with_suffix(".tar.gz") + with tarfile.open(tarfile_path, "w:gz") as tar: + tar.add(path, arcname=path.name) diff --git a/kliff/trainer/kliff_trainer.py b/kliff/trainer/kliff_trainer.py index b8c25222..d8de33ce 100644 --- a/kliff/trainer/kliff_trainer.py +++ b/kliff/trainer/kliff_trainer.py @@ -1,3 +1,5 @@ +import hashlib +import importlib import json import os import random @@ -7,24 +9,21 @@ from pathlib import Path from typing import Callable, Union +import dill # TODO: include dill in requirements.txt import numpy as np import yaml from loguru import logger import kliff.transforms.configuration_transforms from kliff._exceptions import TrainerError - -from .option_enumerations import DataSource, ModelTypes, OptimizerProvider -import importlib - -import dill # TODO: include dill in requirements.txt -import hashlib - from kliff.dataset import Dataset from kliff.transforms.configuration_transforms import ConfigurationTransform from kliff.transforms.parameter_transforms import ParameterTransform from kliff.transforms.property_transforms import PropertyTransform +from ..dataset.weight import Weight +from .option_enumerations import DataSource, ModelTypes, OptimizerProvider + class Trainer: """Base class for all trainers. @@ -41,7 +40,8 @@ class will provide the basic functionality for training, such as setting up the - Set up the work directory - Set up the dataset - Set up the test train split - Model and optimizer setup are left for the derived classes to implement. + Model, parameter transform and optimizer setup are left for the derived classes to + implement. Args: configuration: configuration dictionary @@ -97,14 +97,23 @@ def __init__(self, configuration: dict): # training variables self.loss_function: Callable = None - self.energy_loss_weight = 1.0 - self.forces_loss_weight = 0.0 + self.loss_function_name = None + self.loss_weights: Union[dict, Path] = { + "energy": 1.0, + "forces": None, + "stress": None, + "config": 1.0, + } + self.loss_normalize_per_atom = False self.optimizer_provider: OptimizerProvider = OptimizerProvider.UNDEFINED self.optimizer = None # instance of optimizer, "scipy" for scipy torch.optim instance for torch self.optimizer_name = ( None # name of optimizer, e.g. "l-bfgs-b" for scipy, "adam" for torch ) + self.optimizer_kwargs = ( + None # kwargs for optimizer to bundle all unknown kwargs + ) self.learning_rate = None # learning rate for torch self.max_epochs = 10000 # maximum number of epochs @@ -112,6 +121,14 @@ def __init__(self, configuration: dict): self.batch_size = 1 self.chkpt_interval = 100 self.stop_condition = None # function to check if training should stop + self.num_workers = 1 # number of workers for data loading + self.verbose = True # whether to print verbose output + + # export trained model + self.export_kim_model = False + self.export_model_type = None + self.export_model_name = None + self.export_model_path = None self.configuration = self.config_from_dict(configuration) @@ -126,10 +143,12 @@ def __init__(self, configuration: dict): self.current_dataset_hash = None self.start_current_run_title = None # start time of the current run self.expected_end_time = None + self.warned_once = False # log all one time warnings - self._initialize() + self.initialize() - def config_from_dict(self, configuration: dict): + @staticmethod + def config_from_dict(configuration: dict): """ It accepts the raw configuration dictionary, and processes it to the formatted configuration. This includes mapping the string fields to enums, and setting sane @@ -145,106 +164,96 @@ def config_from_dict(self, configuration: dict): date_time_str = start_time.strftime("%Y-%m-%d-%H-%M-%S") processed_configuration = {} - # Workspace variables - workspace_block = configuration.get("workspace", None) - if workspace_block is not None: - processed_configuration["start_time"] = start_time - processed_configuration["workspace_name"] = workspace_block.get( - "name", f"kliff_{date_time_str}" - ) - processed_configuration["current_run_title"] = ( - None # will be assigned in the model block - ) - processed_configuration["export_kim_model"] = workspace_block.get( - "export", False - ) - processed_configuration["seed"] = workspace_block.get("seed", 12345) - processed_configuration["resume"] = workspace_block.get("resume", False) - walltime: Union[str,int] = workspace_block.get("walltime", "2:00:00:00") - if type(walltime) is int: # yaml parsed the time - processed_configuration["walltime"] = timedelta(seconds=walltime) - elif type(walltime) is str: - processed_configuration["walltime"] = timedelta( - days=int(walltime.split(":")[0]), - hours=int(walltime.split(":")[1]), - minutes=int(walltime.split(":")[2]), - seconds=int(walltime.split(":")[3]), - ) - else: - raise TrainerError("Walltime not in correct format. dd:hh:mm:ss expected.") - processed_configuration["expected_end_time"] = ( - start_time + processed_configuration["walltime"] + # Workspace variables ################################################ + workspace_block = configuration.get("workspace", {}) + if workspace_block == {}: + raise TrainerError("Workspace block not defined in trainer configuration") + processed_configuration["start_time"] = start_time + processed_configuration["workspace_name"] = workspace_block.get( + "name", f"kliff_{date_time_str}" + ) + processed_configuration["current_run_title"] = ( + None # will be assigned in the model block + ) + processed_configuration["export_kim_model"] = workspace_block.get( + "export", False + ) + processed_configuration["seed"] = workspace_block.get("seed", 12345) + processed_configuration["resume"] = workspace_block.get("resume", False) + walltime: Union[str, int] = workspace_block.get("walltime", "2:00:00:00") + if type(walltime) is int: # yaml parsed the time + processed_configuration["walltime"] = timedelta(seconds=walltime) + elif type(walltime) is str: + processed_configuration["walltime"] = timedelta( + days=int(walltime.split(":")[0]), + hours=int(walltime.split(":")[1]), + minutes=int(walltime.split(":")[2]), + seconds=int(walltime.split(":")[3]), ) else: - raise TrainerError("Workspace block not found in the configuration.") + raise TrainerError("Walltime not in correct format. dd:hh:mm:ss expected.") + processed_configuration["expected_end_time"] = ( + start_time + processed_configuration["walltime"] + ) - # Dataset variables - dataset_block = configuration.get("dataset", None) - if dataset_block is not None: - processed_configuration["dataset_type"] = DataSource.get_data_enum( - dataset_block.get("type", "kliff") - ) - processed_configuration["dataset_path"] = dataset_block.get( - "path", None - ) - processed_configuration["dataset_save"] = dataset_block.get("save", False) - processed_configuration["dataset_shuffle"] = dataset_block.get( - "shuffle", False + # Dataset variables ################################################# + dataset_block = configuration.get("dataset", {}) + if dataset_block == {}: + raise TrainerError( + "Dataset block is missing from the trainer configuration file" ) - ase_keys = dataset_block.get("keys", {}) - processed_configuration["ase_keys"] = { - "energy_key": ase_keys.get("energy", "energy"), - "forces_key": ase_keys.get("forces", "forces"), - } - train_dataset_info = dataset_block.get("training_dataset", None) - if train_dataset_info is not None: - # none values will be tackled during dataset loading - processed_configuration["train_size"] = train_dataset_info.get( - "train_size", None - ) - processed_configuration["train_indices"] = train_dataset_info.get( - "train_indices", None - ) - else: - processed_configuration["train_size"] = None - processed_configuration["train_indices"] = None - val_dataset_info = dataset_block.get("validation_dataset", None) - if val_dataset_info is not None: - processed_configuration["val_size"] = val_dataset_info.get( - "val_size", None - ) - processed_configuration["val_indices"] = val_dataset_info.get( - "val_indices", None - ) - else: - processed_configuration["val_size"] = None - processed_configuration["val_indices"] = None - processed_configuration["indices_file"] = {"train": None, "val": None} - if type(processed_configuration["train_indices"]) is str: - processed_configuration["indices_file"] = processed_configuration[ - "train_indices" - ] - if type(processed_configuration["val_indices"]) is str: - processed_configuration["indices_file"] = processed_configuration[ - "val_indices" - ] + processed_configuration["dataset_type"] = DataSource.get_data_enum( + dataset_block.get("type", "kliff") + ) + processed_configuration["dataset_path"] = dataset_block.get("path", None) + processed_configuration["dataset_save"] = dataset_block.get("save", False) + processed_configuration["dataset_shuffle"] = dataset_block.get("shuffle", False) + ase_keys = dataset_block.get("keys", {}) + processed_configuration["ase_keys"] = { + "energy_key": ase_keys.get("energy", "energy"), + "forces_key": ase_keys.get("forces", "forces"), + } + train_dataset_info = dataset_block.get("training_dataset", {}) - processed_configuration["train_dataset"] = None # To be assigned - processed_configuration["val_dataset"] = None # To be assigned - processed_configuration["dataset"] = None # To be assigned - - colabfit_dict = dataset_block.get("colabfit_dataset", None) - if colabfit_dict is not None: - processed_configuration["colabfit_dataset"] = { - "dataset_name": colabfit_dict.get("dataset_name", None), - "database_name": colabfit_dict.get("database_name", None), - "database_url": colabfit_dict.get("database_url", None), - } - else: - raise TrainerError("Dataset block not found in the configuration.") + # none values will be tackled during dataset loading + processed_configuration["train_size"] = train_dataset_info.get( + "train_size", None + ) + processed_configuration["train_indices"] = train_dataset_info.get( + "train_indices", None + ) - # model variables + val_dataset_info = dataset_block.get("validation_dataset", {}) + processed_configuration["val_size"] = val_dataset_info.get("val_size", None) + processed_configuration["val_indices"] = val_dataset_info.get( + "val_indices", None + ) + + processed_configuration["indices_file"] = {"train": None, "val": None} + + if type(processed_configuration["train_indices"]) is str: + processed_configuration["indices_file"] = processed_configuration[ + "train_indices" + ] + + if type(processed_configuration["val_indices"]) is str: + processed_configuration["indices_file"] = processed_configuration[ + "val_indices" + ] + + processed_configuration["train_dataset"] = None # To be assigned + processed_configuration["val_dataset"] = None # To be assigned + processed_configuration["dataset"] = None # To be assigned + colabfit_dict = dataset_block.get("colabfit_dataset", {}) + + processed_configuration["colabfit_dataset"] = { + "dataset_name": colabfit_dict.get("dataset_name", None), + "database_name": colabfit_dict.get("database_name", None), + "database_url": colabfit_dict.get("database_url", None), + } + + # model variables #################################################### model_block = configuration.get("model", {}) processed_configuration["model_type"] = ModelTypes.get_model_enum( model_block.get("model_type", "kim") @@ -261,7 +270,7 @@ def config_from_dict(self, configuration: dict): f"{processed_configuration['model_name']}_{date_time_str}" ) - # transform variables + # transform variables #################################################### transform_block = configuration.get("transforms", {}) property_transform_sub_block = transform_block.get("property", {}) parameter_transform_sub_block = transform_block.get("parameter", {}) @@ -269,26 +278,19 @@ def config_from_dict(self, configuration: dict): processed_configuration["property_transform_options"] = { "name": property_transform_sub_block.get("name", None), - "property_key": property_transform_sub_block.get("property_key", None) + "property_key": property_transform_sub_block.get("property_key", None), } processed_configuration["property_transform"] = ( property_transform_sub_block.get("instance", None) ) # no executable given. initialize on own processed_configuration["parameter_transform_options"] = { - "name": parameter_transform_sub_block.get("name", None), + "parameter_list": parameter_transform_sub_block.get("parameter_list", None), } processed_configuration["parameter_transform"] = ( parameter_transform_sub_block.get("instance", None) ) # no executable given. initialize on own - # map default hyperparameters - configuration_transform_kwargs = configuration_transform_sub_block.get("kwargs", {}) - hyperparams = configuration_transform_kwargs.get("hyperparameters", None) - if hyperparams == "default": - configuration_transform_kwargs["hyperparameters"] = \ - kliff.transforms.configuration_transforms.get_default_hyperparams() - processed_configuration["configuration_transform_options"] = ( configuration_transform_sub_block # this might contain lot of variables ) @@ -296,15 +298,22 @@ def config_from_dict(self, configuration: dict): configuration_transform_sub_block.get("instance", None) ) # no executable given. initialize on own - # training variables + # training variables ######################################################## training_block = configuration.get("training", {}) loss_block = training_block.get("loss", {}) - processed_configuration["loss_function"] = loss_block.get("loss_function", None) - processed_configuration["energy_loss_weight"] = loss_block.get( - "energy_loss_weight", 1.0 - ) - processed_configuration["forces_loss_weight"] = loss_block.get( - "forces_loss_weight", 0.0 + processed_configuration["loss_function_name"] = loss_block.get("function", None) + weights = loss_block.get("weights", {}) + if isinstance(weights, str): + processed_configuration["loss_weights"] = Path(weights) + else: + processed_configuration["loss_weights"] = { + "energy": weights.get("energy", 1.0), + "forces": weights.get("forces", None), + "stress": weights.get("stress", None), + "config": weights.get("config", 1.0), + } + processed_configuration["loss_normalize_per_atom"] = loss_block.get( + "normalize_per_atom", False ) optimizer_block = training_block.get("optimizer", {}) @@ -318,6 +327,9 @@ def config_from_dict(self, configuration: dict): processed_configuration["learning_rate"] = optimizer_block.get( "learning_rate", None ) + processed_configuration["optimizer_kwargs"] = optimizer_block.get( + "kwargs", None + ) processed_configuration["max_epochs"] = training_block.get("max_epochs", 10000) processed_configuration["device"] = training_block.get("device", "cpu") @@ -328,6 +340,22 @@ def config_from_dict(self, configuration: dict): processed_configuration["stop_condition"] = training_block.get( "stop_condition", None ) + processed_configuration["num_workers"] = training_block.get("num_workers", 1) + processed_configuration["verbose"] = training_block.get("verbose", True) + + # export trained model #################################################### + export_block = configuration.get("export", {}) + processed_configuration["export_kim_model"] = export_block is not {} + if export_block is not {}: + processed_configuration["export_model_type"] = ModelTypes.get_model_enum( + export_block.get("model_type", "kim") + ) + processed_configuration["export_model_name"] = export_block.get( + "model_name", None + ) + processed_configuration["export_model_path"] = export_block.get( + "model_path", None + ) return processed_configuration @@ -361,7 +389,7 @@ def config_to_dict(self): "dataset_name": self.colabfit_dataset["dataset_name"], "database_name": self.colabfit_dataset["database_name"], "database_url": self.colabfit_dataset["database_url"], - } + }, } if self.ase_keys is not None: config["dataset"]["keys"] = { @@ -381,33 +409,47 @@ def config_to_dict(self): "property_key": self.property_transform_options["property_key"], }, "parameter": { - "name": self.parameter_transform_options["name"], + "parameter_list": self.parameter_transform_options["parameter_list"], }, "configuration": { "name": self.configuration_transform_options["name"], "kwargs": self.configuration_transform_options, - } + }, } config["training"] = { "loss": { - "loss_function": self.loss_function, - "weight": { - "energy": self.energy_loss_weight, - "forces": self.forces_loss_weight, + "loss_function": self.loss_function_name, + "weights": { + "energy": self.loss_weights["energy"], + "forces": self.loss_weights["forces"], + "stress": self.loss_weights["stress"], + "config": self.loss_weights["config"], }, + "normalize_per_atom": self.loss_normalize_per_atom, }, "optimizer": { - "provider": OptimizerProvider.get_optimizer_str(self.optimizer_provider), + "provider": OptimizerProvider.get_optimizer_str( + self.optimizer_provider + ), "name": self.optimizer_name, "learning_rate": self.learning_rate, + "kwargs": self.optimizer_kwargs, }, "epochs": self.max_epochs, "device": self.device, "batch_size": self.batch_size, "chkpt_interval": self.chkpt_interval, "stop_condition": self.stop_condition, + "num_workers": self.num_workers, + "verbose": self.verbose, } + if self.export_kim_model: + config["export"] = { + "model_type": ModelTypes.get_model_str(self.export_model_type), + "model_name": self.export_model_name, + "model_path": self.export_model_path, + } return config @@ -438,7 +480,7 @@ def get_trainer_hash(self): config_immut_str = json.dumps(config, sort_keys=True) return hashlib.md5(config_immut_str.encode()).hexdigest() - def _initialize(self): + def initialize(self): """ Initialize the trainer. Assigns the configuration objects, and call setup methods. @@ -456,6 +498,8 @@ def _initialize(self): self.setup_test_train_datasets() # Step 6 - Set up the model self.setup_model() + # Step 6.5 - Setup parameter transform + self.setup_parameter_transforms() # Step 7 - Set up the optimizer self.setup_optimizer() # Step 8 - Save the configuration for future @@ -514,15 +558,23 @@ def setup_dataset(self): raise TrainerError(f"Dataset type {self.dataset_type} not supported.") if self.property_transform_options is not None: - dataset_transforms += self.property_transform_options["name"] + dataset_transforms += str(self.property_transform_options["name"]) dataset_transforms += "_" if self.configuration_transform_options is not None: dataset_transforms += self.configuration_transform_options["name"] + "_" - dataset_transforms += str(self.configuration_transform_options["kwargs"]["cutoff"]) + dataset_transforms += str( + self.configuration_transform_options["kwargs"]["cutoff"] + ) - dataset_hash_str = dataset_path + "_" + dataset_transforms + dataset_hash_str = ( + dataset_path + + "_" + + dataset_transforms + + "_" + + json.dumps(self.loss_weights, sort_keys=True) + ) dataset_hash = hashlib.md5(dataset_hash_str.encode()).hexdigest() self.current_dataset_hash = dataset_hash dataset_dir = f"{self.workspace_name}/{dataset_hash}" @@ -536,15 +588,32 @@ def setup_dataset(self): f"Dataset not found in {self.workspace_name} directory. Creating dataset." ) + if isinstance(self.loss_weights, Path): + weights = self.loss_weights + elif isinstance(self.loss_weights, dict): + weights = Weight( + config_weight=self.loss_weights["config"], + energy_weight=self.loss_weights["energy"], + forces_weight=self.loss_weights["forces"], + stress_weight=self.loss_weights["stress"], + ) + elif self.loss_weights is None: + weights = None + else: + raise TrainerError( + "Loss weights should be a dictionary or a path to a file, or None." + ) + if self.dataset_type == DataSource.KLIFF: - dataset = Dataset.from_path(dataset_path) + dataset = Dataset.from_path(dataset_path, weight=weights) elif self.dataset_type == DataSource.ASE: - dataset = Dataset.from_ase(dataset_path, **self.ase_keys) + dataset = Dataset.from_ase(dataset_path, **self.ase_keys, weight=weights) elif self.dataset_type == DataSource.COLABFIT: dataset = Dataset.from_colabfit( self.colabfit_dataset["dataset_name"], self.colabfit_dataset["database_name"], self.colabfit_dataset["database_url"], + weight=weights, ) else: raise TrainerError(f"Dataset type {self.dataset_type} not supported.") @@ -559,22 +628,28 @@ def setup_dataset(self): if self.property_transform_options.get("instance") is not None: self.property_transform = self.property_transform_options["instance"] else: - try: - # try getting class "name" from kliff.transforms.property_transforms - module = importlib.import_module( - "kliff.transforms.property_transforms" - ) - class_ = getattr(module, self.property_transform_options["name"]) - self.property_transform = class_( - property_key=self.property_transform_options["property_key"], - ) - except AttributeError: - raise TrainerError( - f"Property transform {self.property_transform_options['name']} not found." - "If it is a custom transform, please provide the instance." - ) - - self.property_transform(dataset) + if self.property_transform_options.get("name"): + try: + # try getting class "name" from kliff.transforms.property_transforms + module = importlib.import_module( + "kliff.transforms.property_transforms" + ) + class_ = getattr( + module, self.property_transform_options["name"] + ) + self.property_transform = class_( + property_key=self.property_transform_options[ + "property_key" + ], + ) + except AttributeError: + raise TrainerError( + f"Property transform {self.property_transform_options['name']} not found." + "If it is a custom transform, please provide the instance." + ) + + if self.property_transform: + self.property_transform(dataset) if self.configuration_transform is not None: if not isinstance(self.configuration_transform, ConfigurationTransform): @@ -583,8 +658,10 @@ def setup_dataset(self): ) else: # check if configuration_instance_options have "instance" - if "instance" in self.configuration_transform_options \ - and self.configuration_transform_options["instance"] is not None: + if ( + "instance" in self.configuration_transform_options + and self.configuration_transform_options["instance"] is not None + ): self.configuration_transform = self.configuration_transform_options[ "instance" ] @@ -671,6 +748,17 @@ def setup_model(self): """ raise TrainerError("setup_model not implemented.") + def setup_parameter_transforms(self): + """ + This method set up the transformed parameter space for models. It can be used + for any model type in general, but as there exists a significant difference + between how models handles their parameters, it is left for the subclass to + implement. Although to ensure that `initialize` function remains consistent + this method will not raise NotImplemented error, rather it will quietly pass. + So be aware. + """ + pass + def setup_optimizer(self): """ Set up the optimizer based on the provided information. If the optimizer is not @@ -698,3 +786,6 @@ def get_optimizer(self, *args, **kwargs): def train(self, *args, **kwargs): TrainerError("train not implemented.") + + def save_kim_model(self, *args, **kwargs): + TrainerError("save_kim_model not implemented.") diff --git a/kliff/trainer/option_enumerations.py b/kliff/trainer/option_enumerations.py new file mode 100644 index 00000000..9a002a05 --- /dev/null +++ b/kliff/trainer/option_enumerations.py @@ -0,0 +1,225 @@ +from enum import Enum + +from kliff._exceptions import TrainerError + + +class ModelTypes(Enum): + """ + Enumerates the different types of models that can be used in the training + process. The different types are: + - KIM: The KIM model, passed as a string of KIM ID. + - TORCH: A model implemented in PyTorch, passed as a Python callable. + - TAR: A model saved in a tar file, which can either be a valid KIM API + model, where it will be installed in CWD using kim-api collections + management, or TorchML driver based model, in which case, the TorchScript + file will be extracted and used. The kind of model is determined from the + KIM API CMakelists.txt file model drive name string. + """ + + UNDEFINED = -1 + KIM = 0 + TORCH = 1 + TAR = 2 + + @staticmethod + def get_model_enum(input_str: str): + """ + Get the model type from the input string. + + Args: + input_str: name of the model type. "kim", "torch", "pt", "pth" or "tar" + + Returns: + Model type enum. + """ + if input_str.lower() == "kim": + return ModelTypes.KIM + elif ( + input_str.lower() == "torch" + or input_str.lower() == "pt" + or input_str.lower() == "pth" + ): + return ModelTypes.TORCH + elif input_str.lower() == "tar": + return ModelTypes.TAR + else: + raise TrainerError(f"Model type {input_str} not supported.") + + @staticmethod + def get_model_str(input_type): + """ + Get the model configuration string from the model type. + + Args: + input_type: input model enum. + + Returns: + Model configuration string. "KIM", "TORCH" or "TAR" + + """ + if input_type == ModelTypes.KIM: + return "KIM" + elif input_type == ModelTypes.TORCH: + return "TORCH" + elif input_type == ModelTypes.TAR: + return "TAR" + else: + raise TrainerError(f"Model type {input_type} not supported.") + + +class DataSource(Enum): + """ + Enumerates the different types of data sources. The different types are: + - ASE: ASE atoms objects, or xyz file with configurations. Uses + ~:class:`~kliff.dataset.Dataset.from_ase` method. + - COLABFIT: uUses ColabFit dataset exchange instance. Uses + ~:class:`~kliff.dataset.Dataset.from_colabfit` method. + - KLIFF: Uses KLIFF compatible extxyz files path. Uses + ~:class:`~kliff.dataset.Dataset.from_path` method. + """ + + UNDEFINED = -1 + ASE = 0 + COLABFIT = 1 + KLIFF = 2 + + @staticmethod + def get_data_enum(input_str: str): + """ + Get the data type from the input string. + + Args: + input_str: name of the data type. "ase", "colabfit" or "kliff" + + Returns: + Data type enum. + + """ + if input_str.lower() == "ase": + return DataSource.ASE + elif input_str.lower() == "colabfit": + return DataSource.COLABFIT + elif input_str.lower() == "kliff": + return DataSource.KLIFF + else: + raise TrainerError(f"Data type {input_str} not supported.") + + @staticmethod + def get_data_str(input_type): + """ + Get the data configuration string from the data type. + + Args: + input_type: input data enum. + + Returns: + Data configuration string. "ASE", "COLABFIT" or "KLIFF" + + """ + if input_type == DataSource.ASE: + return "ASE" + elif input_type == DataSource.COLABFIT: + return "COLABFIT" + elif input_type == DataSource.KLIFF: + return "KLIFF" + else: + raise TrainerError(f"Data type {input_type} not supported.") + + +class ConfigurationTransformationTypes(Enum): + """ + Enumerates the different types of configuration transformations that can be + applied to the input data. The different types are: + - GRAPH: Graph based transformation. + - DESCRIPTORS: Descriptor based transformation. + - NEIGHBORS: No transformation besides neighbor list computation. + """ + + UNDEFINED = -1 + GRAPH = 0 + DESCRIPTORS = 1 + NEIGHBORS = 2 + + @staticmethod + def get_config_transformation_enum(input_str: str): + """ + Get the configuration transformation type from the input string. + + Args: + input_str: name of the configuration transformation type. "graph", "descriptors" or "neighbors" + + Returns: + Configuration transformation type enum. + + """ + if input_str.lower() == "graph": + return ConfigurationTransformationTypes.GRAPH + elif input_str.lower() == "descriptors": + return ConfigurationTransformationTypes.DESCRIPTORS + elif input_str.lower() == "neighbors" or input_str.lower() == "none": + return ConfigurationTransformationTypes.NEIGHBORS + else: + raise TrainerError( + f"Configuration transformation type {input_str} not supported." + ) + + @staticmethod + def get_config_transformation_str(input_type): + """ + Get the configuration transformation configuration string from the + configuration transformation type. + + Args: + input_type: input configuration transformation enum. + + Returns: + Configuration transformation configuration string. "GRAPH", "DESCRIPTORS" or "NEIGHBORS" + + """ + if input_type == ConfigurationTransformationTypes.GRAPH: + return "GRAPH" + elif input_type == ConfigurationTransformationTypes.DESCRIPTORS: + return "DESCRIPTORS" + else: + raise TrainerError( + f"Configuration transformation type {input_type} not supported." + ) + + +class OptimizerProvider(Enum): + """ + Enumerates the different types of optimizer providers that can be used in the + training process. The different types are "TORCH" and "SCIPY". + """ + + UNDEFINED = -1 + TORCH = 0 + SCIPY = 1 + + @staticmethod + def get_optimizer_enum(input_str: str): + """ + Get the optimizer provider from the input string. + + Args: + input_str: name of the optimizer provider. "torch" or "scipy" + + Returns: + Optimizer provider enum. + + """ + if input_str.lower() == "torch": + return OptimizerProvider.TORCH + elif input_str.lower() == "scipy": + return OptimizerProvider.SCIPY + else: + raise TrainerError(f"Optimizer provider {input_str} not supported.") + + @staticmethod + def get_optimizer_str(input_type): + if input_type == OptimizerProvider.TORCH: + return "TORCH" + elif input_type == OptimizerProvider.SCIPY: + return "SCIPY" + else: + raise TrainerError(f"Optimizer provider {input_type} not supported.") diff --git a/kliff/utils.py b/kliff/utils.py index 5cddf2d0..57f3a1f4 100644 --- a/kliff/utils.py +++ b/kliff/utils.py @@ -1,6 +1,7 @@ import os import pickle import random +import subprocess import tarfile from collections.abc import Sequence from pathlib import Path @@ -222,3 +223,40 @@ def stress_to_tensor(input_stress: list) -> np.ndarray: stress[0, 1] = stress[1, 0] = input_stress[5] return stress + + +def is_kim_model_installed(model_name: str) -> bool: + """ + Check if the KIM model is installed in any collection. + + Args: + model_name: name of the model. + """ + model_list = subprocess.run( + ["kim-api-collections-management", "list"], capture_output=True, text=True + ) + if model_name in model_list.stdout: + return True + else: + return False + + +def install_kim_model(model_name: str, collection: str = "user") -> bool: + """ + Install the KIM model + + Args: + model_name: name of the model. + collection: name of the collection. + + Returns: + True if the model is now installed, False otherwise. + """ + if not is_kim_model_installed(model_name): + output = subprocess.run( + ["kim-api-collections-management", "install", collection, model_name], + check=True, + ) + return output.returncode == 0 + else: + return True diff --git a/setup.py b/setup.py index 9c62527b..8cf4b845 100644 --- a/setup.py +++ b/setup.py @@ -113,7 +113,7 @@ def get_readme(): version=get_version(), packages=find_packages(), ext_modules=[sym_fn, bispectrum, neighlist, graph_module], - install_requires=["requests", "scipy", "pyyaml", "monty", "loguru"], + install_requires=["requests", "scipy", "pyyaml", "monty", "loguru", "dill"], extras_require={ "test": [ "pytest", @@ -125,6 +125,7 @@ def get_readme(): "ase", "libdescriptor", "torch_geometric", + "dill", ], "docs": [ "sphinx", diff --git a/tests/dataset/test_weight.py b/tests/dataset/test_weight.py index ec96fd2e..50405198 100644 --- a/tests/dataset/test_weight.py +++ b/tests/dataset/test_weight.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np from kliff.dataset import Dataset @@ -84,3 +86,21 @@ def _compute_magnitude_inverse_weight(c1, c2, norm): """ sigma = np.sqrt(c1**2 + (c2 * norm) ** 2) return 1 / sigma + + +def test_weight_from_file(): + ds = Dataset.from_ase( + Path(__file__).parents[1].joinpath("test_data/configs/Si_4.xyz"), + energy_key="Energy", + forces_key="force", + weight=Path(__file__).parents[1].joinpath("test_data/weights/Si_4_weights.dat"), + ) + configs = ds.get_configs() + assert len(configs) == 4 + assert configs[0].weight.config_weight == 1.0 + assert configs[0].weight.energy_weight == 0.5 + assert configs[0].weight.stress_weight == 1.0 + assert configs[3].weight.forces_weight == 0.5 + assert configs[3].weight.config_weight == 1.0 + assert configs[3].weight.energy_weight == 0.5 + assert configs[3].weight.stress_weight == 4.0 diff --git a/tests/test_data/weights/Si_4_weights.dat b/tests/test_data/weights/Si_4_weights.dat new file mode 100644 index 00000000..a80b9bb1 --- /dev/null +++ b/tests/test_data/weights/Si_4_weights.dat @@ -0,0 +1,4 @@ +1.0 0.5 0.5 1.0 +1.0 0.5 0.5 2.0 +1.0 0.5 0.5 3.0 +1.0 0.5 0.5 4.0 From 01322ecb879cf55d97784c2f3256cde212a64302 Mon Sep 17 00:00:00 2001 From: Amit Gupta Date: Sat, 6 Apr 2024 12:26:41 -0500 Subject: [PATCH 5/8] from config functionality in KIMModel --- kliff/models/kim.py | 226 +++++++++++++++++++++++++++++++++++ kliff/trainer/kim_trainer.py | 175 +-------------------------- 2 files changed, 227 insertions(+), 174 deletions(-) diff --git a/kliff/models/kim.py b/kliff/models/kim.py index 9adf2164..7e2622fb 100644 --- a/kliff/models/kim.py +++ b/kliff/models/kim.py @@ -1,4 +1,6 @@ +import importlib import os +import subprocess from collections import OrderedDict from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Union @@ -12,6 +14,11 @@ from kliff.models.model import ComputeArguments, Model from kliff.models.parameter import Parameter from kliff.neighbor import assemble_forces, assemble_stress +from kliff.utils import install_kim_model, is_kim_model_installed + +from omegaconf import DictConfig, OmegaConf +import kimpy +import tarfile try: import kimpy @@ -21,6 +28,12 @@ except ImportError: kimpy_avail = False +# list of model drivers that are not supported by this trainer. +# example quip, torchml, etc. +# TODO: Get the complete list of unsupported model drivers. +UNSUPPORTED_MODEL_DRIVERS = [ + "TorchML", +] class KIMComputeArguments(ComputeArguments): """ @@ -88,6 +101,8 @@ def __init__( self._update_neigh(influence_distance) self._register_data(compute_energy, compute_forces) + self.model_trainable_via_kim_api = False + def _get_implemented_property(self): """ Get implemented property of model. @@ -681,6 +696,217 @@ def __call__( return kim_ca_instance.results + @staticmethod + def get_model_from_config(model_config: DictConfig, param_config: DictConfig = None): + """ + Get the model from a configuration. If it is a valid KIM model, it will return + the KIMModel object. If it is a TorchML model, it will return the torch + ReverseScriptedModule object *in future*. Else raise error. If the model is a tarball, it + will extract and install the model. + + Example `model_config`: + ```yaml + model: + model_type: kim # kim or torch or tar + model_path: ./ + model_name: SW_StillingerWeber_1985_Si__MO_405512056662_006 # KIM model name, installed if missing + model_collection: "user" + ``` + + Example `param_config`: + ```yaml + parameter: + parameter_list: # optional for KIM models, list of parameters to optimize + - A # dict means the parameter is transformed + - B # these are the parameters that are not transformed + - sigma: + transform_name: LogParameterTransform + value: 2.0 + bounds: [[1.0, 10.0]] + ``` + + ```{note} + `parameter` block is usually defined as the children of the `transform` block + in trainer configuration file. + ``` + + Args: + model_config: configuration object + param_config: parameter transformation configuration + + Returns: + Model object + """ + model_name = model_config.model_name + model_type = model_config.model_type + model_path = model_config.model_path + model_driver = KIMModel.get_model_driver_name(model_name) + model_collection = model_config.model_collection + + if model_driver in UNSUPPORTED_MODEL_DRIVERS: + logger.error("Model driver not supported for KIM-API based training. " + "Please use appropriate trainer for this model.") + raise KIMModelError(f"Model driver {model_driver} not supported for KIMModel training.") + + # ensure model is installed + if model_type.lower() == "kim": + is_model_installed = install_kim_model(model_name, model_collection) + if not is_model_installed: + logger.error( + f"Mode: {model_name} neither installed nor available in the KIM API collections. Please check the model name and try again." + ) + raise KIMModelError(f"Model {model_name} not found.") + else: + logger.info(f"Model {model_name} is present in {model_collection} collection.") + elif model_type.lower() == "tar": + archive_content = tarfile.open(model_path + "/" + model_name) + model = archive_content.getnames()[0] + archive_content.extractall(model_path) + subprocess.run( + [ + "kim-api-collections-management", + "install", + "--force", + model_collection, + model_path + "/" + model, + ], + check=True, + ) + logger.info(f"Tarball Model {model} installed in {model_collection} collection.") + else: + raise KIMModelError(f"Model type {model_type} not supported.") + + model = KIMModel(model_name) + + if param_config: + parameter_list = param_config.parameter_list + mutable_param_list = [] + for param_to_transform in parameter_list: + if isinstance(param_to_transform, dict): + parameter_name = list(param_to_transform.keys())[0] + elif isinstance(param_to_transform, str): + parameter_name = param_to_transform + else: + raise KIMModelError(f"Parameter can be a str or dict") + mutable_param_list.append(parameter_name) + + model.set_params_mutable(mutable_param_list) + model_param_list = model.parameters() + + # apply transforms if needed + for model_params, input_params in zip(model_param_list, parameter_list): + if isinstance(input_params, dict): + param_name = list(input_params.keys())[0] + if param_name != model_params.name: + raise KIMModelError( + f"Parameter name mismatch. Expected {model_params.name}, got {param_name}." + ) + + param_value_dict = input_params[param_name] + transform_name = param_value_dict.get("transform_name", None) + params_value = param_value_dict.get("value", None) + bounds = param_value_dict.get("bounds", None) + + if transform_name is not None: + transform_module = getattr( + importlib.import_module( + f"kliff.transforms.parameter_transforms" + ), + transform_name, + ) + transform_module = transform_module() + model_params.add_transform(transform_module) + + if params_value is not None: + model_params.copy_from_model_space(params_value) + + if bounds is not None: + model_params.add_bounds_model_space(np.array(bounds)) + + elif isinstance(input_params, str): + if input_params != model_params.name: + raise KIMModelError( + f"Parameter name mismatch. Expected {model_params.name}, got {input_params}." + ) + else: + raise KIMModelError( + f"Optimizable parameters must be string or value dict. Got {input_params} instead." + ) + + return model + + @staticmethod + def get_model_driver_name(model_name: str) -> Union[str, None]: + """ + Get the model driver from the model name. It will return the model driver + string from the installed KIM API model. If the model is not installed, and the + model name is a tarball, it will extract the model driver name from the CMakeLists.txt. + This is needed to ensure that it excludes the model drivers that it cannot handle. + Example: TorchML driver based models. These models are to be trained using the + TorchTrainer. + + TODO: This is not a clean solution. I think KIMPY must have a better way to handle this. + Ask Mingjian/Yaser for comment. + + Args: + model_name: name of the model. + + Returns: + Model driver name. + """ + # check if model is tarball + if "tar" in model_name: + return KIMModel._get_model_driver_name_for_tarball(model_name) + + collections = kimpy.collections.create() + try: + shared_obj_path, collection = ( + collections.get_item_library_file_name_and_collection( + kimpy.collection_item_type.portableModel, model_name + ) + ) + except RuntimeError: # not a portable model + return None + shared_obj_content = open(shared_obj_path, "rb").read() + md_start_idx = shared_obj_content.find(b"model-driver") + + if md_start_idx == -1: + return None + else: + md_start_idx += 15 # length of 'model-driver" "' + md_end_idx = shared_obj_content.find(b'"', md_start_idx) + return shared_obj_content[md_start_idx:md_end_idx].decode("utf-8") + + @staticmethod + def _get_model_driver_name_for_tarball(tarball: str) -> Union[str, None]: + """ + Get the model driver name from the tarball. It will extract the model driver + name from the CMakeLists.txt file in the tarball. This is needed to ensure that + it excludes the model drivers that it cannot handle. Example: TorchML driver based + models. These models are to be trained using the TorchTrainer. + + Args: + tarball: path to the tarball. + + Returns: + Model driver name. + """ + archive_content = tarfile.open(tarball) + cmake_file_path = archive_content.getnames()[0] + "/CMakeLists.txt" + cmake_file = archive_content.extractfile(cmake_file_path) + cmake_file_content = cmake_file.read().decode("utf-8") + + md_start_idx = cmake_file_content.find("DRIVER_NAME") + if md_start_idx == -1: + return None + else: + # name strats at " + md_start_idx = cmake_file_content.find('"', md_start_idx) + 1 + if md_start_idx == -1: + return None + md_end_idx = cmake_file_content.find('"', md_start_idx) + return cmake_file_content[md_start_idx:md_end_idx] + class KIMModelError(Exception): def __init__(self, msg): diff --git a/kliff/trainer/kim_trainer.py b/kliff/trainer/kim_trainer.py index e50f32aa..04e9f337 100644 --- a/kliff/trainer/kim_trainer.py +++ b/kliff/trainer/kim_trainer.py @@ -19,12 +19,7 @@ from .kliff_trainer import Trainer from .option_enumerations import ModelTypes, OptimizerProvider -# list of model drivers that are not supported by this trainer. -# example quip, torchml, etc. -# TODO: Get the complete list of unsupported model drivers. -UNSUPPORTED_MODEL_DRIVERS = [ - "TorchML", -] + SCIPY_MINIMIZE_METHODS = [ "Nelder-Mead", "Powell", @@ -102,60 +97,6 @@ def setup_model(self): self.model = KIMModel(self.model_name) self.parameters = self.model.get_model_params() - def setup_parameter_transforms(self): - """ - This method set up the transformed parameter space for models. It can be used - for any model type in general, but as there exists a significant difference - between how models handles their parameters, it is left for the subclass to - implement. Although to ensure that `initialize` function remains consistent - this method will not raise NotImplemented error, rather it will quietly pass. - So be aware. - """ - self.set_parameters_as_mutable() - mutable_params = self.model.parameters() - parameter_transforms_input = self.parameter_transform_options["parameter_list"] - if parameter_transforms_input is not None: - for model_params, input_params in zip( - mutable_params, parameter_transforms_input - ): - if isinstance(input_params, dict): - param_name = list(input_params.keys())[0] - if param_name != model_params.name: - raise TrainerError( - f"Parameter name mismatch. Expected {model_params.name}, got {param_name}." - ) - - param_value_dict = input_params[param_name] - transform_name = param_value_dict.get("transform_name", None) - params_value = param_value_dict.get("value", None) - bounds = param_value_dict.get("bounds", None) - - if transform_name is not None: - transform_module = getattr( - importlib.import_module( - f"kliff.transforms.parameter_transforms" - ), - transform_name, - ) - transform_module = transform_module() - model_params.add_transform(transform_module) - - if params_value is not None: - model_params.copy_from_model_space(params_value) - - if bounds is not None: - model_params.add_bounds_model_space(np.array(bounds)) - - elif isinstance(input_params, str): - if input_params != model_params.name: - raise TrainerError( - f"Parameter name mismatch. Expected {model_params.name}, got {input_params}." - ) - else: - raise TrainerError( - f"Optimizable parameters must be string or value dict. Got {input_params} instead." - ) - def setup_optimizer(self): """ Set up the optimizer based on the provided information. If the optimizer is not @@ -242,120 +183,6 @@ def _wrapper_func(x): else: logger.error(f"Optimization failed: {result.message}") - @staticmethod - def get_model_driver_name_for_kim(model_name: str) -> Union[str, None]: - """ - Get the model driver from the model name. It will return the model driver - string from the installed KIM API model. If the model is not installed, and the - model name is a tarball, it will extract the model driver name from the CMakeLists.txt. - This is needed to ensure that it excludes the model drivers that it cannot handle. - Example: TorchML driver based models. These models are to be trained using the - TorchTrainer. - - TODO: This is not a clean solution. I think KIMPY must have a better way to handle this. - Ask Mingjian/Yaser for comment. - - Args: - model_name: name of the model. - - Returns: - Model driver name. - """ - collections = kimpy.collections.create() - try: - shared_obj_path, collection = ( - collections.get_item_library_file_name_and_collection( - kimpy.collection_item_type.portableModel, model_name - ) - ) - except RuntimeError: # not a portable model - return None - shared_obj_content = open(shared_obj_path, "rb").read() - md_start_idx = shared_obj_content.find(b"model-driver") - - if md_start_idx == -1: - return None - else: - md_start_idx += 15 # length of 'model-driver" "' - md_end_idx = shared_obj_content.find(b'"', md_start_idx) - return shared_obj_content[md_start_idx:md_end_idx].decode("utf-8") - - @staticmethod - def get_model_driver_name_for_tarball(tarball: str) -> Union[str, None]: - """ - Get the model driver name from the tarball. It will extract the model driver - name from the CMakeLists.txt file in the tarball. This is needed to ensure that - it excludes the model drivers that it cannot handle. Example: TorchML driver based - models. These models are to be trained using the TorchTrainer. - - Args: - tarball: path to the tarball. - - Returns: - Model driver name. - """ - archive_content = tarfile.open(tarball) - cmake_file_path = archive_content.getnames()[0] + "/CMakeLists.txt" - cmake_file = archive_content.extractfile(cmake_file_path) - cmake_file_content = cmake_file.read().decode("utf-8") - - md_start_idx = cmake_file_content.find("DRIVER_NAME") - if md_start_idx == -1: - return None - else: - # name strats at " - md_start_idx = cmake_file_content.find('"', md_start_idx) + 1 - if md_start_idx == -1: - return None - md_end_idx = cmake_file_content.find('"', md_start_idx) - return cmake_file_content[md_start_idx:md_end_idx] - - @staticmethod - def ensure_kim_model_installation(model_name: str, collection: str = "user"): - """ - Ensure that the KIM model is installed. If the model is not installed, it will - install the model in the user collection. If the model is already installed, it - will not do anything. - - Args: - model_name: name of the model. - collection: collection to install the model in. - """ - is_model_installed = install_kim_model(model_name) - if not install_kim_model(model_name): - logger.error( - f"Mode: {model_name} neither installed nor available in the KIM API collections. Please check the model name and try again." - ) - raise TrainerError(f"Model {model_name} not found.") - else: - logger.info(f"Model {model_name} is present in {collection} collection.") - - def ensure_tarball_model_installation(self, tarball: str, collection: str = "user"): - """ - Ensure that the model is installed from the tarball. If the model is not installed, - it will install the model in the user collection. If the model is already installed, - it will reinstall the model. - - Args: - tarball: path to the tarball. - collection: collection to install the model in. - """ - scratch_dir = f"{self.current_run_dir}/.scratch" - archive_content = tarfile.open(tarball) - model = archive_content.getnames()[0] - archive_content.extractall(scratch_dir) - subprocess.run( - [ - "kim-api-collections-management", - "install", - "--force", - collection, - scratch_dir + "/" + model, - ], - check=True, - ) - logger.info(f"Tarball Model {model} installed in {collection} collection.") - def set_parameters_as_mutable(self): if self.parameter_transform_options is not None: for param_to_transform in self.parameter_transform_options[ From 2fa8bb8c47b3a7a4ef06d11c3adcdee4ba35b1e4 Mon Sep 17 00:00:00 2001 From: Amit Gupta Date: Mon, 8 Apr 2024 16:12:50 -0500 Subject: [PATCH 6/8] DS and Model manifest initialization --- kliff/dataset/dataset.py | 591 ++++++++++++++++++++++++++++----------- kliff/dataset/weight.py | 16 ++ kliff/models/kim.py | 26 +- 3 files changed, 450 insertions(+), 183 deletions(-) diff --git a/kliff/dataset/dataset.py b/kliff/dataset/dataset.py index b8424a74..b1f58363 100644 --- a/kliff/dataset/dataset.py +++ b/kliff/dataset/dataset.py @@ -1,17 +1,24 @@ import copy +import json import os from collections.abc import Iterable from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union, Tuple +import dill import numpy as np from loguru import logger from monty.dev import requires +from omegaconf import OmegaConf, DictConfig +from omegaconf.errors import ConfigAttributeError from kliff.dataset.extxyz import read_extxyz, write_extxyz from kliff.dataset.weight import Weight from kliff.utils import stress_to_tensor, stress_to_voigt, to_path +import importlib +import hashlib + # For type checking if TYPE_CHECKING: from colabfit.tools.configuration import Configuration as ColabfitConfiguration @@ -214,8 +221,8 @@ def from_colabfit( forces = np.asarray(property["atomic-forces"]["forces"]["source-value"]) elif property["type"] == "cauchy-stress": stress = np.asarray(property["cauchy-stress"]["stress"]["source-value"]) + stress = stress_to_voigt(stress) - stress = stress_to_voigt(stress) self = cls( cell, species, @@ -573,6 +580,8 @@ def __init__(self, configurations: Iterable = None): "configurations must be a iterable of Configuration objects." ) + self._metadata: dict = {} + @classmethod @requires(MongoDatabase is not None, "colabfit-tools is not installed") def from_colabfit( @@ -611,7 +620,7 @@ def from_colabfit( def _read_from_colabfit( database_client: MongoDatabase, colabfit_dataset: str, - weight: Optional[Union[Weight, Path]] = None, + weight: Optional[Weight] = None, ) -> List[Configuration]: """ Read configurations from colabfit database. @@ -621,11 +630,7 @@ def _read_from_colabfit( fetch database from colabfit-tools dataset. colabfit_dataset: Name of the colabfit dataset instance to read from. weight: an instance that computes the weight of the configuration in the loss - function. If a path is provided, it is used to read the weight from the - file. The file must be a plain text file with 4 whitespace separated - columns: config_weight, energy_weight, forces_weight, and stress_weight. - Length of the file must be equal to the number of configurations, or 1 - (in which case the same weight is used for all configurations). + function. Returns: A list of configurations. @@ -638,37 +643,10 @@ def _read_from_colabfit( logger.error(f"{colabfit_dataset} is either empty or does not exist") raise DatasetError(f"{colabfit_dataset} is either empty or does not exist") - if isinstance(weight, Path): - print(weight) - weights = np.loadtxt(weight) - if weights.ndim == 1 and len(weights) == 4: - weights = np.tile(weights, (len(data_objects), 1)) - elif weights.ndim == 2 and len(weights) == len(data_objects): - pass - else: - raise DatasetError( - "Weight file must be a plain text file with 4 whitespace separated " - "columns: config_weight, energy_weight, forces_weight, and " - "stress_weight. Length of the file must be equal to the number of " - "configurations, or 1 (in which case the same weight is used for all " - "configurations)." - ) - weights = [ - Weight( - config_weight=w[0], - energy_weight=w[1], - forces_weight=w[2], - stress_weight=w[3], - ) - for w in weights - ] - else: - weights = [weight] * len(data_objects) - configs = [] - for data_object, weight_obj in zip(data_objects, weights): + for data_object in data_objects: configs.append( - Configuration.from_colabfit(database_client, data_object, weight_obj) + Configuration.from_colabfit(database_client, data_object, weight) ) if len(configs) <= 0: @@ -704,7 +682,11 @@ def add_from_colabfit( """ # open link to the mongo mongo_client = MongoDatabase(colabfit_database, uri=colabfit_uri) - configs = Dataset._read_from_colabfit(mongo_client, colabfit_dataset, weight) + if isinstance(weight, Weight): + configs = Dataset._read_from_colabfit(mongo_client, colabfit_dataset, weight) + else: + configs = Dataset._read_from_colabfit(mongo_client, colabfit_dataset, None) + self.add_weights(weight) self.configs.extend(configs) @classmethod @@ -737,7 +719,7 @@ def from_path( @staticmethod def _read_from_path( path: Path, - weight: Optional[Union[Weight, Path]] = None, + weight: Optional[Weight] = None, file_format: str = "xyz", ) -> List[Configuration]: """ @@ -747,11 +729,7 @@ def _read_from_path( path: Path of the directory storing the configurations in individual files. For single file with multiple configurations, use `_read_from_ase()` instead. weight: an instance that computes the weight of the configuration in the loss - function. If a path is provided, it is used to read the weight from the - file. The file must be a plain text file with 4 whitespace separated - columns: config_weight, energy_weight, forces_weight, and stress_weight. - Length of the file must be equal to the number of configurations, or 1 - (in which case the same weight is used for all configurations). + function. file_format: Format of the file that stores the configuration, e.g. `xyz`. Returns: @@ -779,37 +757,9 @@ def _read_from_path( parent = path.parent all_files = [path] - if isinstance(weight, Path): - print(weight) - weights = np.loadtxt(weight) - if weights.ndim == 1 and len(weights) == 4: - weights = np.tile(weights, (len(all_files), 1)) - elif weights.ndim == 2 and len(weights) == len(all_files): - pass - else: - raise DatasetError( - "Weight file must be a plain text file with 4 whitespace separated " - "columns: config_weight, energy_weight, forces_weight, and " - "stress_weight. Length of the file must be equal to the number of " - "configurations, or 1 (in which case the same weight is used for all " - "configurations)." - ) - weights = [ - Weight( - config_weight=w[0], - energy_weight=w[1], - forces_weight=w[2], - stress_weight=w[3], - ) - for w in weights - ] - - else: - weights = [weight] * len(all_files) - configs = [ - Configuration.from_file(f, copy.copy(w), file_format) - for f, w in zip(all_files, weights) + Configuration.from_file(f, weight, file_format) + for f in all_files ] if len(configs) <= 0: @@ -839,7 +789,12 @@ def add_from_path( """ if isinstance(path, str): path = Path(path) - configs = self._read_from_path(path, weight, file_format) + + if isinstance(weight, Weight): + configs = self._read_from_path(path, weight, file_format) + else: + configs = self._read_from_path(path, None, file_format) + self.add_weights(weight) self.configs.extend(configs) @classmethod @@ -910,11 +865,7 @@ def _read_from_ase( path: Path the directory (or filename) storing the configurations. ase_atoms_list: A list of ase.Atoms objects. weight: an instance that computes the weight of the configuration in the loss - function. If a path is provided, it is used to read the weight from the - file. The file must be a plain text file with 4 whitespace separated - columns: config_weight, energy_weight, forces_weight, and stress_weight. - Length of the file must be equal to the number of configurations, or 1 - (in which case the same weight is used for all configurations). + function. energy_key: Name of the field in extxyz/ase.Atoms that stores the energy. forces_key: Name of the field in extxyz/ase.Atoms that stores the forces. slices: Slice of the configurations to read. It is used only when `path` is @@ -930,42 +881,14 @@ def _read_from_ase( ) if ase_atoms_list: - if isinstance(weight, Path): - weights = np.loadtxt(weight) - if weights.ndim == 1 and len(weights) == 4: - weights = np.tile(weights, (len(ase_atoms_list), 1)) - if weights.ndim == 2 and len(weights) == len(ase_atoms_list): - pass - else: - raise DatasetError( - "Length of weights must be equal to the number of configurations, or 1 " - "(in which case the same weight is used for all configurations)." - ) - weights = [ - Weight( - config_weight=w[0], - energy_weight=w[1], - forces_weight=w[2], - stress_weight=w[3], - ) - for w in weights - ] - else: - weights = [weight] * len(ase_atoms_list) - - if len(ase_atoms_list) != len(weights): - raise DatasetError( - "Length of weights must be equal to the number of configurations, or 1 " - "(in which case the same weight is used for all configurations)." - ) configs = [ Configuration.from_ase_atoms( config, - weight=copy.copy(weight_obj), + weight=weight, energy_key=energy_key, forces_key=forces_key, ) - for config, weight_obj in zip(ase_atoms_list, weights) + for config, weight_obj in zip(ase_atoms_list) ] else: try: @@ -993,76 +916,24 @@ def _read_from_ase( if len(all_files) == 1: # single xyz file with multiple configs all_configs = ase.io.read(all_files[0], index=slices) - # This code fragment is duplicated because in ASE loading, there can be multiple - # branches on how the configurations are loaded, and it is simplest to - # assign weights accordingly per configuration. - if isinstance(weight, Path): - weights = np.loadtxt(weight) - if weights.ndim == 1 and len(weights) == 4: - weights = np.tile(weights, (len(all_configs), 1)) - if weights.ndim == 2 and len(weights) == len(all_configs): - pass - else: - raise DatasetError( - "Length of weights must be equal to the number of configurations, or 1 " - "(in which case the same weight is used for all configurations)." - ) - weights = [ - Weight( - config_weight=w[0], - energy_weight=w[1], - forces_weight=w[2], - stress_weight=w[3], - ) - for w in weights - ] - else: - weights = [weight] * len(all_configs) - configs = [ Configuration.from_ase_atoms( config, - weight=copy.copy(weight_obj), + weight=weight, energy_key=energy_key, forces_key=forces_key, ) - for config, weight_obj in zip(all_configs, weights) + for config, weight_obj in zip(all_configs) ] else: - # This code fragment is duplicated because in ASE loading, there can be multiple - # branches on how the configurations are loaded, and it is simplest to - # assign weights accordingly per configuration. - if isinstance(weight, Path): - weights = np.loadtxt(weight) - if weights.ndim == 1 and len(weights) == 4: - weights = np.tile(weights, (len(all_files), 1)) - if weights.ndim == 2 and len(weights) == len(all_files): - pass - else: - raise DatasetError( - "Length of weights must be equal to the number of configurations, or 1 " - "(in which case the same weight is used for all configurations)." - ) - weights = [ - Weight( - config_weight=w[0], - energy_weight=w[1], - forces_weight=w[2], - stress_weight=w[3], - ) - for w in weights - ] - else: - weights = [weight] * len(all_files) - configs = [ Configuration.from_ase_atoms( ase.io.read(f), - weight=copy.copy(w), + weight=weight, energy_key=energy_key, forces_key=forces_key, ) - for f, w in zip(all_files, weights) + for f in zip(all_files) ] if len(configs) <= 0: @@ -1077,7 +948,7 @@ def add_from_ase( self, path: Union[Path, str] = None, ase_atoms_list: List[ase.Atoms] = None, - weight: Optional[Weight] = None, + weight: Optional[Union[Weight, Path]] = None, energy_key: str = "energy", forces_key: str = "forces", slices: str = ":", @@ -1116,9 +987,16 @@ def add_from_ase( """ if isinstance(path, str): path = Path(path) - configs = self._read_from_ase( - path, ase_atoms_list, weight, energy_key, forces_key, slices - ) + + if isinstance(weight, Weight): + configs = self._read_from_ase( + path, ase_atoms_list, weight, energy_key, forces_key, slices, file_format + ) + else: + configs = self._read_from_ase( + path, ase_atoms_list, None, energy_key, forces_key, slices, file_format + ) + self.add_weights(weight) self.configs.extend(configs) def get_configs(self) -> List[Configuration]: @@ -1157,6 +1035,379 @@ def __getitem__( configs = [self.configs[i] for i in idx] return Dataset(configs) + def save_weights(self, path: Union[Path, str]): + """ + Save the weights of the configurations to a file. + + Args: + path: Path of the file to save the weights. + """ + path = to_path(path) + with path.open("w") as f: + for config in self.configs: + f.write( + f"{config.weight.config_weight} " + + f"{config.weight.energy_weight} " + + f"{config.weight.forces_weight} " + + f"{config.weight.stress_weight}\n" + ) + + def add_weights(self, path: Union[Path, str]): + """ + Load weights from a text file. The text file should contain 1 to 4 columns, + whitespace seperated, formatted as, + ``` + Config Energy Forces Stress + 1.0 0.0 10.0 0.0 + ``` + ```{note} + The column headers are case-insensitive, but should have same name as above. + The weight of 0.0 will set respective weight as `None`. The length of column can + be either 1 (all configs same weight) or n, where n is the number of configs in + the dataset. + ``` + Missing columns are treated as 0.0, i.e. above example file can also be written + as + ``` + Config Forces + 1.0 10.0 + ``` + + Args: + path: Path to the configuration file + + """ + if path is None: + logger.info("No weights provided.") + return + + weights_data = np.genfromtxt(path, names=True) + weights_col = weights_data.dtype.names + + # sanity checks + if 1 > len(weights_col) > 4: + raise DatasetError("Weights file contains improper number of cols," + "there needs to be at least 1 col, and at most 4") + + if not (weights_data.size == 1 or weights_data.size == len(self)): + raise DatasetError("Weights file contains improper number of rows," + "there can be either 1 row (all weights same), " + "or same number of rows as the configurations.") + + expected_cols = {"config", "energy", "forces", "stress"} + missing_cols = expected_cols - set([col.lower() for col in weights_col]) + + # missing weights are set to 0.0 + weights = {k.lower(): weights_data[k] for k in weights_col} + for fields in missing_cols: + weights[fields] = np.zeros_like(weights["config"]) + + # set weights + for i, config in enumerate(self.configs): + config.weight = Weight( + config_weight=weights["config"][i], + energy_weight=weights["energy"][i], + forces_weight=weights["forces"][i], + stress_weight=weights["stress"][i], + ) + + def add_metadata(self, metadata: dict): + """ + Add metadata to the dataset object. + + Args: + metadata: A dictionary containing the metadata. + """ + if not isinstance(metadata, dict): + raise DatasetError("metadata must be a dictionary.") + self._metadata.update(metadata) + + def get_metadata(self, key: str): + """ + Get the metadata of the dataset. + + Args: + key: Key of the metadata to get. + + Returns: + Value of the metadata. + """ + return self._metadata[key] + + @property + def metadata(self): + """ + Return the metadata of the dataset. + """ + return self._metadata + + def check_properties_consistency(self, properties: List[str] = None): + """ + Check which of the properties of the configurations are consistent. These + consistent properties are saved a list which can be used to get the attributes + from the configurations. "Consistent" in this context means that same property + is available for all the configurations. A property is not considered consistent + if it is None for any of the configurations. + + Args: + properties: List of properties to check for consistency. If None, no + properties are checked. All consistent properties are saved in the + metadata. + """ + if properties is None: + logger.warning("No properties provided to check for consistency.") + return + + property_list = list(copy.deepcopy(properties)) # make it mutable, if not + for config in self.configs: + for prop in property_list: + try: + getattr(config, prop) + except ConfigurationError: + property_list.remove(prop) + + self.add_metadata({"consistent_properties": tuple(property_list)}) + logger.info( + f"Consistent properties: {property_list}, stored in metadata key: `consistent_properties`" + ) + + @staticmethod + def get_manifest_checksum(dataset_manifest: DictConfig, transform_manifest: Optional[DictConfig] = None) -> str: + """ + Get the checksum of the dataset manifest. + + Args: + dataset_manifest: Manifest of the dataset. + transform_manifest: Manifest of the transformation. + + Returns: + Checksum of the manifest. + """ + dataset_str = json.dumps(OmegaConf.to_object(dataset_manifest), sort_keys=True) + if transform_manifest: + transform_str = json.dumps(OmegaConf.to_object(transform_manifest), sort_keys=True) + dataset_str += transform_str + return hashlib.md5(dataset_str.encode()).hexdigest() + + + @staticmethod + def get_datasets_from_manifest(dataset_manifest: DictConfig, transform_manifest: Optional[DictConfig] = None) -> ( + Tuple[Optional["Dataset"], Optional["Dataset"], Optional["Dataset"]]): + """ + Get a dataset from a manifest. + + Examples: + 1. Manifest file for initializing dataset using ASE parser: + ```yaml + dataset: + type: ase # ase or path or colabfit + path: Si.xyz # Path to the dataset + save: True # Save processed dataset to a file + save_path: /folder/to # Save to this folder + shuffle: False # Shuffle the dataset + weights: /path/to/weights.dat # or dictionary with weights + keys: + energy: Energy # Key for energy, if ase dataset is used + forces: forces # Key for forces, if ase dataset is used + training_dataset: + train_size: 3 # Number of training samples + train_indices: # files with indices [optional] + val_dataset: + val_size: 1 # Number of validation samples + val_indices: # files with indices [optional] + test_dataset: + test_size: + test_indices: + ``` + + 2. Manifest file for initializing dataset using KLIFF extxyz parser: + ```yaml + dataset: + type: path # ase or path or colabfit + path: /all/my/xyz # Path to the dataset + save: False # Save processed dataset to a file + shuffle: False # Shuffle the dataset + weights: # same weight for all, or file with weights + config: 1.0 + energy: 0.0 + forces: 10.0 + stress: 0.0 + ``` + + 3. Manifest file for initializing dataset using ColabFit parser: + ```yaml + dataset: + type: colabfit # ase or path or colabfit + save: False # Save processed dataset to a file + shuffle: False # Shuffle the dataset + weights: None + colabfit_dataset: + dataset_name: + database_name: + database_url: + ``` + + For dataset transformation, the transform manifest should be provided. Example, + ```yaml + property: + - energy: + name: NormalizedPropertyTransform + kwargs: + keep_original: True + - forces: + name: RMSNormalizePropertyTransform + kwargs: + keep_original: True + + configuration: # optional: generate fingerprints from the configuration + name: Descriptor # Graph, Descriptor, None + kwargs: + cutoff: 3.7 + species: ["Si"] + descriptor: "SymmetryFunctions" + hyperparameters: set51 + ``` + + TODO: Cross-validation splits, stratified splits, etc. + + Args: + dataset_manifest: List of configurations. + transform_manifest: List of configurations. + + Returns: + A dataset of configurations. + """ + dataset_type = dataset_manifest.type.tolower() + if dataset_type != "ase" and dataset_type != "path" and dataset_type != "colabfit": + raise DatasetError(f"Dataset type {dataset_type} not supported.") + weights = dataset_manifest.weights + if weights is not None: + if isinstance(weights, str): + weights = Path(weights) + elif isinstance(weights, DictConfig): + weights = Weight( + config_weight=weights.get("config", 0.0), + energy_weight=weights.get("energy", 0.0), + forces_weight=weights.get("forces", 0.0), + stress_weight=weights.get("stress", 0.0), + ) + else: + raise DatasetError("Weights must be a path or a dictionary.") + + if dataset_type == "ase": + dataset = Dataset.from_ase( + path=dataset_manifest.get("path", "."), + weight=weights, + energy_key=dataset_manifest.get("keys", {}).get("energy", "energy"), + forces_key=dataset_manifest.get("keys", {}).get("forces", "forces"), + ) + elif dataset_type == "path": + dataset = Dataset.from_path( + path=dataset_manifest.get("path", "."), + weight=weights, + ) + elif dataset_type == "colabfit": + try: + colabfit_dataset = dataset_manifest.colabfit_dataset + colabfit_database = colabfit_dataset.database_name + except ConfigAttributeError: + raise DatasetError("Colabfit dataset or database not provided.") + colabfit_uri = dataset_manifest.get("colabfit_uri", "mongodb://localhost:27017") + + dataset = Dataset.from_colabfit( + colabfit_database=colabfit_database, + colabfit_dataset=colabfit_dataset, + colabfit_uri=colabfit_uri, + weight=weights, + ) + else: + # this should not happen + raise DatasetError(f"Dataset type {dataset_type} not supported.") + + # transforms? + if transform_manifest: + configuration_transform = transform_manifest.get("configuration", None) + property_transform = transform_manifest.get("property", None) + + if property_transform: + for property_to_transform in property_transform: + property_name = list(property_to_transform.keys())[0] + transform_module_name = property_to_transform[property_name].get("name", None) + if not transform_module_name: + raise DatasetError("Property transform module name not provided.") + property_transform_module = importlib.import_module( + f"kliff.transforms.property_transforms" + ) + property_module = getattr(property_transform_module, transform_module_name) + property_module = property_module(proprty_key=property_name, **property_to_transform[property_name].get("kwargs", {})) + dataset = property_module(dataset) + + if configuration_transform: + configuration_module_name = configuration_transform.get("name", None) + if not configuration_module_name: + raise DatasetError("Configuration transform module name not provided.") + configuration_transform_module = importlib.import_module( + f"kliff.transforms.configuration_transforms" + ) + configuration_module = getattr(configuration_transform_module, configuration_module_name) + kwargs = configuration_transform.get("kwargs", None) + if not kwargs: + raise DatasetError("Configuration transform module options not provided.") + configuration_module = configuration_module(**kwargs, copy_to_config=True) + + for config in dataset.configs: + _ = configuration_module(config) + + # dataset hash + dataset_checksum = Dataset.get_manifest_checksum(dataset_manifest, transform_manifest) + dataset.add_metadata({"checksum": dataset_checksum}) + + if dataset_manifest.get("save", False): + # TODO: use Path for compatibility + dataset_save_path = dataset_manifest.get("save_path", "./") + logger.info(f"Saving dataset to {dataset_save_path}/DS_{dataset_checksum[:10]}.pkl") + dill.dump(dataset, open(f"{dataset_save_path}/DS_{dataset_checksum[:10]}.pkl", "wb")) + + # test train splits + train_size = dataset_manifest.training_dataset.get("train_size", len(dataset)) + val_size = dataset_manifest.val_dataset.get("val_size", 0) + test_size = dataset_manifest.test_dataset.get("test_size", 0) + if train_size + val_size + test_size > len(dataset): + raise DatasetError("Sum of train, val, and test sizes is greater than the dataset size.") + + # check if indices are provided + train_indices = dataset_manifest.training_dataset.get("train_indices", np.arange(train_size)) + val_indices = dataset_manifest.val_dataset.get("val_indices", train_size + np.arange(val_size)) + test_indices = dataset_manifest.test_dataset.get("test_indices", train_size + val_size + np.arange(test_size)) + + if isinstance(train_indices, str): + train_indices = np.genfromtxt(train_indices, dtype=int) + if isinstance(val_indices, str): + val_indices = np.genfromtxt(val_indices, dtype=int) + if isinstance(test_indices, str): + test_indices = np.genfromtxt(test_indices, dtype=int) + + if dataset_manifest.get("shuffle", False): + np.random.shuffle(train_indices) + np.random.shuffle(val_indices) + np.random.shuffle(test_indices) + + if train_size > 0: + train_dataset = Dataset([dataset[i] for i in train_indices]) + else: + train_dataset = dataset + + if val_size > 0: + val_dataset = Dataset([dataset[i] for i in val_indices]) + else: + val_dataset = None + if test_size > 0: + test_dataset = Dataset([dataset[i] for i in test_indices]) + else: + test_dataset = None + + return train_dataset, val_dataset, test_dataset + class ConfigurationError(Exception): def __init__(self, msg): diff --git a/kliff/dataset/weight.py b/kliff/dataset/weight.py index dfd2727d..fe5caa7c 100644 --- a/kliff/dataset/weight.py +++ b/kliff/dataset/weight.py @@ -44,18 +44,34 @@ def compute_weight(self, config): def config_weight(self): return self._config_weight + @config_weight.setter + def config_weight(self, value): + self._config_weight = value + @property def energy_weight(self): return self._energy_weight + @energy_weight.setter + def energy_weight(self, value): + self._energy_weight = value + @property def forces_weight(self): return self._forces_weight + @forces_weight.setter + def forces_weight(self, value): + self._forces_weight = value + @property def stress_weight(self): return self._stress_weight + @stress_weight.setter + def stress_weight(self, value): + self._stress_weight = value + def _check_compute_flag(self, config): """ Check whether compute flag correctly set when the corresponding weight in diff --git a/kliff/models/kim.py b/kliff/models/kim.py index 7e2622fb..fa5647aa 100644 --- a/kliff/models/kim.py +++ b/kliff/models/kim.py @@ -697,14 +697,14 @@ def __call__( return kim_ca_instance.results @staticmethod - def get_model_from_config(model_config: DictConfig, param_config: DictConfig = None): + def get_model_from_manifest(model_manifest: DictConfig, param_manifest: DictConfig = None): """ Get the model from a configuration. If it is a valid KIM model, it will return the KIMModel object. If it is a TorchML model, it will return the torch ReverseScriptedModule object *in future*. Else raise error. If the model is a tarball, it will extract and install the model. - Example `model_config`: + Example `model_manifest`: ```yaml model: model_type: kim # kim or torch or tar @@ -713,12 +713,12 @@ def get_model_from_config(model_config: DictConfig, param_config: DictConfig = N model_collection: "user" ``` - Example `param_config`: + Example `param_manifest`: ```yaml parameter: parameter_list: # optional for KIM models, list of parameters to optimize - - A # dict means the parameter is transformed - - B # these are the parameters that are not transformed + - A # dict means the parameter is transformed + - B # these are the parameters that are not transformed - sigma: transform_name: LogParameterTransform value: 2.0 @@ -731,17 +731,17 @@ def get_model_from_config(model_config: DictConfig, param_config: DictConfig = N ``` Args: - model_config: configuration object - param_config: parameter transformation configuration + model_manifest: configuration object + param_manifest: parameter transformation configuration Returns: Model object """ - model_name = model_config.model_name - model_type = model_config.model_type - model_path = model_config.model_path + model_name = model_manifest.model_name + model_type = model_manifest.model_type + model_path = model_manifest.model_path model_driver = KIMModel.get_model_driver_name(model_name) - model_collection = model_config.model_collection + model_collection = model_manifest.model_collection if model_driver in UNSUPPORTED_MODEL_DRIVERS: logger.error("Model driver not supported for KIM-API based training. " @@ -778,8 +778,8 @@ def get_model_from_config(model_config: DictConfig, param_config: DictConfig = N model = KIMModel(model_name) - if param_config: - parameter_list = param_config.parameter_list + if param_manifest: + parameter_list = param_manifest.parameter_list mutable_param_list = [] for param_to_transform in parameter_list: if isinstance(param_to_transform, dict): From e1ef24e2a3155068b4dfde913ccf814aed4d378d Mon Sep 17 00:00:00 2001 From: Amit Gupta Date: Tue, 16 Apr 2024 18:11:02 -0500 Subject: [PATCH 7/8] Moved back from omegaconf to dict --- kliff/_exceptions.py | 13 - kliff/dataset/dataset.py | 110 +--- kliff/models/kim.py | 17 +- kliff/trainer/__init__.py | 2 +- kliff/trainer/kliff_trainer.py | 857 +++++++++------------------ kliff/trainer/option_enumerations.py | 225 ------- 6 files changed, 324 insertions(+), 900 deletions(-) delete mode 100644 kliff/_exceptions.py delete mode 100644 kliff/trainer/option_enumerations.py diff --git a/kliff/_exceptions.py b/kliff/_exceptions.py deleted file mode 100644 index de1c8288..00000000 --- a/kliff/_exceptions.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -This module contains exceptions to be raised in kliff modules, along with details on -where they are raised. -""" - - -class TrainerError(Exception): - """ - Exceptions to be raised in Trainer and associated classes. - """ - - def __init__(self, message): - super().__init__(message) diff --git a/kliff/dataset/dataset.py b/kliff/dataset/dataset.py index b1f58363..4dc856ba 100644 --- a/kliff/dataset/dataset.py +++ b/kliff/dataset/dataset.py @@ -9,8 +9,6 @@ import numpy as np from loguru import logger from monty.dev import requires -from omegaconf import OmegaConf, DictConfig -from omegaconf.errors import ConfigAttributeError from kliff.dataset.extxyz import read_extxyz, write_extxyz from kliff.dataset.weight import Weight @@ -923,7 +921,7 @@ def _read_from_ase( energy_key=energy_key, forces_key=forces_key, ) - for config, weight_obj in zip(all_configs) + for config in all_configs ] else: configs = [ @@ -933,7 +931,7 @@ def _read_from_ase( energy_key=energy_key, forces_key=forces_key, ) - for f in zip(all_files) + for f in all_files ] if len(configs) <= 0: @@ -1172,7 +1170,7 @@ def check_properties_consistency(self, properties: List[str] = None): ) @staticmethod - def get_manifest_checksum(dataset_manifest: DictConfig, transform_manifest: Optional[DictConfig] = None) -> str: + def get_manifest_checksum(dataset_manifest: dict, transform_manifest: Optional[dict] = None) -> str: """ Get the checksum of the dataset manifest. @@ -1183,16 +1181,15 @@ def get_manifest_checksum(dataset_manifest: DictConfig, transform_manifest: Opti Returns: Checksum of the manifest. """ - dataset_str = json.dumps(OmegaConf.to_object(dataset_manifest), sort_keys=True) + dataset_str = json.dumps(dataset_manifest, sort_keys=True) if transform_manifest: - transform_str = json.dumps(OmegaConf.to_object(transform_manifest), sort_keys=True) + transform_str = json.dumps(transform_manifest, sort_keys=True) dataset_str += transform_str return hashlib.md5(dataset_str.encode()).hexdigest() - @staticmethod - def get_datasets_from_manifest(dataset_manifest: DictConfig, transform_manifest: Optional[DictConfig] = None) -> ( - Tuple[Optional["Dataset"], Optional["Dataset"], Optional["Dataset"]]): + def get_dataset_from_manifest(dataset_manifest: dict, transform_manifest: Optional[dict] = None) -> ( + "Dataset"): """ Get a dataset from a manifest. @@ -1209,15 +1206,6 @@ def get_datasets_from_manifest(dataset_manifest: DictConfig, transform_manifest: keys: energy: Energy # Key for energy, if ase dataset is used forces: forces # Key for forces, if ase dataset is used - training_dataset: - train_size: 3 # Number of training samples - train_indices: # files with indices [optional] - val_dataset: - val_size: 1 # Number of validation samples - val_indices: # files with indices [optional] - test_dataset: - test_size: - test_indices: ``` 2. Manifest file for initializing dataset using KLIFF extxyz parser: @@ -1277,14 +1265,14 @@ def get_datasets_from_manifest(dataset_manifest: DictConfig, transform_manifest: Returns: A dataset of configurations. """ - dataset_type = dataset_manifest.type.tolower() + dataset_type = dataset_manifest.get("type").lower() if dataset_type != "ase" and dataset_type != "path" and dataset_type != "colabfit": raise DatasetError(f"Dataset type {dataset_type} not supported.") - weights = dataset_manifest.weights + weights = dataset_manifest.get("weights", None) if weights is not None: if isinstance(weights, str): weights = Path(weights) - elif isinstance(weights, DictConfig): + elif isinstance(weights, dict): weights = Weight( config_weight=weights.get("config", 0.0), energy_weight=weights.get("energy", 0.0), @@ -1308,9 +1296,9 @@ def get_datasets_from_manifest(dataset_manifest: DictConfig, transform_manifest: ) elif dataset_type == "colabfit": try: - colabfit_dataset = dataset_manifest.colabfit_dataset + colabfit_dataset = dataset_manifest.get("colabfit_dataset") colabfit_database = colabfit_dataset.database_name - except ConfigAttributeError: + except KeyError: raise DatasetError("Colabfit dataset or database not provided.") colabfit_uri = dataset_manifest.get("colabfit_uri", "mongodb://localhost:27017") @@ -1326,12 +1314,14 @@ def get_datasets_from_manifest(dataset_manifest: DictConfig, transform_manifest: # transforms? if transform_manifest: - configuration_transform = transform_manifest.get("configuration", None) - property_transform = transform_manifest.get("property", None) + configuration_transform: Union[dict, None] = transform_manifest.get("configuration", None) + property_transform: Union[list, None] = transform_manifest.get("property", None) if property_transform: for property_to_transform in property_transform: - property_name = list(property_to_transform.keys())[0] + property_name = property_to_transform.get("name", None) + if not property_name: + continue # it is probably an empty propery transform_module_name = property_to_transform[property_name].get("name", None) if not transform_module_name: raise DatasetError("Property transform module name not provided.") @@ -1343,20 +1333,22 @@ def get_datasets_from_manifest(dataset_manifest: DictConfig, transform_manifest: dataset = property_module(dataset) if configuration_transform: - configuration_module_name = configuration_transform.get("name", None) + configuration_module_name: Union[str, None] = configuration_transform.get("name", None) if not configuration_module_name: - raise DatasetError("Configuration transform module name not provided.") - configuration_transform_module = importlib.import_module( - f"kliff.transforms.configuration_transforms" - ) - configuration_module = getattr(configuration_transform_module, configuration_module_name) - kwargs = configuration_transform.get("kwargs", None) - if not kwargs: - raise DatasetError("Configuration transform module options not provided.") - configuration_module = configuration_module(**kwargs, copy_to_config=True) + logger.warning("Configuration transform module name not provided." + "Skipping configuration transform.") + else: + configuration_transform_module = importlib.import_module( + f"kliff.transforms.configuration_transforms" + ) + configuration_module = getattr(configuration_transform_module, configuration_module_name) + kwargs: Union[dict, None] = configuration_transform.get("kwargs", None) + if not kwargs: + raise DatasetError("Configuration transform module options not provided.") + configuration_module = configuration_module(**kwargs, copy_to_config=True) - for config in dataset.configs: - _ = configuration_module(config) + for config in dataset.configs: + _ = configuration_module(config) # dataset hash dataset_checksum = Dataset.get_manifest_checksum(dataset_manifest, transform_manifest) @@ -1368,45 +1360,7 @@ def get_datasets_from_manifest(dataset_manifest: DictConfig, transform_manifest: logger.info(f"Saving dataset to {dataset_save_path}/DS_{dataset_checksum[:10]}.pkl") dill.dump(dataset, open(f"{dataset_save_path}/DS_{dataset_checksum[:10]}.pkl", "wb")) - # test train splits - train_size = dataset_manifest.training_dataset.get("train_size", len(dataset)) - val_size = dataset_manifest.val_dataset.get("val_size", 0) - test_size = dataset_manifest.test_dataset.get("test_size", 0) - if train_size + val_size + test_size > len(dataset): - raise DatasetError("Sum of train, val, and test sizes is greater than the dataset size.") - - # check if indices are provided - train_indices = dataset_manifest.training_dataset.get("train_indices", np.arange(train_size)) - val_indices = dataset_manifest.val_dataset.get("val_indices", train_size + np.arange(val_size)) - test_indices = dataset_manifest.test_dataset.get("test_indices", train_size + val_size + np.arange(test_size)) - - if isinstance(train_indices, str): - train_indices = np.genfromtxt(train_indices, dtype=int) - if isinstance(val_indices, str): - val_indices = np.genfromtxt(val_indices, dtype=int) - if isinstance(test_indices, str): - test_indices = np.genfromtxt(test_indices, dtype=int) - - if dataset_manifest.get("shuffle", False): - np.random.shuffle(train_indices) - np.random.shuffle(val_indices) - np.random.shuffle(test_indices) - - if train_size > 0: - train_dataset = Dataset([dataset[i] for i in train_indices]) - else: - train_dataset = dataset - - if val_size > 0: - val_dataset = Dataset([dataset[i] for i in val_indices]) - else: - val_dataset = None - if test_size > 0: - test_dataset = Dataset([dataset[i] for i in test_indices]) - else: - test_dataset = None - - return train_dataset, val_dataset, test_dataset + return dataset class ConfigurationError(Exception): diff --git a/kliff/models/kim.py b/kliff/models/kim.py index fa5647aa..84f001f4 100644 --- a/kliff/models/kim.py +++ b/kliff/models/kim.py @@ -16,7 +16,6 @@ from kliff.neighbor import assemble_forces, assemble_stress from kliff.utils import install_kim_model, is_kim_model_installed -from omegaconf import DictConfig, OmegaConf import kimpy import tarfile @@ -697,7 +696,7 @@ def __call__( return kim_ca_instance.results @staticmethod - def get_model_from_manifest(model_manifest: DictConfig, param_manifest: DictConfig = None): + def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None): """ Get the model from a configuration. If it is a valid KIM model, it will return the KIMModel object. If it is a TorchML model, it will return the torch @@ -716,7 +715,6 @@ def get_model_from_manifest(model_manifest: DictConfig, param_manifest: DictConf Example `param_manifest`: ```yaml parameter: - parameter_list: # optional for KIM models, list of parameters to optimize - A # dict means the parameter is transformed - B # these are the parameters that are not transformed - sigma: @@ -737,11 +735,11 @@ def get_model_from_manifest(model_manifest: DictConfig, param_manifest: DictConf Returns: Model object """ - model_name = model_manifest.model_name - model_type = model_manifest.model_type - model_path = model_manifest.model_path + model_name: Union[None, str] = model_manifest.get("model_name", None) + model_type: Union[None, str] = model_manifest.get("model_type", None) + model_path: Union[None, str, Path] = model_manifest.get("model_path", None) model_driver = KIMModel.get_model_driver_name(model_name) - model_collection = model_manifest.model_collection + model_collection = model_manifest.get("model_collection") if model_driver in UNSUPPORTED_MODEL_DRIVERS: logger.error("Model driver not supported for KIM-API based training. " @@ -779,9 +777,8 @@ def get_model_from_manifest(model_manifest: DictConfig, param_manifest: DictConf model = KIMModel(model_name) if param_manifest: - parameter_list = param_manifest.parameter_list mutable_param_list = [] - for param_to_transform in parameter_list: + for param_to_transform in param_manifest: if isinstance(param_to_transform, dict): parameter_name = list(param_to_transform.keys())[0] elif isinstance(param_to_transform, str): @@ -794,7 +791,7 @@ def get_model_from_manifest(model_manifest: DictConfig, param_manifest: DictConf model_param_list = model.parameters() # apply transforms if needed - for model_params, input_params in zip(model_param_list, parameter_list): + for model_params, input_params in zip(model_param_list, param_manifest): if isinstance(input_params, dict): param_name = list(input_params.keys())[0] if param_name != model_params.name: diff --git a/kliff/trainer/__init__.py b/kliff/trainer/__init__.py index 18f04637..24372b3f 100644 --- a/kliff/trainer/__init__.py +++ b/kliff/trainer/__init__.py @@ -1,2 +1,2 @@ -from .kim_trainer import KIMTrainer +# from .kim_trainer import KIMTrainer from .kliff_trainer import Trainer diff --git a/kliff/trainer/kliff_trainer.py b/kliff/trainer/kliff_trainer.py index d8de33ce..19c55f7e 100644 --- a/kliff/trainer/kliff_trainer.py +++ b/kliff/trainer/kliff_trainer.py @@ -15,14 +15,12 @@ from loguru import logger import kliff.transforms.configuration_transforms -from kliff._exceptions import TrainerError from kliff.dataset import Dataset from kliff.transforms.configuration_transforms import ConfigurationTransform from kliff.transforms.parameter_transforms import ParameterTransform from kliff.transforms.property_transforms import PropertyTransform from ..dataset.weight import Weight -from .option_enumerations import DataSource, ModelTypes, OptimizerProvider class Trainer: @@ -44,419 +42,228 @@ class will provide the basic functionality for training, such as setting up the implement. Args: - configuration: configuration dictionary + training_manifest: training manifest """ - def __init__(self, configuration: dict): + def __init__(self, training_manifest: dict): # workspace variables - self.workspace_name = None # name of default directory, root - self.workspace_name = ( - None # where to save everything from current run (inside workspace) - ) - self.current_run_title = ( - None # title of current run, usually model name + date and time - ) - self.export_kim_model = False # whether to export the model to KIM model - self.seed = 12345 # random seed - self.resume = False # whether to resume from previous run (conditions apply) - self.walltime = None # maximum walltime for the run + self.workspace: dict = { + "name": "kliff_workspace", + "seed": 12345, + "resume": False, + "walltime": "2:00:00:00", + } # dataset variables - self.dataset_type: DataSource = DataSource.UNDEFINED - self.dataset_path = None - self.dataset_save = None - self.dataset_shuffle = None - self.dataset = None - self.train_size = None - self.val_size = None - self.indices_files: dict = {"train": None, "val": None} - self.ase_keys = {"energy_key": "energy", "forces_key": "forces"} - self.val_indices = None - self.train_indices = None - self.train_dataset = None - self.val_dataset = None - self.colabfit_dataset: dict = { - "dataset_name": None, - "database_name": None, - "database_url": None, + self.dataset_manifest: dict = { + "type": "kliff", + "path": "./", + "save": False, + "shuffle": False, + "ase_keys": {"energy": "energy", "forces": "forces"}, + "colabfit_dataset": { + "dataset_name": None, + "database_name": None, + "database_url": None, + }, } + self.dataset = None # model variables - self.model_type: ModelTypes = ModelTypes.UNDEFINED + self.model_manifest: dict = { + "type": "kim", + "name": None, + "path": None, + "instance": None, + } self.model: Callable = None - self.model_name = None # KIM string or name of pt/pth file - self.model_path = None # path to the model file # transform variables + self.transform_manifest: dict = { + "property": [{ + "name": None, + "property_key": None, + }], + "parameter": [], + "configuration": { + "name": None, + "kwargs": None, + }, + } self.property_transform: PropertyTransform = None - self.property_transform_options = None self.parameter_transform: ParameterTransform = None - self.parameter_transform_options = None self.configuration_transform: ConfigurationTransform = None - self.configuration_transform_options = None # training variables - self.loss_function: Callable = None - self.loss_function_name = None - self.loss_weights: Union[dict, Path] = { - "energy": 1.0, - "forces": None, - "stress": None, - "config": 1.0, + # this is too complicated to put it in singe dict, therefore the training + # block is divided into loss, optimizer, dataset_sample + self.training_manifest: dict = {} + self.loss_manifest: dict = { + "function": "mse", + "weights": { + "energy": 1.0, + "forces": None, + "stress": None, + "config": 1.0, + }, + "normalize_per_atom": False, + "loss_traj": False, + } + self.optimizer_manifest: dict = { + "provider": "scipy", + "name": None, + "learning_rate": None, + "kwargs": None, + "epochs": 10000, + "stop_condition": None, + "num_workers": 1, } - self.loss_normalize_per_atom = False - - self.optimizer_provider: OptimizerProvider = OptimizerProvider.UNDEFINED - self.optimizer = None # instance of optimizer, "scipy" for scipy torch.optim instance for torch - self.optimizer_name = ( - None # name of optimizer, e.g. "l-bfgs-b" for scipy, "adam" for torch - ) - self.optimizer_kwargs = ( - None # kwargs for optimizer to bundle all unknown kwargs - ) - self.learning_rate = None # learning rate for torch - - self.max_epochs = 10000 # maximum number of epochs - self.device = "cpu" - self.batch_size = 1 - self.chkpt_interval = 100 - self.stop_condition = None # function to check if training should stop - self.num_workers = 1 # number of workers for data loading - self.verbose = True # whether to print verbose output - # export trained model - self.export_kim_model = False - self.export_model_type = None - self.export_model_name = None - self.export_model_path = None + # part of current? + self.dataset_sample_manifest: dict = { + "train_size": None, + "val_size": None, + "indices_files": {"train": None, "val": None}, + "val_indices": None, + "train_indices": None, + "train_dataset": None, + "val_dataset": None, + } - self.configuration = self.config_from_dict(configuration) + # export trained model + self.export_manifest: dict = { + "model_type": None, + "model_name": None, + "model_path": None, + } # state variables - self.current_epoch = 0 - self.current_step = 0 - self.current_best_loss = None - self.current_best_model = None - self.current_loss = None - self.current_run_dir = None - self.appending_to_previous_run = False - self.current_dataset_hash = None - self.start_current_run_title = None # start time of the current run - self.expected_end_time = None - self.warned_once = False # log all one time warnings - + self.current: dict = { + "run_title": None, + "run_dir": None, + "run_hash": None, + "start_time": None, + "end_time": None, + "best_loss": None, + "best_model": None, + "loss": None, + "epoch": 0, + "step": 0, + "expected_end_time": None, + "warned_once": False, + "dataset_hash": None, + "appending_to_previous_run": False, + "verbose": False, + "chkpt_interval": 100, + } + self.parse_manifest(training_manifest) self.initialize() - @staticmethod - def config_from_dict(configuration: dict): + def parse_manifest(self, manifest: dict): """ - It accepts the raw configuration dictionary, and processes it to the formatted - configuration. This includes mapping the string fields to enums, and setting sane + It accepts the raw manifest dictionary, and processes it to the formatted + manifest. This includes mapping the string fields to enums, and setting sane defaults for missing fields. Args: - configuration: raw incoming dictionary + manifest: raw incoming configuration Returns: - Processed configuration dictionary + Processed manifest """ + _date_time_format = "%Y-%m-%d-%H-%M-%S" start_time = datetime.now() - date_time_str = start_time.strftime("%Y-%m-%d-%H-%M-%S") - processed_configuration = {} + date_time_str = start_time.strftime(_date_time_format) + self.current["start_time"] = start_time # Workspace variables ################################################ - workspace_block = configuration.get("workspace", {}) - if workspace_block == {}: - raise TrainerError("Workspace block not defined in trainer configuration") - processed_configuration["start_time"] = start_time - processed_configuration["workspace_name"] = workspace_block.get( - "name", f"kliff_{date_time_str}" - ) - processed_configuration["current_run_title"] = ( - None # will be assigned in the model block - ) - processed_configuration["export_kim_model"] = workspace_block.get( - "export", False - ) - processed_configuration["seed"] = workspace_block.get("seed", 12345) - processed_configuration["resume"] = workspace_block.get("resume", False) - walltime: Union[str, int] = workspace_block.get("walltime", "2:00:00:00") - if type(walltime) is int: # yaml parsed the time - processed_configuration["walltime"] = timedelta(seconds=walltime) - elif type(walltime) is str: - processed_configuration["walltime"] = timedelta( - days=int(walltime.split(":")[0]), - hours=int(walltime.split(":")[1]), - minutes=int(walltime.split(":")[2]), - seconds=int(walltime.split(":")[3]), - ) + workspace_block: Union[None, dict] = manifest.get("workspace", None) + if workspace_block is None: + logger.warning("Workspace block not found in the configuration. Using default values.") else: - raise TrainerError("Walltime not in correct format. dd:hh:mm:ss expected.") - processed_configuration["expected_end_time"] = ( - start_time + processed_configuration["walltime"] - ) - - # Dataset variables ################################################# - dataset_block = configuration.get("dataset", {}) - if dataset_block == {}: - raise TrainerError( - "Dataset block is missing from the trainer configuration file" - ) - - processed_configuration["dataset_type"] = DataSource.get_data_enum( - dataset_block.get("type", "kliff") - ) - processed_configuration["dataset_path"] = dataset_block.get("path", None) - processed_configuration["dataset_save"] = dataset_block.get("save", False) - processed_configuration["dataset_shuffle"] = dataset_block.get("shuffle", False) - ase_keys = dataset_block.get("keys", {}) - processed_configuration["ase_keys"] = { - "energy_key": ase_keys.get("energy", "energy"), - "forces_key": ase_keys.get("forces", "forces"), - } - train_dataset_info = dataset_block.get("training_dataset", {}) - - # none values will be tackled during dataset loading - processed_configuration["train_size"] = train_dataset_info.get( - "train_size", None - ) - processed_configuration["train_indices"] = train_dataset_info.get( - "train_indices", None - ) - - val_dataset_info = dataset_block.get("validation_dataset", {}) - processed_configuration["val_size"] = val_dataset_info.get("val_size", None) - processed_configuration["val_indices"] = val_dataset_info.get( - "val_indices", None - ) - - processed_configuration["indices_file"] = {"train": None, "val": None} - - if type(processed_configuration["train_indices"]) is str: - processed_configuration["indices_file"] = processed_configuration[ - "train_indices" - ] - - if type(processed_configuration["val_indices"]) is str: - processed_configuration["indices_file"] = processed_configuration[ - "val_indices" - ] - - processed_configuration["train_dataset"] = None # To be assigned - processed_configuration["val_dataset"] = None # To be assigned - processed_configuration["dataset"] = None # To be assigned - colabfit_dict = dataset_block.get("colabfit_dataset", {}) - - processed_configuration["colabfit_dataset"] = { - "dataset_name": colabfit_dict.get("dataset_name", None), - "database_name": colabfit_dict.get("database_name", None), - "database_url": colabfit_dict.get("database_url", None), - } + self.workspace |= workspace_block + + if isinstance(self.workspace["walltime"], int): + expected_end_time = datetime.now() + timedelta(seconds=self.workspace["walltime"]) + else: + expected_end_time = datetime.now() + timedelta( + days=int(self.workspace["walltime"].split(":")[0]), + hours=int(self.workspace["walltime"].split(":")[1]), + minutes=int(self.workspace["walltime"].split(":")[2]), + seconds=int(self.workspace["walltime"].split(":")[3]), + ) + self.current["expected_end_time"] = expected_end_time + + # Dataset manifest ################################################# + dataset_manifest: Union[None, dict] = manifest.get("dataset", None) + if dataset_manifest is None: + raise TrainerError("Dataset block not found in the configuration. Exiting.") + + self.dataset_manifest |= dataset_manifest # model variables #################################################### - model_block = configuration.get("model", {}) - processed_configuration["model_type"] = ModelTypes.get_model_enum( - model_block.get("model_type", "kim") - ) - processed_configuration["model_name"] = model_block.get("model_name", None) - processed_configuration["model_path"] = model_block.get("model_path", None) - processed_configuration["model"] = None # To be assigned - if processed_configuration["model_name"] is None: - processed_configuration["current_run_title"] = ( - f"{processed_configuration['model_type']}_{date_time_str}" - ) + model_manifest: Union[None, dict] = manifest.get("model", None) + if model_manifest is None: + raise TrainerError("Model block not found in the configuration. Exiting.") + self.model_manifest |= model_manifest + + if self.model_manifest.get("name", None) is None: + self.current["run_title"] = f"{self.model_manifest.get('type')}_{date_time_str}" else: - processed_configuration["current_run_title"] = ( - f"{processed_configuration['model_name']}_{date_time_str}" - ) + self.current["run_title"] = f"{self.model_manifest.get('name')}_{date_time_str}" # transform variables #################################################### - transform_block = configuration.get("transforms", {}) - property_transform_sub_block = transform_block.get("property", {}) - parameter_transform_sub_block = transform_block.get("parameter", {}) - configuration_transform_sub_block = transform_block.get("configuration", {}) - - processed_configuration["property_transform_options"] = { - "name": property_transform_sub_block.get("name", None), - "property_key": property_transform_sub_block.get("property_key", None), - } - processed_configuration["property_transform"] = ( - property_transform_sub_block.get("instance", None) - ) # no executable given. initialize on own + transform_manifest: Union[None, dict] = manifest.get("transforms", None) + if transform_manifest is None: + logger.warning("Transform block not found in the configuration. This is bit unusual.") + else: + self.transform_manifest |= transform_manifest - processed_configuration["parameter_transform_options"] = { - "parameter_list": parameter_transform_sub_block.get("parameter_list", None), - } - processed_configuration["parameter_transform"] = ( - parameter_transform_sub_block.get("instance", None) - ) # no executable given. initialize on own + # training variables ######################################################## + training_manifest: Union[None, dict] = manifest.get("training", None) + if training_manifest is None: + logger.warning("Training block not found in the configuration." + "Will try and resume the previous run if possible.") + # TODO: implement resume + self.training_manifest |= training_manifest - processed_configuration["configuration_transform_options"] = ( - configuration_transform_sub_block # this might contain lot of variables - ) - processed_configuration["configuration_transform"] = ( - configuration_transform_sub_block.get("instance", None) - ) # no executable given. initialize on own + if self.training_manifest.get("loss", None) is None: + logger.warning("Loss block not found in the configuration. Using default values.") - # training variables ######################################################## - training_block = configuration.get("training", {}) - loss_block = training_block.get("loss", {}) - processed_configuration["loss_function_name"] = loss_block.get("function", None) - weights = loss_block.get("weights", {}) - if isinstance(weights, str): - processed_configuration["loss_weights"] = Path(weights) - else: - processed_configuration["loss_weights"] = { - "energy": weights.get("energy", 1.0), - "forces": weights.get("forces", None), - "stress": weights.get("stress", None), - "config": weights.get("config", 1.0), - } - processed_configuration["loss_normalize_per_atom"] = loss_block.get( - "normalize_per_atom", False - ) - - optimizer_block = training_block.get("optimizer", {}) - processed_configuration["optimizer_provider"] = ( - OptimizerProvider.get_optimizer_enum( - optimizer_block.get("provider", "scipy") - ) - ) - processed_configuration["optimizer"] = None # To be assigned - processed_configuration["optimizer_name"] = optimizer_block.get("name", None) - processed_configuration["learning_rate"] = optimizer_block.get( - "learning_rate", None - ) - processed_configuration["optimizer_kwargs"] = optimizer_block.get( - "kwargs", None - ) - - processed_configuration["max_epochs"] = training_block.get("max_epochs", 10000) - processed_configuration["device"] = training_block.get("device", "cpu") - processed_configuration["batch_size"] = training_block.get("batch_size", 1) - processed_configuration["chkpt_interval"] = training_block.get( - "chkpt_interval", 100 - ) - processed_configuration["stop_condition"] = training_block.get( - "stop_condition", None - ) - processed_configuration["num_workers"] = training_block.get("num_workers", 1) - processed_configuration["verbose"] = training_block.get("verbose", True) - - # export trained model #################################################### - export_block = configuration.get("export", {}) - processed_configuration["export_kim_model"] = export_block is not {} - if export_block is not {}: - processed_configuration["export_model_type"] = ModelTypes.get_model_enum( - export_block.get("model_type", "kim") - ) - processed_configuration["export_model_name"] = export_block.get( - "model_name", None - ) - processed_configuration["export_model_path"] = export_block.get( - "model_path", None - ) - - return processed_configuration + self.loss_manifest |= self.training_manifest.get("loss") + + if self.training_manifest.get("optimizer", None) is None: + logger.warning("Optimizer block not found in the configuration." + "Will resume the previous run if possible.") + # TODO: implement resume + + self.optimizer_manifest |= self.training_manifest.get("optimizer") + self.optimizer_manifest["epochs"] = self.training_manifest.get("epochs", 10000) + self.optimizer_manifest["stop_condition"] = self.training_manifest.get("stop_condition", None) + self.optimizer_manifest["num_workers"] = self.training_manifest.get("num_workers", 1) + + self.current["chkpt_interval"] = self.training_manifest.get("chkpt_interval", 100) + self.current["verbose"] = self.training_manifest.get("verbose", False) + + # dataset sample variables will be processed in the setup_dataset method + self.export_manifest |= manifest.get("export", {}) def config_to_dict(self): """ Convert the configuration to a dictionary. """ config = {} - config["workspace"] = { - "name": self.workspace_name, - "export": self.export_kim_model, - "seed": self.seed, - "resume": self.resume, - "walltime": self.walltime.total_seconds(), - } - - config["dataset"] = { - "type": DataSource.get_data_str(self.dataset_type), - "path": self.dataset_path, - "save": self.dataset_save, - "shuffle": self.dataset_shuffle, - "training_dataset": { - "train_size": self.train_size, - "train_indices": self.indices_files["train"], - }, - "validation_dataset": { - "val_size": self.val_size, - "val_indices": self.indices_files["val"], - }, - "colabfit_dataset": { - "dataset_name": self.colabfit_dataset["dataset_name"], - "database_name": self.colabfit_dataset["database_name"], - "database_url": self.colabfit_dataset["database_url"], - }, - } - if self.ase_keys is not None: - config["dataset"]["keys"] = { - "energy": self.ase_keys["energy_key"], - "forces": self.ase_keys["forces_key"], - } - - config["model"] = { - "model_type": ModelTypes.get_model_str(self.model_type), - "model_name": self.model_name, - "model_path": self.model_path, - } - - config["transforms"] = { - "property": { - "name": self.property_transform_options["name"], - "property_key": self.property_transform_options["property_key"], - }, - "parameter": { - "parameter_list": self.parameter_transform_options["parameter_list"], - }, - "configuration": { - "name": self.configuration_transform_options["name"], - "kwargs": self.configuration_transform_options, - }, - } - - config["training"] = { - "loss": { - "loss_function": self.loss_function_name, - "weights": { - "energy": self.loss_weights["energy"], - "forces": self.loss_weights["forces"], - "stress": self.loss_weights["stress"], - "config": self.loss_weights["config"], - }, - "normalize_per_atom": self.loss_normalize_per_atom, - }, - "optimizer": { - "provider": OptimizerProvider.get_optimizer_str( - self.optimizer_provider - ), - "name": self.optimizer_name, - "learning_rate": self.learning_rate, - "kwargs": self.optimizer_kwargs, - }, - "epochs": self.max_epochs, - "device": self.device, - "batch_size": self.batch_size, - "chkpt_interval": self.chkpt_interval, - "stop_condition": self.stop_condition, - "num_workers": self.num_workers, - "verbose": self.verbose, - } - if self.export_kim_model: - config["export"] = { - "model_type": ModelTypes.get_model_str(self.export_model_type), - "model_name": self.export_model_name, - "model_path": self.export_model_path, - } - + config |= self.workspace + config |= self.dataset_manifest + config |= self.model_manifest + config |= self.transform_manifest + config |= self.training_manifest return config @classmethod def from_file(cls, filename: Path): """ - Load the configuration from a YAML file. + Load the manifest from a YAML file. Args: filename: name of the yaml file @@ -465,10 +272,8 @@ def from_file(cls, filename: Path): Trainer instance """ - with open(filename, "r") as f: - configuration = yaml.safe_load(f) - configuration["filename"] = str(filename) - return cls(configuration) + manifest = yaml.safe_load(open(filename, "r")) + return cls(manifest) def get_trainer_hash(self): """ @@ -486,8 +291,7 @@ def initialize(self): call setup methods. """ # Step 1 - Assign the processed configuration objects to the class variables - for key, value in self.configuration.items(): - setattr(self, key, value) + # This has been done in the __init__ method # Step 2 - Initialize all seeds self.seed_all() # Step 3 - Set up the workspace folder @@ -509,8 +313,8 @@ def seed_all(self): """ Seed all the random number generators. """ - np.random.seed(self.seed) - random.seed(self.seed) + np.random.seed(self.workspace["seed"]) + random.seed(self.workspace["seed"]) # torch.manual_seed(self.seed) # enable torch seed in respective children # torch.cuda.manual_seed_all(self.seed) # torch.backends.cudnn.deterministic = True @@ -520,220 +324,45 @@ def setup_workspace(self): """ Check all the existing runs in the root directory and see if it finished the run """ - dir_list = sorted(glob(f"{self.workspace_name}*")) - if len(dir_list) == 0 or not self.resume: - self.appending_to_previous_run = False - self.current_run_dir = f"{self.workspace_name}/{self.current_run_title}" - os.makedirs(self.current_run_dir, exist_ok=True) + dir_list = sorted(glob(f"{self.workspace['name']}*")) + if len(dir_list) == 0 or not self.workspace["resume"]: + self.current["appending_to_previous_run"] = False + self.current["run_dir"] = f"{self.workspace['name']}/{self.current['run_title']}" + os.makedirs(self.current["run_dir"], exist_ok=True) else: last_dir = dir_list[-1] was_it_finished = os.path.exists(f"{last_dir}/.finished") if was_it_finished: # start new run - current_run_dir = f"{self.workspace_name}/{self.current_run_title}" + current_run_dir = f"{self.workspace['name']}/{self.current['run_title']}" os.makedirs(current_run_dir, exist_ok=True) - self.appending_to_previous_run = False + self.current["appending_to_previous_run"] = False else: - self.appending_to_previous_run = True - self.current_run_dir = dir_list[-1] + self.current["appending_to_previous_run"] = True + self.current["run_dir"] = dir_list[-1] def setup_dataset(self): """ - Set up the dataset based on the provided information. It will check the - {workspace}/{dataset_name} directory to see if dataset hash is already there. - The dataset hash is determined but hashing the full path to the dataset + - transforms names + configuration transform properties. - """ - dataset_path = "" - dataset_transforms = "" - dataset_hash = "" - if self.dataset_type == DataSource.KLIFF or self.dataset_type == DataSource.ASE: - dataset_path = os.path.abspath(self.dataset_path) - elif self.dataset_type == DataSource.COLABFIT: - dataset_path = ( - self.colabfit_dataset["database_url"] - + "/" - + self.colabfit_dataset["database_name"] - ) - else: - raise TrainerError(f"Dataset type {self.dataset_type} not supported.") - - if self.property_transform_options is not None: - dataset_transforms += str(self.property_transform_options["name"]) - - dataset_transforms += "_" - - if self.configuration_transform_options is not None: - dataset_transforms += self.configuration_transform_options["name"] + "_" - dataset_transforms += str( - self.configuration_transform_options["kwargs"]["cutoff"] - ) - - dataset_hash_str = ( - dataset_path - + "_" - + dataset_transforms - + "_" - + json.dumps(self.loss_weights, sort_keys=True) - ) - dataset_hash = hashlib.md5(dataset_hash_str.encode()).hexdigest() - self.current_dataset_hash = dataset_hash - dataset_dir = f"{self.workspace_name}/{dataset_hash}" - os.makedirs(dataset_dir, exist_ok=True) - try: - self.dataset = dill.load(open(f"{dataset_dir}/dataset.dill", "rb")) - logger.info(f"Dataset found in {dataset_dir}.") - return - except FileNotFoundError: - logger.info( - f"Dataset not found in {self.workspace_name} directory. Creating dataset." - ) - - if isinstance(self.loss_weights, Path): - weights = self.loss_weights - elif isinstance(self.loss_weights, dict): - weights = Weight( - config_weight=self.loss_weights["config"], - energy_weight=self.loss_weights["energy"], - forces_weight=self.loss_weights["forces"], - stress_weight=self.loss_weights["stress"], - ) - elif self.loss_weights is None: - weights = None - else: - raise TrainerError( - "Loss weights should be a dictionary or a path to a file, or None." - ) - - if self.dataset_type == DataSource.KLIFF: - dataset = Dataset.from_path(dataset_path, weight=weights) - elif self.dataset_type == DataSource.ASE: - dataset = Dataset.from_ase(dataset_path, **self.ase_keys, weight=weights) - elif self.dataset_type == DataSource.COLABFIT: - dataset = Dataset.from_colabfit( - self.colabfit_dataset["dataset_name"], - self.colabfit_dataset["database_name"], - self.colabfit_dataset["database_url"], - weight=weights, - ) - else: - raise TrainerError(f"Dataset type {self.dataset_type} not supported.") - - if self.property_transform is not None: - if not isinstance(self.property_transform, PropertyTransform): - raise TrainerError( - "Property transform is not none and not an instance of PropertyTransform." - ) - else: - # check if property_instance_options have "instance" - if self.property_transform_options.get("instance") is not None: - self.property_transform = self.property_transform_options["instance"] - else: - if self.property_transform_options.get("name"): - try: - # try getting class "name" from kliff.transforms.property_transforms - module = importlib.import_module( - "kliff.transforms.property_transforms" - ) - class_ = getattr( - module, self.property_transform_options["name"] - ) - self.property_transform = class_( - property_key=self.property_transform_options[ - "property_key" - ], - ) - except AttributeError: - raise TrainerError( - f"Property transform {self.property_transform_options['name']} not found." - "If it is a custom transform, please provide the instance." - ) - - if self.property_transform: - self.property_transform(dataset) - - if self.configuration_transform is not None: - if not isinstance(self.configuration_transform, ConfigurationTransform): - raise TrainerError( - "Configuration transform is not none and not an instance of ConfigurationTransform." - ) - else: - # check if configuration_instance_options have "instance" - if ( - "instance" in self.configuration_transform_options - and self.configuration_transform_options["instance"] is not None - ): - self.configuration_transform = self.configuration_transform_options[ - "instance" - ] - else: - try: - # try getting class "name" from kliff.transforms.configuration_transforms - module = importlib.import_module( - "kliff.transforms.configuration_transforms" - ) - class_ = getattr( - module, self.configuration_transform_options["name"] - ) - self.configuration_transform = class_( - **self.configuration_transform_options["kwargs"], - copy_to_config=True, - ) - except AttributeError: - raise TrainerError( - f"Configuration transform {self.configuration_transform_options['name']} not found." - "If it is a custom transform, please provide the instance." - ) - for configuration in dataset: - self.configuration_transform(configuration) - - dill.dump(dataset, open(f"{dataset_dir}/dataset.dill", "wb")) - logger.info(f"Dataset saved in {dataset_dir}.") - if self.dataset_shuffle: - random.shuffle(dataset.configs) - self.dataset = dataset + Set up the dataset based on the provided information. - def setup_test_train_datasets(self): + TODO: It will check the {workspace}/{dataset_name} directory to see if dataset hash + is already there. The dataset hash is determined but hashing the full path to + the dataset + transforms names + configuration transform properties. + TODO: reload hashed dataset if it exists. """ - Set up the test and train datasets based on the provided indices. If the indices - are not provided, shuffled serial indices will be used. If val_indices are not - provided, the train_indices no validation dataset will be used. - """ - - # training indices - if self.indices_files["train"] is not None: - self.train_indices = np.load(self.indices_files["train"]) - else: - if self.train_size is not None: - self.train_indices = np.arange(self.train_size) - else: - self.train_indices = np.arange(len(self.dataset)) - - # validation indices - if self.indices_files["val"] is not None: - self.val_indices = np.load(self.indices_files["val"]) - else: - if self.val_size is not None: - self.val_indices = np.arange(self.val_size) - else: - self.val_indices = None - - self.train_dataset = self.dataset[self.train_indices] - self.indices_files["train"] = f"{self.current_run_dir}/train_indices.npy" - self.train_indices.dump(self.indices_files["train"]) - - if self.val_indices: - self.val_dataset = self.dataset[self.val_indices] - self.indices_files["val"] = f"{self.current_run_dir}/val_indices.npy" - self.val_indices.dump(self.indices_files["val"]) + dataset_module_manifest = deepcopy(self.dataset_manifest) + dataset_module_manifest["weights"] = self.loss_manifest["weights"] + self.dataset = Dataset.get_dataset_from_manifest( + dataset_module_manifest, + self.transform_manifest) def save_config(self): """ Hash and save the configuration to the current run directory. """ config_hash = self.get_trainer_hash() - config_file = f"{self.current_run_dir}/{config_hash}.yaml" + config_file = f"{self.current['run_dir']}/{config_hash}.yaml" with open(config_file, "w") as f: - yaml.dump(self.configuration, f, default_flow_style=False) + yaml.dump(self.config_to_dict(), f, default_flow_style=False) logger.info(f"Configuration saved in {config_file}.") def setup_model(self): @@ -769,23 +398,105 @@ def setup_optimizer(self): """ raise TrainerError("setup_optimizer not implemented.") + def setup_test_train_datasets(self): + """ + Simple test train split for now, will have more options like stratification + in the future. + + """ + # test train splits + train_size = self.dataset_sample_manifest.get("train_size", len(self.dataset)) + val_size = self.dataset_sample_manifest.get("val_size", 0) + + # sanity checks + if not isinstance(train_size, int) or train_size < 1: + logger.warning("Train size is not provided or is less than 1. Using full dataset for training.") + train_size = len(self.dataset) + + if not isinstance(val_size, int) or val_size < 0: + logger.warning("Val size is not provided or is less than 0. Using 0 for validation.") + val_size = 0 + + if train_size + val_size > len(self.dataset): + raise TrainerError("Sum of train, val, and test sizes is greater than the dataset size.") + + # check if indices are provided + train_indices = self.dataset_sample_manifest.get("train_indices") + if train_indices is None: + train_indices = np.arange(train_size, dtype=int) + elif isinstance(train_indices, str): + train_indices = np.genfromtxt(train_indices, dtype=int) + else: + TrainerError("train_indices should be a numpy array or a path to a file.") + + val_indices = self.dataset_sample_manifest.get("val_indices") + if val_indices is None: + val_indices = np.arange(train_size, train_size + val_size, dtype=int) + elif isinstance(val_indices, str): + val_indices = np.genfromtxt(val_indices, dtype=int) + else: + TrainerError("val_indices should be a numpy array or a path to a file.") + + if self.dataset_manifest.get("shuffle", False): + # instead of shuffling the main dataset, validation/train indices are shuffled + # this gives better control over future active learning scenarios + np.random.shuffle(train_indices) + np.random.shuffle(val_indices) + + print(train_indices, val_indices, self.dataset) + + train_dataset = self.dataset[train_indices] + + if val_size > 0: + val_dataset = self.dataset[val_indices] + else: + val_dataset = None + + self.dataset_sample_manifest["train_size"] = train_size + self.dataset_sample_manifest["val_size"] = val_size + self.dataset_sample_manifest["train_indices"] = train_indices + self.dataset_sample_manifest["val_indices"] = val_indices + self.dataset_sample_manifest["train_dataset"] = train_dataset + self.dataset_sample_manifest["val_dataset"] = val_dataset + + # save the indices if generated + if isinstance(train_indices, str): + self.dataset_sample_manifest["indices_files"]["train"] = train_indices + else: + self.dataset_sample_manifest["indices_files"]["train"] = f"{self.current['run_dir']}/train_indices.txt" + np.savetxt(self.dataset_sample_manifest["indices_files"]["train"], train_indices, fmt="%d") + + if isinstance(val_indices, str): + self.dataset_sample_manifest["indices_files"]["val"] = val_indices + else: + self.dataset_sample_manifest["indices_files"]["val"] = f"{self.current['run_dir']}/val_indices.txt" + np.savetxt(self.dataset_sample_manifest["indices_files"]["val"], val_indices, fmt="%d") + def loss(self, *args, **kwargs): - TrainerError("loss not implemented.") + raise TrainerError("loss not implemented.") def checkpoint(self, *args, **kwargs): - TrainerError("checkpoint not implemented.") + raise TrainerError("checkpoint not implemented.") def train_step(self, *args, **kwargs): - TrainerError("train_step not implemented.") + raise TrainerError("train_step not implemented.") def validation_step(self, *args, **kwargs): - TrainerError("validation_step not implemented.") + raise TrainerError("validation_step not implemented.") def get_optimizer(self, *args, **kwargs): - TrainerError("get_optimizer not implemented.") + raise TrainerError("get_optimizer not implemented.") def train(self, *args, **kwargs): - TrainerError("train not implemented.") + raise TrainerError("train not implemented.") def save_kim_model(self, *args, **kwargs): - TrainerError("save_kim_model not implemented.") + raise TrainerError("save_kim_model not implemented.") + + +class TrainerError(Exception): + """ + Exceptions to be raised in Trainer and associated classes. + """ + def __init__(self, message): + super().__init__(message) diff --git a/kliff/trainer/option_enumerations.py b/kliff/trainer/option_enumerations.py deleted file mode 100644 index 9a002a05..00000000 --- a/kliff/trainer/option_enumerations.py +++ /dev/null @@ -1,225 +0,0 @@ -from enum import Enum - -from kliff._exceptions import TrainerError - - -class ModelTypes(Enum): - """ - Enumerates the different types of models that can be used in the training - process. The different types are: - - KIM: The KIM model, passed as a string of KIM ID. - - TORCH: A model implemented in PyTorch, passed as a Python callable. - - TAR: A model saved in a tar file, which can either be a valid KIM API - model, where it will be installed in CWD using kim-api collections - management, or TorchML driver based model, in which case, the TorchScript - file will be extracted and used. The kind of model is determined from the - KIM API CMakelists.txt file model drive name string. - """ - - UNDEFINED = -1 - KIM = 0 - TORCH = 1 - TAR = 2 - - @staticmethod - def get_model_enum(input_str: str): - """ - Get the model type from the input string. - - Args: - input_str: name of the model type. "kim", "torch", "pt", "pth" or "tar" - - Returns: - Model type enum. - """ - if input_str.lower() == "kim": - return ModelTypes.KIM - elif ( - input_str.lower() == "torch" - or input_str.lower() == "pt" - or input_str.lower() == "pth" - ): - return ModelTypes.TORCH - elif input_str.lower() == "tar": - return ModelTypes.TAR - else: - raise TrainerError(f"Model type {input_str} not supported.") - - @staticmethod - def get_model_str(input_type): - """ - Get the model configuration string from the model type. - - Args: - input_type: input model enum. - - Returns: - Model configuration string. "KIM", "TORCH" or "TAR" - - """ - if input_type == ModelTypes.KIM: - return "KIM" - elif input_type == ModelTypes.TORCH: - return "TORCH" - elif input_type == ModelTypes.TAR: - return "TAR" - else: - raise TrainerError(f"Model type {input_type} not supported.") - - -class DataSource(Enum): - """ - Enumerates the different types of data sources. The different types are: - - ASE: ASE atoms objects, or xyz file with configurations. Uses - ~:class:`~kliff.dataset.Dataset.from_ase` method. - - COLABFIT: uUses ColabFit dataset exchange instance. Uses - ~:class:`~kliff.dataset.Dataset.from_colabfit` method. - - KLIFF: Uses KLIFF compatible extxyz files path. Uses - ~:class:`~kliff.dataset.Dataset.from_path` method. - """ - - UNDEFINED = -1 - ASE = 0 - COLABFIT = 1 - KLIFF = 2 - - @staticmethod - def get_data_enum(input_str: str): - """ - Get the data type from the input string. - - Args: - input_str: name of the data type. "ase", "colabfit" or "kliff" - - Returns: - Data type enum. - - """ - if input_str.lower() == "ase": - return DataSource.ASE - elif input_str.lower() == "colabfit": - return DataSource.COLABFIT - elif input_str.lower() == "kliff": - return DataSource.KLIFF - else: - raise TrainerError(f"Data type {input_str} not supported.") - - @staticmethod - def get_data_str(input_type): - """ - Get the data configuration string from the data type. - - Args: - input_type: input data enum. - - Returns: - Data configuration string. "ASE", "COLABFIT" or "KLIFF" - - """ - if input_type == DataSource.ASE: - return "ASE" - elif input_type == DataSource.COLABFIT: - return "COLABFIT" - elif input_type == DataSource.KLIFF: - return "KLIFF" - else: - raise TrainerError(f"Data type {input_type} not supported.") - - -class ConfigurationTransformationTypes(Enum): - """ - Enumerates the different types of configuration transformations that can be - applied to the input data. The different types are: - - GRAPH: Graph based transformation. - - DESCRIPTORS: Descriptor based transformation. - - NEIGHBORS: No transformation besides neighbor list computation. - """ - - UNDEFINED = -1 - GRAPH = 0 - DESCRIPTORS = 1 - NEIGHBORS = 2 - - @staticmethod - def get_config_transformation_enum(input_str: str): - """ - Get the configuration transformation type from the input string. - - Args: - input_str: name of the configuration transformation type. "graph", "descriptors" or "neighbors" - - Returns: - Configuration transformation type enum. - - """ - if input_str.lower() == "graph": - return ConfigurationTransformationTypes.GRAPH - elif input_str.lower() == "descriptors": - return ConfigurationTransformationTypes.DESCRIPTORS - elif input_str.lower() == "neighbors" or input_str.lower() == "none": - return ConfigurationTransformationTypes.NEIGHBORS - else: - raise TrainerError( - f"Configuration transformation type {input_str} not supported." - ) - - @staticmethod - def get_config_transformation_str(input_type): - """ - Get the configuration transformation configuration string from the - configuration transformation type. - - Args: - input_type: input configuration transformation enum. - - Returns: - Configuration transformation configuration string. "GRAPH", "DESCRIPTORS" or "NEIGHBORS" - - """ - if input_type == ConfigurationTransformationTypes.GRAPH: - return "GRAPH" - elif input_type == ConfigurationTransformationTypes.DESCRIPTORS: - return "DESCRIPTORS" - else: - raise TrainerError( - f"Configuration transformation type {input_type} not supported." - ) - - -class OptimizerProvider(Enum): - """ - Enumerates the different types of optimizer providers that can be used in the - training process. The different types are "TORCH" and "SCIPY". - """ - - UNDEFINED = -1 - TORCH = 0 - SCIPY = 1 - - @staticmethod - def get_optimizer_enum(input_str: str): - """ - Get the optimizer provider from the input string. - - Args: - input_str: name of the optimizer provider. "torch" or "scipy" - - Returns: - Optimizer provider enum. - - """ - if input_str.lower() == "torch": - return OptimizerProvider.TORCH - elif input_str.lower() == "scipy": - return OptimizerProvider.SCIPY - else: - raise TrainerError(f"Optimizer provider {input_str} not supported.") - - @staticmethod - def get_optimizer_str(input_type): - if input_type == OptimizerProvider.TORCH: - return "TORCH" - elif input_type == OptimizerProvider.SCIPY: - return "SCIPY" - else: - raise TrainerError(f"Optimizer provider {input_type} not supported.") From 614c1b98cd75d3f738d8bfc829bd1417bc9874ed Mon Sep 17 00:00:00 2001 From: Amit Gupta Date: Tue, 16 Apr 2024 22:51:38 -0500 Subject: [PATCH 8/8] Working KIM trainer module --- kliff/dataset/dataset.py | 145 ++++++++++++------ kliff/models/kim.py | 30 ++-- kliff/trainer/__init__.py | 4 +- .../{kliff_trainer.py => base_trainer.py} | 138 +++++++++++------ kliff/trainer/kim_trainer.py | 101 ++++-------- 5 files changed, 237 insertions(+), 181 deletions(-) rename kliff/trainer/{kliff_trainer.py => base_trainer.py} (83%) diff --git a/kliff/dataset/dataset.py b/kliff/dataset/dataset.py index 4dc856ba..02cd0232 100644 --- a/kliff/dataset/dataset.py +++ b/kliff/dataset/dataset.py @@ -1,9 +1,11 @@ import copy +import hashlib +import importlib import json import os from collections.abc import Iterable from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Union, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import dill import numpy as np @@ -14,9 +16,6 @@ from kliff.dataset.weight import Weight from kliff.utils import stress_to_tensor, stress_to_voigt, to_path -import importlib -import hashlib - # For type checking if TYPE_CHECKING: from colabfit.tools.configuration import Configuration as ColabfitConfiguration @@ -681,7 +680,9 @@ def add_from_colabfit( # open link to the mongo mongo_client = MongoDatabase(colabfit_database, uri=colabfit_uri) if isinstance(weight, Weight): - configs = Dataset._read_from_colabfit(mongo_client, colabfit_dataset, weight) + configs = Dataset._read_from_colabfit( + mongo_client, colabfit_dataset, weight + ) else: configs = Dataset._read_from_colabfit(mongo_client, colabfit_dataset, None) self.add_weights(weight) @@ -755,10 +756,7 @@ def _read_from_path( parent = path.parent all_files = [path] - configs = [ - Configuration.from_file(f, weight, file_format) - for f in all_files - ] + configs = [Configuration.from_file(f, weight, file_format) for f in all_files] if len(configs) <= 0: raise DatasetError( @@ -988,7 +986,13 @@ def add_from_ase( if isinstance(weight, Weight): configs = self._read_from_ase( - path, ase_atoms_list, weight, energy_key, forces_key, slices, file_format + path, + ase_atoms_list, + weight, + energy_key, + forces_key, + slices, + file_format, ) else: configs = self._read_from_ase( @@ -1084,13 +1088,17 @@ def add_weights(self, path: Union[Path, str]): # sanity checks if 1 > len(weights_col) > 4: - raise DatasetError("Weights file contains improper number of cols," - "there needs to be at least 1 col, and at most 4") + raise DatasetError( + "Weights file contains improper number of cols," + "there needs to be at least 1 col, and at most 4" + ) if not (weights_data.size == 1 or weights_data.size == len(self)): - raise DatasetError("Weights file contains improper number of rows," - "there can be either 1 row (all weights same), " - "or same number of rows as the configurations.") + raise DatasetError( + "Weights file contains improper number of rows," + "there can be either 1 row (all weights same), " + "or same number of rows as the configurations." + ) expected_cols = {"config", "energy", "forces", "stress"} missing_cols = expected_cols - set([col.lower() for col in weights_col]) @@ -1156,7 +1164,7 @@ def check_properties_consistency(self, properties: List[str] = None): logger.warning("No properties provided to check for consistency.") return - property_list = list(copy.deepcopy(properties)) # make it mutable, if not + property_list = list(copy.deepcopy(properties)) # make it mutable, if not for config in self.configs: for prop in property_list: try: @@ -1170,7 +1178,9 @@ def check_properties_consistency(self, properties: List[str] = None): ) @staticmethod - def get_manifest_checksum(dataset_manifest: dict, transform_manifest: Optional[dict] = None) -> str: + def get_manifest_checksum( + dataset_manifest: dict, transform_manifest: Optional[dict] = None + ) -> str: """ Get the checksum of the dataset manifest. @@ -1188,8 +1198,9 @@ def get_manifest_checksum(dataset_manifest: dict, transform_manifest: Optional[d return hashlib.md5(dataset_str.encode()).hexdigest() @staticmethod - def get_dataset_from_manifest(dataset_manifest: dict, transform_manifest: Optional[dict] = None) -> ( - "Dataset"): + def get_dataset_from_manifest( + dataset_manifest: dict, transform_manifest: Optional[dict] = None + ) -> "Dataset": """ Get a dataset from a manifest. @@ -1266,7 +1277,11 @@ def get_dataset_from_manifest(dataset_manifest: dict, transform_manifest: Option A dataset of configurations. """ dataset_type = dataset_manifest.get("type").lower() - if dataset_type != "ase" and dataset_type != "path" and dataset_type != "colabfit": + if ( + dataset_type != "ase" + and dataset_type != "path" + and dataset_type != "colabfit" + ): raise DatasetError(f"Dataset type {dataset_type} not supported.") weights = dataset_manifest.get("weights", None) if weights is not None: @@ -1284,15 +1299,15 @@ def get_dataset_from_manifest(dataset_manifest: dict, transform_manifest: Option if dataset_type == "ase": dataset = Dataset.from_ase( - path=dataset_manifest.get("path", "."), - weight=weights, - energy_key=dataset_manifest.get("keys", {}).get("energy", "energy"), - forces_key=dataset_manifest.get("keys", {}).get("forces", "forces"), + path=dataset_manifest.get("path", "."), + weight=weights, + energy_key=dataset_manifest.get("keys", {}).get("energy", "energy"), + forces_key=dataset_manifest.get("keys", {}).get("forces", "forces"), ) elif dataset_type == "path": dataset = Dataset.from_path( - path=dataset_manifest.get("path", "."), - weight=weights, + path=dataset_manifest.get("path", "."), + weight=weights, ) elif dataset_type == "colabfit": try: @@ -1300,13 +1315,15 @@ def get_dataset_from_manifest(dataset_manifest: dict, transform_manifest: Option colabfit_database = colabfit_dataset.database_name except KeyError: raise DatasetError("Colabfit dataset or database not provided.") - colabfit_uri = dataset_manifest.get("colabfit_uri", "mongodb://localhost:27017") + colabfit_uri = dataset_manifest.get( + "colabfit_uri", "mongodb://localhost:27017" + ) dataset = Dataset.from_colabfit( - colabfit_database=colabfit_database, - colabfit_dataset=colabfit_dataset, - colabfit_uri=colabfit_uri, - weight=weights, + colabfit_database=colabfit_database, + colabfit_dataset=colabfit_dataset, + colabfit_uri=colabfit_uri, + weight=weights, ) else: # this should not happen @@ -1314,51 +1331,83 @@ def get_dataset_from_manifest(dataset_manifest: dict, transform_manifest: Option # transforms? if transform_manifest: - configuration_transform: Union[dict, None] = transform_manifest.get("configuration", None) - property_transform: Union[list, None] = transform_manifest.get("property", None) + configuration_transform: Union[dict, None] = transform_manifest.get( + "configuration", None + ) + property_transform: Union[list, None] = transform_manifest.get( + "property", None + ) if property_transform: for property_to_transform in property_transform: property_name = property_to_transform.get("name", None) if not property_name: - continue # it is probably an empty propery - transform_module_name = property_to_transform[property_name].get("name", None) + continue # it is probably an empty propery + transform_module_name = property_to_transform[property_name].get( + "name", None + ) if not transform_module_name: - raise DatasetError("Property transform module name not provided.") + raise DatasetError( + "Property transform module name not provided." + ) property_transform_module = importlib.import_module( f"kliff.transforms.property_transforms" ) - property_module = getattr(property_transform_module, transform_module_name) - property_module = property_module(proprty_key=property_name, **property_to_transform[property_name].get("kwargs", {})) + property_module = getattr( + property_transform_module, transform_module_name + ) + property_module = property_module( + proprty_key=property_name, + **property_to_transform[property_name].get("kwargs", {}), + ) dataset = property_module(dataset) if configuration_transform: - configuration_module_name: Union[str, None] = configuration_transform.get("name", None) + configuration_module_name: Union[str, None] = ( + configuration_transform.get("name", None) + ) if not configuration_module_name: - logger.warning("Configuration transform module name not provided." - "Skipping configuration transform.") + logger.warning( + "Configuration transform module name not provided." + "Skipping configuration transform." + ) else: configuration_transform_module = importlib.import_module( f"kliff.transforms.configuration_transforms" ) - configuration_module = getattr(configuration_transform_module, configuration_module_name) - kwargs: Union[dict, None] = configuration_transform.get("kwargs", None) + configuration_module = getattr( + configuration_transform_module, configuration_module_name + ) + kwargs: Union[dict, None] = configuration_transform.get( + "kwargs", None + ) if not kwargs: - raise DatasetError("Configuration transform module options not provided.") - configuration_module = configuration_module(**kwargs, copy_to_config=True) + raise DatasetError( + "Configuration transform module options not provided." + ) + configuration_module = configuration_module( + **kwargs, copy_to_config=True + ) for config in dataset.configs: _ = configuration_module(config) # dataset hash - dataset_checksum = Dataset.get_manifest_checksum(dataset_manifest, transform_manifest) + dataset_checksum = Dataset.get_manifest_checksum( + dataset_manifest, transform_manifest + ) dataset.add_metadata({"checksum": dataset_checksum}) if dataset_manifest.get("save", False): # TODO: use Path for compatibility dataset_save_path = dataset_manifest.get("save_path", "./") - logger.info(f"Saving dataset to {dataset_save_path}/DS_{dataset_checksum[:10]}.pkl") - dill.dump(dataset, open(f"{dataset_save_path}/DS_{dataset_checksum[:10]}.pkl", "wb")) + logger.info( + f"Saving dataset to {dataset_save_path}/DS_{dataset_checksum[:10]}.pkl" + ) + dill.dump( + dataset, + open(f"{dataset_save_path}/DS_{dataset_checksum[:10]}.pkl", "wb"), + ) return dataset diff --git a/kliff/models/kim.py b/kliff/models/kim.py index 84f001f4..a82f3ffb 100644 --- a/kliff/models/kim.py +++ b/kliff/models/kim.py @@ -1,10 +1,12 @@ import importlib import os import subprocess +import tarfile from collections import OrderedDict from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Union +import kimpy import numpy as np from loguru import logger @@ -16,9 +18,6 @@ from kliff.neighbor import assemble_forces, assemble_stress from kliff.utils import install_kim_model, is_kim_model_installed -import kimpy -import tarfile - try: import kimpy from kimpy import neighlist as nl @@ -34,6 +33,7 @@ "TorchML", ] + class KIMComputeArguments(ComputeArguments): """ KIMModel potentials arguments. @@ -742,9 +742,13 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None): model_collection = model_manifest.get("model_collection") if model_driver in UNSUPPORTED_MODEL_DRIVERS: - logger.error("Model driver not supported for KIM-API based training. " - "Please use appropriate trainer for this model.") - raise KIMModelError(f"Model driver {model_driver} not supported for KIMModel training.") + logger.error( + "Model driver not supported for KIM-API based training. " + "Please use appropriate trainer for this model." + ) + raise KIMModelError( + f"Model driver {model_driver} not supported for KIMModel training." + ) # ensure model is installed if model_type.lower() == "kim": @@ -755,7 +759,9 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None): ) raise KIMModelError(f"Model {model_name} not found.") else: - logger.info(f"Model {model_name} is present in {model_collection} collection.") + logger.info( + f"Model {model_name} is present in {model_collection} collection." + ) elif model_type.lower() == "tar": archive_content = tarfile.open(model_path + "/" + model_name) model = archive_content.getnames()[0] @@ -770,7 +776,9 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None): ], check=True, ) - logger.info(f"Tarball Model {model} installed in {model_collection} collection.") + logger.info( + f"Tarball Model {model} installed in {model_collection} collection." + ) else: raise KIMModelError(f"Model type {model_type} not supported.") @@ -778,7 +786,7 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None): if param_manifest: mutable_param_list = [] - for param_to_transform in param_manifest: + for param_to_transform in param_manifest.get("parameter", []): if isinstance(param_to_transform, dict): parameter_name = list(param_to_transform.keys())[0] elif isinstance(param_to_transform, str): @@ -791,7 +799,9 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None): model_param_list = model.parameters() # apply transforms if needed - for model_params, input_params in zip(model_param_list, param_manifest): + for model_params, input_params in zip( + model_param_list, param_manifest.get("parameter", []) + ): if isinstance(input_params, dict): param_name = list(input_params.keys())[0] if param_name != model_params.name: diff --git a/kliff/trainer/__init__.py b/kliff/trainer/__init__.py index 24372b3f..d669e814 100644 --- a/kliff/trainer/__init__.py +++ b/kliff/trainer/__init__.py @@ -1,2 +1,2 @@ -# from .kim_trainer import KIMTrainer -from .kliff_trainer import Trainer +from .base_trainer import Trainer +from .kim_trainer import KIMTrainer diff --git a/kliff/trainer/kliff_trainer.py b/kliff/trainer/base_trainer.py similarity index 83% rename from kliff/trainer/kliff_trainer.py rename to kliff/trainer/base_trainer.py index 19c55f7e..184d1756 100644 --- a/kliff/trainer/kliff_trainer.py +++ b/kliff/trainer/base_trainer.py @@ -1,5 +1,4 @@ import hashlib -import importlib import json import os import random @@ -14,13 +13,7 @@ import yaml from loguru import logger -import kliff.transforms.configuration_transforms from kliff.dataset import Dataset -from kliff.transforms.configuration_transforms import ConfigurationTransform -from kliff.transforms.parameter_transforms import ParameterTransform -from kliff.transforms.property_transforms import PropertyTransform - -from ..dataset.weight import Weight class Trainer: @@ -80,19 +73,18 @@ def __init__(self, training_manifest: dict): # transform variables self.transform_manifest: dict = { - "property": [{ - "name": None, - "property_key": None, - }], + "property": [ + { + "name": None, + "property_key": None, + } + ], "parameter": [], "configuration": { "name": None, "kwargs": None, }, } - self.property_transform: PropertyTransform = None - self.parameter_transform: ParameterTransform = None - self.configuration_transform: ConfigurationTransform = None # training variables # this is too complicated to put it in singe dict, therefore the training @@ -109,6 +101,7 @@ def __init__(self, training_manifest: dict): "normalize_per_atom": False, "loss_traj": False, } + self.optimizer_manifest: dict = { "provider": "scipy", "name": None, @@ -118,6 +111,7 @@ def __init__(self, training_manifest: dict): "stop_condition": None, "num_workers": 1, } + self.optimizer = None # part of current? self.dataset_sample_manifest: dict = { @@ -126,9 +120,9 @@ def __init__(self, training_manifest: dict): "indices_files": {"train": None, "val": None}, "val_indices": None, "train_indices": None, - "train_dataset": None, - "val_dataset": None, } + self.train_dataset = None + self.val_dataset = None # export trained model self.export_manifest: dict = { @@ -154,7 +148,7 @@ def __init__(self, training_manifest: dict): "dataset_hash": None, "appending_to_previous_run": False, "verbose": False, - "chkpt_interval": 100, + "ckpt_interval": 100, } self.parse_manifest(training_manifest) self.initialize() @@ -179,19 +173,23 @@ def parse_manifest(self, manifest: dict): # Workspace variables ################################################ workspace_block: Union[None, dict] = manifest.get("workspace", None) if workspace_block is None: - logger.warning("Workspace block not found in the configuration. Using default values.") + logger.warning( + "Workspace block not found in the configuration. Using default values." + ) else: self.workspace |= workspace_block if isinstance(self.workspace["walltime"], int): - expected_end_time = datetime.now() + timedelta(seconds=self.workspace["walltime"]) + expected_end_time = datetime.now() + timedelta( + seconds=self.workspace["walltime"] + ) else: expected_end_time = datetime.now() + timedelta( - days=int(self.workspace["walltime"].split(":")[0]), - hours=int(self.workspace["walltime"].split(":")[1]), - minutes=int(self.workspace["walltime"].split(":")[2]), - seconds=int(self.workspace["walltime"].split(":")[3]), - ) + days=int(self.workspace["walltime"].split(":")[0]), + hours=int(self.workspace["walltime"].split(":")[1]), + minutes=int(self.workspace["walltime"].split(":")[2]), + seconds=int(self.workspace["walltime"].split(":")[3]), + ) self.current["expected_end_time"] = expected_end_time # Dataset manifest ################################################# @@ -208,41 +206,57 @@ def parse_manifest(self, manifest: dict): self.model_manifest |= model_manifest if self.model_manifest.get("name", None) is None: - self.current["run_title"] = f"{self.model_manifest.get('type')}_{date_time_str}" + self.current["run_title"] = ( + f"{self.model_manifest.get('type')}_{date_time_str}" + ) else: - self.current["run_title"] = f"{self.model_manifest.get('name')}_{date_time_str}" + self.current["run_title"] = ( + f"{self.model_manifest.get('name')}_{date_time_str}" + ) # transform variables #################################################### transform_manifest: Union[None, dict] = manifest.get("transforms", None) if transform_manifest is None: - logger.warning("Transform block not found in the configuration. This is bit unusual.") + logger.warning( + "Transform block not found in the configuration. This is bit unusual." + ) else: self.transform_manifest |= transform_manifest # training variables ######################################################## training_manifest: Union[None, dict] = manifest.get("training", None) if training_manifest is None: - logger.warning("Training block not found in the configuration." - "Will try and resume the previous run if possible.") + logger.warning( + "Training block not found in the configuration." + "Will try and resume the previous run if possible." + ) # TODO: implement resume self.training_manifest |= training_manifest if self.training_manifest.get("loss", None) is None: - logger.warning("Loss block not found in the configuration. Using default values.") + logger.warning( + "Loss block not found in the configuration. Using default values." + ) self.loss_manifest |= self.training_manifest.get("loss") if self.training_manifest.get("optimizer", None) is None: - logger.warning("Optimizer block not found in the configuration." - "Will resume the previous run if possible.") + logger.warning( + "Optimizer block not found in the configuration." + "Will resume the previous run if possible." + ) # TODO: implement resume self.optimizer_manifest |= self.training_manifest.get("optimizer") self.optimizer_manifest["epochs"] = self.training_manifest.get("epochs", 10000) - self.optimizer_manifest["stop_condition"] = self.training_manifest.get("stop_condition", None) - self.optimizer_manifest["num_workers"] = self.training_manifest.get("num_workers", 1) - - self.current["chkpt_interval"] = self.training_manifest.get("chkpt_interval", 100) + self.optimizer_manifest["stop_condition"] = self.training_manifest.get( + "stop_condition", None + ) + self.optimizer_manifest["num_workers"] = self.training_manifest.get( + "num_workers", 1 + ) + + self.current["ckpt_interval"] = self.training_manifest.get("ckpt_interval", 100) self.current["verbose"] = self.training_manifest.get("verbose", False) # dataset sample variables will be processed in the setup_dataset method @@ -327,13 +341,17 @@ def setup_workspace(self): dir_list = sorted(glob(f"{self.workspace['name']}*")) if len(dir_list) == 0 or not self.workspace["resume"]: self.current["appending_to_previous_run"] = False - self.current["run_dir"] = f"{self.workspace['name']}/{self.current['run_title']}" + self.current["run_dir"] = ( + f"{self.workspace['name']}/{self.current['run_title']}" + ) os.makedirs(self.current["run_dir"], exist_ok=True) else: last_dir = dir_list[-1] was_it_finished = os.path.exists(f"{last_dir}/.finished") if was_it_finished: # start new run - current_run_dir = f"{self.workspace['name']}/{self.current['run_title']}" + current_run_dir = ( + f"{self.workspace['name']}/{self.current['run_title']}" + ) os.makedirs(current_run_dir, exist_ok=True) self.current["appending_to_previous_run"] = False else: @@ -352,8 +370,8 @@ def setup_dataset(self): dataset_module_manifest = deepcopy(self.dataset_manifest) dataset_module_manifest["weights"] = self.loss_manifest["weights"] self.dataset = Dataset.get_dataset_from_manifest( - dataset_module_manifest, - self.transform_manifest) + dataset_module_manifest, self.transform_manifest + ) def save_config(self): """ @@ -410,15 +428,21 @@ def setup_test_train_datasets(self): # sanity checks if not isinstance(train_size, int) or train_size < 1: - logger.warning("Train size is not provided or is less than 1. Using full dataset for training.") + logger.warning( + "Train size is not provided or is less than 1. Using full dataset for training." + ) train_size = len(self.dataset) if not isinstance(val_size, int) or val_size < 0: - logger.warning("Val size is not provided or is less than 0. Using 0 for validation.") + logger.warning( + "Val size is not provided or is less than 0. Using 0 for validation." + ) val_size = 0 if train_size + val_size > len(self.dataset): - raise TrainerError("Sum of train, val, and test sizes is greater than the dataset size.") + raise TrainerError( + "Sum of train, val, and test sizes is greater than the dataset size." + ) # check if indices are provided train_indices = self.dataset_sample_manifest.get("train_indices") @@ -443,8 +467,6 @@ def setup_test_train_datasets(self): np.random.shuffle(train_indices) np.random.shuffle(val_indices) - print(train_indices, val_indices, self.dataset) - train_dataset = self.dataset[train_indices] if val_size > 0: @@ -456,21 +478,34 @@ def setup_test_train_datasets(self): self.dataset_sample_manifest["val_size"] = val_size self.dataset_sample_manifest["train_indices"] = train_indices self.dataset_sample_manifest["val_indices"] = val_indices - self.dataset_sample_manifest["train_dataset"] = train_dataset - self.dataset_sample_manifest["val_dataset"] = val_dataset + + self.train_dataset = train_dataset + self.val_dataset = val_dataset # save the indices if generated if isinstance(train_indices, str): self.dataset_sample_manifest["indices_files"]["train"] = train_indices else: - self.dataset_sample_manifest["indices_files"]["train"] = f"{self.current['run_dir']}/train_indices.txt" - np.savetxt(self.dataset_sample_manifest["indices_files"]["train"], train_indices, fmt="%d") + self.dataset_sample_manifest["indices_files"][ + "train" + ] = f"{self.current['run_dir']}/train_indices.txt" + np.savetxt( + self.dataset_sample_manifest["indices_files"]["train"], + train_indices, + fmt="%d", + ) if isinstance(val_indices, str): self.dataset_sample_manifest["indices_files"]["val"] = val_indices else: - self.dataset_sample_manifest["indices_files"]["val"] = f"{self.current['run_dir']}/val_indices.txt" - np.savetxt(self.dataset_sample_manifest["indices_files"]["val"], val_indices, fmt="%d") + self.dataset_sample_manifest["indices_files"][ + "val" + ] = f"{self.current['run_dir']}/val_indices.txt" + np.savetxt( + self.dataset_sample_manifest["indices_files"]["val"], + val_indices, + fmt="%d", + ) def loss(self, *args, **kwargs): raise TrainerError("loss not implemented.") @@ -498,5 +533,6 @@ class TrainerError(Exception): """ Exceptions to be raised in Trainer and associated classes. """ + def __init__(self, message): super().__init__(message) diff --git a/kliff/trainer/kim_trainer.py b/kliff/trainer/kim_trainer.py index 04e9f337..c0e5e214 100644 --- a/kliff/trainer/kim_trainer.py +++ b/kliff/trainer/kim_trainer.py @@ -1,24 +1,13 @@ import importlib -import multiprocessing -import subprocess import tarfile from pathlib import Path -from typing import Callable, Tuple, Union -import kimpy -import numpy as np from loguru import logger -import kliff.models -from kliff._exceptions import TrainerError -from kliff.dataset import Configuration from kliff.models import KIMModel -from kliff.utils import install_kim_model +from .base_trainer import Trainer, TrainerError from .kim_residuals import MSE_residuals -from .kliff_trainer import Trainer -from .option_enumerations import ModelTypes, OptimizerProvider - SCIPY_MINIMIZE_METHODS = [ "Nelder-Mead", @@ -65,36 +54,18 @@ def setup_model(self): by the TorchTrainer. """ # check for unsupported model drivers - # 1. get the model driver name - if self.model_type == ModelTypes.KIM: - self.model_driver_name = self.get_model_driver_name_for_kim(self.model_name) - elif self.model_type == ModelTypes.TAR: - self.model_driver_name = self.get_model_driver_name_for_tarball( - self.model_name + if ( + self.model_manifest["type"].lower() == "kim" + or self.model_manifest["type"].lower() == "tar" + ): + self.model = KIMModel.get_model_from_manifest( + self.model_manifest, self.transform_manifest ) else: - raise TrainerError(f"Model type {self.model_type} not supported.") - - # 2. check if the model driver is supported - if self.model_driver_name in UNSUPPORTED_MODEL_DRIVERS: raise TrainerError( - f"Model driver {self.model_driver_name} not supported by KIMTrainer." - ) - elif self.model_driver_name is None: - logger.warning( - f"Could not determine model-driver name for {self.model_name}. Please be careful and check if the model is supported." + f"Model type {self.model_manifest['type']} not supported." ) - else: - logger.info(f"Model driver name: {self.model_driver_name}") - - # 3. load the model - if self.model_type == ModelTypes.KIM: - self.ensure_kim_model_installation(self.model_name, self.collection) - elif self.model_type == ModelTypes.TAR: - # reinstall model to be sure - self.ensure_tarball_model_installation(self.model_name, self.collection) - self.model = KIMModel(self.model_name) self.parameters = self.model.get_model_params() def setup_optimizer(self): @@ -105,13 +76,15 @@ def setup_optimizer(self): loaded from the scipy.optimize. If the optimizer_provider is torch, it will be loaded from the torch.optim. Left for the derived classes to implement. """ - if self.optimizer_provider is not OptimizerProvider.SCIPY: + if self.optimizer_manifest["provider"] != "scipy": raise TrainerError( - f"Optimizer provider {self.optimizer_provider} not supported by KIMTrainer." + f"Optimizer provider {self.optimizer_manifest['provider']} not supported by KIMTrainer." ) - if self.optimizer_name not in SCIPY_MINIMIZE_METHODS: - raise TrainerError(f"Optimizer not supported: {self.optimizer_name}.") + if self.optimizer_manifest["name"] not in SCIPY_MINIMIZE_METHODS: + raise TrainerError( + f"Optimizer not supported: {self.optimizer_manifest['name']}." + ) optimizer_lib = importlib.import_module(f"scipy.optimize") self.optimizer = getattr(optimizer_lib, "minimize") @@ -171,10 +144,13 @@ def _wrapper_func(x): return self.loss(x) x = self.model.get_opt_params() - options = self.optimizer_kwargs - options["options"] = {"maxiter": self.max_epochs, "disp": self.verbose} + options = self.optimizer_manifest.get("kwargs", {}) + options["options"] = { + "maxiter": self.optimizer_manifest["epochs"], + "disp": self.current["verbose"], + } result = self.optimizer( - _wrapper_func, x, method=self.optimizer_name, **self.optimizer_kwargs + _wrapper_func, x, method=self.optimizer_manifest["name"], **options ) if result.success: @@ -183,37 +159,22 @@ def _wrapper_func(x): else: logger.error(f"Optimization failed: {result.message}") - def set_parameters_as_mutable(self): - if self.parameter_transform_options is not None: - for param_to_transform in self.parameter_transform_options[ - "parameter_list" - ]: - if isinstance(param_to_transform, dict): - parameter_name = list(param_to_transform.keys())[0] - elif isinstance(param_to_transform, str): - parameter_name = param_to_transform - else: - raise TrainerError( - f"Optimizable parameters must be string or value dict. Got {param_to_transform} instead." - ) - self.mutable_parameters_list.append(parameter_name) - else: - for param in self.parameters: - self.mutable_parameters_list.append(param) - - self.model.set_params_mutable(self.mutable_parameters_list) - logger.info(f"Mutable parameters: {self.mutable_parameters_list}") - def _get_loss_fn(self): - if self.loss_function_name == "MSE": + if self.loss_manifest["function"].lower() == "mse": return MSE_residuals def save_kim_model(self): - if self.export_model_type is ModelTypes.KIM: - path = Path(self.export_model_path) / self.export_model_name + if self.export_manifest["model_type"].lower() == "kim": + path = ( + Path(self.export_manifest["model_path"]) + / self.export_manifest["model_name"] + ) self.model.write_kim_model(path) - elif self.export_model_type is ModelTypes.TAR: - path = Path(self.export_model_path) / self.export_model_name + elif self.export_manifest["model_type"] == "tar": + path = ( + Path(self.export_manifest["model_path"]) + / self.export_manifest["model_name"] + ) self.model.write_kim_model(path) tarfile_path = path.with_suffix(".tar.gz") with tarfile.open(tarfile_path, "w:gz") as tar: