diff --git a/kliff/dataset/dataset.py b/kliff/dataset/dataset.py index 4945797d..3265b321 100644 --- a/kliff/dataset/dataset.py +++ b/kliff/dataset/dataset.py @@ -1,9 +1,13 @@ import copy +import hashlib +import importlib +import json import os from collections.abc import Iterable from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +import dill import numpy as np from loguru import logger from monty.dev import requires @@ -210,8 +214,8 @@ def from_colabfit( forces = np.asarray(property["atomic-forces"]["forces"]["source-value"]) elif property["type"] == "cauchy-stress": stress = np.asarray(property["cauchy-stress"]["stress"]["source-value"]) + stress = stress_to_voigt(stress) - stress = stress_to_voigt(stress) self = cls( cell, species, @@ -578,6 +582,8 @@ def __init__(self, configurations: Iterable = None): "configurations must be a iterable of Configuration objects." ) + self._metadata: dict = {} + @classmethod @requires(MongoDatabase is not None, "colabfit-tools is not installed") def from_colabfit( @@ -585,7 +591,7 @@ def from_colabfit( colabfit_database: str, colabfit_dataset: str, colabfit_uri: str = "mongodb://localhost:27017", - weight: Optional[Weight] = None, + weight: Optional[Union[Weight, Path]] = None, **kwargs, ) -> "Dataset": """ @@ -593,7 +599,11 @@ def from_colabfit( Args: weight: an instance that computes the weight of the configuration in the loss - function. + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). colabfit_database: Name of the colabfit Mongo database to read from. colabfit_dataset: Name of the colabfit dataset instance to read from, usually it is of form, e.g., "DS_xxxxxxxxxxxx_0" @@ -655,7 +665,7 @@ def add_from_colabfit( colabfit_database: str, colabfit_dataset: str, colabfit_uri: str = "mongodb://localhost:27017", - weight: Optional[Weight] = None, + weight: Optional[Union[Weight, Path]] = None, **kwargs, ): """ @@ -667,19 +677,29 @@ def add_from_colabfit( it is of form, e.g., "DS_xxxxxxxxxxxx_0") colabfit_uri: connection URI of the colabfit Mongo database to read from. weight: an instance that computes the weight of the configuration in the loss - function. + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). """ # open link to the mongo mongo_client = MongoDatabase(colabfit_database, uri=colabfit_uri, **kwargs) - configs = Dataset._read_from_colabfit(mongo_client, colabfit_dataset, weight) + if isinstance(weight, Weight): + configs = Dataset._read_from_colabfit( + mongo_client, colabfit_dataset, weight + ) + else: + configs = Dataset._read_from_colabfit(mongo_client, colabfit_dataset, None) + self.add_weights(weight) self.configs.extend(configs) @classmethod def from_path( cls, path: Union[Path, str], - weight: Optional[Weight] = None, + weight: Optional[Union[Path, Weight]] = None, file_format: str = "xyz", ) -> "Dataset": """ @@ -688,7 +708,11 @@ def from_path( Args: path: Path the directory (or filename) storing the configurations. weight: an instance that computes the weight of the configuration in the loss - function. + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). file_format: Format of the file that stores the configuration, e.g. `xyz`. Returns: @@ -700,7 +724,9 @@ def from_path( @staticmethod def _read_from_path( - path: Path, weight: Optional[Weight] = None, file_format: str = "xyz" + path: Path, + weight: Optional[Weight] = None, + file_format: str = "xyz", ) -> List[Configuration]: """ Read configurations from path. @@ -737,10 +763,7 @@ def _read_from_path( parent = path.parent all_files = [path] - configs = [ - Configuration.from_file(f, copy.copy(weight), file_format) - for f in all_files - ] + configs = [Configuration.from_file(f, weight, file_format) for f in all_files] if len(configs) <= 0: raise DatasetError( @@ -751,7 +774,7 @@ def _read_from_path( def add_from_path( self, path: Union[Path, str], - weight: Optional[Weight] = None, + weight: Optional[Union[Weight, Path]] = None, file_format: str = "xyz", ): """ @@ -760,12 +783,21 @@ def add_from_path( Args: path: Path the directory (or filename) storing the configurations. weight: an instance that computes the weight of the configuration in the loss - function. + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). file_format: Format of the file that stores the configuration, e.g. `xyz`. """ if isinstance(path, str): path = Path(path) - configs = self._read_from_path(path, weight, file_format) + + if isinstance(weight, Weight): + configs = self._read_from_path(path, weight, file_format) + else: + configs = self._read_from_path(path, None, file_format) + self.add_weights(weight) self.configs.extend(configs) @classmethod @@ -773,7 +805,7 @@ def from_ase( cls, path: Union[Path, str] = None, ase_atoms_list: List[ase.Atoms] = None, - weight: Optional[Weight] = None, + weight: Optional[Union[Weight, Path]] = None, energy_key: str = "energy", forces_key: str = "forces", slices: str = ":", @@ -798,7 +830,11 @@ def from_ase( path: Path the directory (or filename) storing the configurations. ase_atoms_list: A list of ase.Atoms objects. weight: an instance that computes the weight of the configuration in the loss - function. + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). energy_key: Name of the field in extxyz/ase.Atoms that stores the energy. forces_key: Name of the field in extxyz/ase.Atoms that stores the forces. slices: Slice of the configurations to read. It is used only when `path` is @@ -851,11 +887,11 @@ def _read_from_ase( configs = [ Configuration.from_ase_atoms( config, - weight=copy.copy(weight), + weight=weight, energy_key=energy_key, forces_key=forces_key, ) - for config in ase_atoms_list + for config, weight_obj in zip(ase_atoms_list) ] else: try: @@ -882,10 +918,11 @@ def _read_from_ase( if len(all_files) == 1: # single xyz file with multiple configs all_configs = ase.io.read(all_files[0], index=slices) + configs = [ Configuration.from_ase_atoms( config, - weight=copy.copy(weight), + weight=weight, energy_key=energy_key, forces_key=forces_key, ) @@ -895,7 +932,7 @@ def _read_from_ase( configs = [ Configuration.from_ase_atoms( ase.io.read(f), - weight=copy.copy(weight), + weight=weight, energy_key=energy_key, forces_key=forces_key, ) @@ -914,7 +951,7 @@ def add_from_ase( self, path: Union[Path, str] = None, ase_atoms_list: List[ase.Atoms] = None, - weight: Optional[Weight] = None, + weight: Optional[Union[Weight, Path]] = None, energy_key: str = "energy", forces_key: str = "forces", slices: str = ":", @@ -940,6 +977,11 @@ def add_from_ase( path: Path the directory (or filename) storing the configurations. ase_atoms_list: A list of ase.Atoms objects. weight: an instance that computes the weight of the configuration in the loss + function. If a path is provided, it is used to read the weight from the + file. The file must be a plain text file with 4 whitespace separated + columns: config_weight, energy_weight, forces_weight, and stress_weight. + Length of the file must be equal to the number of configurations, or 1 + (in which case the same weight is used for all configurations). energy_key: Name of the field in extxyz/ase.Atoms that stores the energy. forces_key: Name of the field in extxyz/ase.Atoms that stores the forces. slices: Slice of the configurations to read. It is used only when `path` is @@ -948,9 +990,22 @@ def add_from_ase( """ if isinstance(path, str): path = Path(path) - configs = self._read_from_ase( - path, ase_atoms_list, weight, energy_key, forces_key, slices - ) + + if isinstance(weight, Weight): + configs = self._read_from_ase( + path, + ase_atoms_list, + weight, + energy_key, + forces_key, + slices, + file_format, + ) + else: + configs = self._read_from_ase( + path, ase_atoms_list, None, energy_key, forces_key, slices, file_format + ) + self.add_weights(weight) self.configs.extend(configs) def get_configs(self) -> List[Configuration]: @@ -969,17 +1024,399 @@ def __len__(self) -> int: """ return len(self.configs) - def __getitem__(self, idx) -> Configuration: + def __getitem__( + self, idx: Union[int, np.ndarray, List] + ) -> Union[Configuration, "Dataset"]: """ - Get the configuration at index `idx`. + Get the configuration at index `idx`. If the index is a list, it returns a new + dataset with the configurations at the indices. Args: - idx: Index of the configuration to get. + idx: Index of the configuration to get or a list of indices. Returns: - The configuration at index `idx`. + The configuration at index `idx` or a new dataset with the configurations at + the indices. """ - return self.configs[idx] + if isinstance(idx, int): + return self.configs[idx] + else: + configs = [self.configs[i] for i in idx] + return Dataset(configs) + + def save_weights(self, path: Union[Path, str]): + """ + Save the weights of the configurations to a file. + + Args: + path: Path of the file to save the weights. + """ + path = to_path(path) + with path.open("w") as f: + for config in self.configs: + f.write( + f"{config.weight.config_weight} " + + f"{config.weight.energy_weight} " + + f"{config.weight.forces_weight} " + + f"{config.weight.stress_weight}\n" + ) + + def add_weights(self, path: Union[Path, str]): + """ + Load weights from a text file. The text file should contain 1 to 4 columns, + whitespace seperated, formatted as, + ``` + Config Energy Forces Stress + 1.0 0.0 10.0 0.0 + ``` + ```{note} + The column headers are case-insensitive, but should have same name as above. + The weight of 0.0 will set respective weight as `None`. The length of column can + be either 1 (all configs same weight) or n, where n is the number of configs in + the dataset. + ``` + Missing columns are treated as 0.0, i.e. above example file can also be written + as + ``` + Config Forces + 1.0 10.0 + ``` + + Args: + path: Path to the configuration file + + """ + if path is None: + logger.info("No weights provided.") + return + + weights_data = np.genfromtxt(path, names=True) + weights_col = weights_data.dtype.names + + # sanity checks + if 1 > len(weights_col) > 4: + raise DatasetError( + "Weights file contains improper number of cols," + "there needs to be at least 1 col, and at most 4" + ) + + if not (weights_data.size == 1 or weights_data.size == len(self)): + raise DatasetError( + "Weights file contains improper number of rows," + "there can be either 1 row (all weights same), " + "or same number of rows as the configurations." + ) + + expected_cols = {"config", "energy", "forces", "stress"} + missing_cols = expected_cols - set([col.lower() for col in weights_col]) + + # missing weights are set to 0.0 + weights = {k.lower(): weights_data[k] for k in weights_col} + for fields in missing_cols: + weights[fields] = np.zeros_like(weights["config"]) + + # set weights + for i, config in enumerate(self.configs): + config.weight = Weight( + config_weight=weights["config"][i], + energy_weight=weights["energy"][i], + forces_weight=weights["forces"][i], + stress_weight=weights["stress"][i], + ) + + def add_metadata(self, metadata: dict): + """ + Add metadata to the dataset object. + + Args: + metadata: A dictionary containing the metadata. + """ + if not isinstance(metadata, dict): + raise DatasetError("metadata must be a dictionary.") + self._metadata.update(metadata) + + def get_metadata(self, key: str): + """ + Get the metadata of the dataset. + + Args: + key: Key of the metadata to get. + + Returns: + Value of the metadata. + """ + return self._metadata[key] + + @property + def metadata(self): + """ + Return the metadata of the dataset. + """ + return self._metadata + + def check_properties_consistency(self, properties: List[str] = None): + """ + Check which of the properties of the configurations are consistent. These + consistent properties are saved a list which can be used to get the attributes + from the configurations. "Consistent" in this context means that same property + is available for all the configurations. A property is not considered consistent + if it is None for any of the configurations. + + Args: + properties: List of properties to check for consistency. If None, no + properties are checked. All consistent properties are saved in the + metadata. + """ + if properties is None: + logger.warning("No properties provided to check for consistency.") + return + + property_list = list(copy.deepcopy(properties)) # make it mutable, if not + for config in self.configs: + for prop in property_list: + try: + getattr(config, prop) + except ConfigurationError: + property_list.remove(prop) + + self.add_metadata({"consistent_properties": tuple(property_list)}) + logger.info( + f"Consistent properties: {property_list}, stored in metadata key: `consistent_properties`" + ) + + @staticmethod + def get_manifest_checksum( + dataset_manifest: dict, transform_manifest: Optional[dict] = None + ) -> str: + """ + Get the checksum of the dataset manifest. + + Args: + dataset_manifest: Manifest of the dataset. + transform_manifest: Manifest of the transformation. + + Returns: + Checksum of the manifest. + """ + dataset_str = json.dumps(dataset_manifest, sort_keys=True) + if transform_manifest: + transform_str = json.dumps(transform_manifest, sort_keys=True) + dataset_str += transform_str + return hashlib.md5(dataset_str.encode()).hexdigest() + + @staticmethod + def get_dataset_from_manifest( + dataset_manifest: dict, transform_manifest: Optional[dict] = None + ) -> "Dataset": + """ + Get a dataset from a manifest. + + Examples: + 1. Manifest file for initializing dataset using ASE parser: + ```yaml + dataset: + type: ase # ase or path or colabfit + path: Si.xyz # Path to the dataset + save: True # Save processed dataset to a file + save_path: /folder/to # Save to this folder + shuffle: False # Shuffle the dataset + weights: /path/to/weights.dat # or dictionary with weights + keys: + energy: Energy # Key for energy, if ase dataset is used + forces: forces # Key for forces, if ase dataset is used + ``` + + 2. Manifest file for initializing dataset using KLIFF extxyz parser: + ```yaml + dataset: + type: path # ase or path or colabfit + path: /all/my/xyz # Path to the dataset + save: False # Save processed dataset to a file + shuffle: False # Shuffle the dataset + weights: # same weight for all, or file with weights + config: 1.0 + energy: 0.0 + forces: 10.0 + stress: 0.0 + ``` + + 3. Manifest file for initializing dataset using ColabFit parser: + ```yaml + dataset: + type: colabfit # ase or path or colabfit + save: False # Save processed dataset to a file + shuffle: False # Shuffle the dataset + weights: None + colabfit_dataset: + dataset_name: + database_name: + database_url: + ``` + + For dataset transformation, the transform manifest should be provided. Example, + ```yaml + property: + - energy: + name: NormalizedPropertyTransform + kwargs: + keep_original: True + - forces: + name: RMSNormalizePropertyTransform + kwargs: + keep_original: True + + configuration: # optional: generate fingerprints from the configuration + name: Descriptor # Graph, Descriptor, None + kwargs: + cutoff: 3.7 + species: ["Si"] + descriptor: "SymmetryFunctions" + hyperparameters: set51 + ``` + + TODO: Cross-validation splits, stratified splits, etc. + + Args: + dataset_manifest: List of configurations. + transform_manifest: List of configurations. + + Returns: + A dataset of configurations. + """ + dataset_type = dataset_manifest.get("type").lower() + if ( + dataset_type != "ase" + and dataset_type != "path" + and dataset_type != "colabfit" + ): + raise DatasetError(f"Dataset type {dataset_type} not supported.") + weights = dataset_manifest.get("weights", None) + if weights is not None: + if isinstance(weights, str): + weights = Path(weights) + elif isinstance(weights, dict): + weights = Weight( + config_weight=weights.get("config", 0.0), + energy_weight=weights.get("energy", 0.0), + forces_weight=weights.get("forces", 0.0), + stress_weight=weights.get("stress", 0.0), + ) + else: + raise DatasetError("Weights must be a path or a dictionary.") + + if dataset_type == "ase": + dataset = Dataset.from_ase( + path=dataset_manifest.get("path", "."), + weight=weights, + energy_key=dataset_manifest.get("keys", {}).get("energy", "energy"), + forces_key=dataset_manifest.get("keys", {}).get("forces", "forces"), + ) + elif dataset_type == "path": + dataset = Dataset.from_path( + path=dataset_manifest.get("path", "."), + weight=weights, + ) + elif dataset_type == "colabfit": + try: + colabfit_dataset = dataset_manifest.get("colabfit_dataset") + colabfit_database = colabfit_dataset.database_name + except KeyError: + raise DatasetError("Colabfit dataset or database not provided.") + colabfit_uri = dataset_manifest.get( + "colabfit_uri", "mongodb://localhost:27017" + ) + + dataset = Dataset.from_colabfit( + colabfit_database=colabfit_database, + colabfit_dataset=colabfit_dataset, + colabfit_uri=colabfit_uri, + weight=weights, + ) + else: + # this should not happen + raise DatasetError(f"Dataset type {dataset_type} not supported.") + + # transforms? + if transform_manifest: + configuration_transform: Union[dict, None] = transform_manifest.get( + "configuration", None + ) + property_transform: Union[list, None] = transform_manifest.get( + "property", None + ) + + if property_transform: + for property_to_transform in property_transform: + property_name = property_to_transform.get("name", None) + if not property_name: + continue # it is probably an empty propery + transform_module_name = property_to_transform[property_name].get( + "name", None + ) + if not transform_module_name: + raise DatasetError( + "Property transform module name not provided." + ) + property_transform_module = importlib.import_module( + f"kliff.transforms.property_transforms" + ) + property_module = getattr( + property_transform_module, transform_module_name + ) + property_module = property_module( + proprty_key=property_name, + **property_to_transform[property_name].get("kwargs", {}), + ) + dataset = property_module(dataset) + + if configuration_transform: + configuration_module_name: Union[str, None] = ( + configuration_transform.get("name", None) + ) + if not configuration_module_name: + logger.warning( + "Configuration transform module name not provided." + "Skipping configuration transform." + ) + else: + configuration_transform_module = importlib.import_module( + f"kliff.transforms.configuration_transforms" + ) + configuration_module = getattr( + configuration_transform_module, configuration_module_name + ) + kwargs: Union[dict, None] = configuration_transform.get( + "kwargs", None + ) + if not kwargs: + raise DatasetError( + "Configuration transform module options not provided." + ) + configuration_module = configuration_module( + **kwargs, copy_to_config=True + ) + + for config in dataset.configs: + _ = configuration_module(config) + + # dataset hash + dataset_checksum = Dataset.get_manifest_checksum( + dataset_manifest, transform_manifest + ) + dataset.add_metadata({"checksum": dataset_checksum}) + + if dataset_manifest.get("save", False): + # TODO: use Path for compatibility + dataset_save_path = dataset_manifest.get("save_path", "./") + logger.info( + f"Saving dataset to {dataset_save_path}/DS_{dataset_checksum[:10]}.pkl" + ) + dill.dump( + dataset, + open(f"{dataset_save_path}/DS_{dataset_checksum[:10]}.pkl", "wb"), + ) + + return dataset class ConfigurationError(Exception): diff --git a/kliff/dataset/weight.py b/kliff/dataset/weight.py index dfd2727d..fe5caa7c 100644 --- a/kliff/dataset/weight.py +++ b/kliff/dataset/weight.py @@ -44,18 +44,34 @@ def compute_weight(self, config): def config_weight(self): return self._config_weight + @config_weight.setter + def config_weight(self, value): + self._config_weight = value + @property def energy_weight(self): return self._energy_weight + @energy_weight.setter + def energy_weight(self, value): + self._energy_weight = value + @property def forces_weight(self): return self._forces_weight + @forces_weight.setter + def forces_weight(self, value): + self._forces_weight = value + @property def stress_weight(self): return self._stress_weight + @stress_weight.setter + def stress_weight(self, value): + self._stress_weight = value + def _check_compute_flag(self, config): """ Check whether compute flag correctly set when the corresponding weight in diff --git a/kliff/models/kim.py b/kliff/models/kim.py index 9adf2164..a82f3ffb 100644 --- a/kliff/models/kim.py +++ b/kliff/models/kim.py @@ -1,8 +1,12 @@ +import importlib import os +import subprocess +import tarfile from collections import OrderedDict from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Union +import kimpy import numpy as np from loguru import logger @@ -12,6 +16,7 @@ from kliff.models.model import ComputeArguments, Model from kliff.models.parameter import Parameter from kliff.neighbor import assemble_forces, assemble_stress +from kliff.utils import install_kim_model, is_kim_model_installed try: import kimpy @@ -21,6 +26,13 @@ except ImportError: kimpy_avail = False +# list of model drivers that are not supported by this trainer. +# example quip, torchml, etc. +# TODO: Get the complete list of unsupported model drivers. +UNSUPPORTED_MODEL_DRIVERS = [ + "TorchML", +] + class KIMComputeArguments(ComputeArguments): """ @@ -88,6 +100,8 @@ def __init__( self._update_neigh(influence_distance) self._register_data(compute_energy, compute_forces) + self.model_trainable_via_kim_api = False + def _get_implemented_property(self): """ Get implemented property of model. @@ -681,6 +695,225 @@ def __call__( return kim_ca_instance.results + @staticmethod + def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None): + """ + Get the model from a configuration. If it is a valid KIM model, it will return + the KIMModel object. If it is a TorchML model, it will return the torch + ReverseScriptedModule object *in future*. Else raise error. If the model is a tarball, it + will extract and install the model. + + Example `model_manifest`: + ```yaml + model: + model_type: kim # kim or torch or tar + model_path: ./ + model_name: SW_StillingerWeber_1985_Si__MO_405512056662_006 # KIM model name, installed if missing + model_collection: "user" + ``` + + Example `param_manifest`: + ```yaml + parameter: + - A # dict means the parameter is transformed + - B # these are the parameters that are not transformed + - sigma: + transform_name: LogParameterTransform + value: 2.0 + bounds: [[1.0, 10.0]] + ``` + + ```{note} + `parameter` block is usually defined as the children of the `transform` block + in trainer configuration file. + ``` + + Args: + model_manifest: configuration object + param_manifest: parameter transformation configuration + + Returns: + Model object + """ + model_name: Union[None, str] = model_manifest.get("model_name", None) + model_type: Union[None, str] = model_manifest.get("model_type", None) + model_path: Union[None, str, Path] = model_manifest.get("model_path", None) + model_driver = KIMModel.get_model_driver_name(model_name) + model_collection = model_manifest.get("model_collection") + + if model_driver in UNSUPPORTED_MODEL_DRIVERS: + logger.error( + "Model driver not supported for KIM-API based training. " + "Please use appropriate trainer for this model." + ) + raise KIMModelError( + f"Model driver {model_driver} not supported for KIMModel training." + ) + + # ensure model is installed + if model_type.lower() == "kim": + is_model_installed = install_kim_model(model_name, model_collection) + if not is_model_installed: + logger.error( + f"Mode: {model_name} neither installed nor available in the KIM API collections. Please check the model name and try again." + ) + raise KIMModelError(f"Model {model_name} not found.") + else: + logger.info( + f"Model {model_name} is present in {model_collection} collection." + ) + elif model_type.lower() == "tar": + archive_content = tarfile.open(model_path + "/" + model_name) + model = archive_content.getnames()[0] + archive_content.extractall(model_path) + subprocess.run( + [ + "kim-api-collections-management", + "install", + "--force", + model_collection, + model_path + "/" + model, + ], + check=True, + ) + logger.info( + f"Tarball Model {model} installed in {model_collection} collection." + ) + else: + raise KIMModelError(f"Model type {model_type} not supported.") + + model = KIMModel(model_name) + + if param_manifest: + mutable_param_list = [] + for param_to_transform in param_manifest.get("parameter", []): + if isinstance(param_to_transform, dict): + parameter_name = list(param_to_transform.keys())[0] + elif isinstance(param_to_transform, str): + parameter_name = param_to_transform + else: + raise KIMModelError(f"Parameter can be a str or dict") + mutable_param_list.append(parameter_name) + + model.set_params_mutable(mutable_param_list) + model_param_list = model.parameters() + + # apply transforms if needed + for model_params, input_params in zip( + model_param_list, param_manifest.get("parameter", []) + ): + if isinstance(input_params, dict): + param_name = list(input_params.keys())[0] + if param_name != model_params.name: + raise KIMModelError( + f"Parameter name mismatch. Expected {model_params.name}, got {param_name}." + ) + + param_value_dict = input_params[param_name] + transform_name = param_value_dict.get("transform_name", None) + params_value = param_value_dict.get("value", None) + bounds = param_value_dict.get("bounds", None) + + if transform_name is not None: + transform_module = getattr( + importlib.import_module( + f"kliff.transforms.parameter_transforms" + ), + transform_name, + ) + transform_module = transform_module() + model_params.add_transform(transform_module) + + if params_value is not None: + model_params.copy_from_model_space(params_value) + + if bounds is not None: + model_params.add_bounds_model_space(np.array(bounds)) + + elif isinstance(input_params, str): + if input_params != model_params.name: + raise KIMModelError( + f"Parameter name mismatch. Expected {model_params.name}, got {input_params}." + ) + else: + raise KIMModelError( + f"Optimizable parameters must be string or value dict. Got {input_params} instead." + ) + + return model + + @staticmethod + def get_model_driver_name(model_name: str) -> Union[str, None]: + """ + Get the model driver from the model name. It will return the model driver + string from the installed KIM API model. If the model is not installed, and the + model name is a tarball, it will extract the model driver name from the CMakeLists.txt. + This is needed to ensure that it excludes the model drivers that it cannot handle. + Example: TorchML driver based models. These models are to be trained using the + TorchTrainer. + + TODO: This is not a clean solution. I think KIMPY must have a better way to handle this. + Ask Mingjian/Yaser for comment. + + Args: + model_name: name of the model. + + Returns: + Model driver name. + """ + # check if model is tarball + if "tar" in model_name: + return KIMModel._get_model_driver_name_for_tarball(model_name) + + collections = kimpy.collections.create() + try: + shared_obj_path, collection = ( + collections.get_item_library_file_name_and_collection( + kimpy.collection_item_type.portableModel, model_name + ) + ) + except RuntimeError: # not a portable model + return None + shared_obj_content = open(shared_obj_path, "rb").read() + md_start_idx = shared_obj_content.find(b"model-driver") + + if md_start_idx == -1: + return None + else: + md_start_idx += 15 # length of 'model-driver" "' + md_end_idx = shared_obj_content.find(b'"', md_start_idx) + return shared_obj_content[md_start_idx:md_end_idx].decode("utf-8") + + @staticmethod + def _get_model_driver_name_for_tarball(tarball: str) -> Union[str, None]: + """ + Get the model driver name from the tarball. It will extract the model driver + name from the CMakeLists.txt file in the tarball. This is needed to ensure that + it excludes the model drivers that it cannot handle. Example: TorchML driver based + models. These models are to be trained using the TorchTrainer. + + Args: + tarball: path to the tarball. + + Returns: + Model driver name. + """ + archive_content = tarfile.open(tarball) + cmake_file_path = archive_content.getnames()[0] + "/CMakeLists.txt" + cmake_file = archive_content.extractfile(cmake_file_path) + cmake_file_content = cmake_file.read().decode("utf-8") + + md_start_idx = cmake_file_content.find("DRIVER_NAME") + if md_start_idx == -1: + return None + else: + # name strats at " + md_start_idx = cmake_file_content.find('"', md_start_idx) + 1 + if md_start_idx == -1: + return None + md_end_idx = cmake_file_content.find('"', md_start_idx) + return cmake_file_content[md_start_idx:md_end_idx] + class KIMModelError(Exception): def __init__(self, msg): diff --git a/kliff/trainer/__init__.py b/kliff/trainer/__init__.py new file mode 100644 index 00000000..d669e814 --- /dev/null +++ b/kliff/trainer/__init__.py @@ -0,0 +1,2 @@ +from .base_trainer import Trainer +from .kim_trainer import KIMTrainer diff --git a/kliff/trainer/base_trainer.py b/kliff/trainer/base_trainer.py new file mode 100644 index 00000000..184d1756 --- /dev/null +++ b/kliff/trainer/base_trainer.py @@ -0,0 +1,538 @@ +import hashlib +import json +import os +import random +from copy import deepcopy +from datetime import datetime, timedelta +from glob import glob +from pathlib import Path +from typing import Callable, Union + +import dill # TODO: include dill in requirements.txt +import numpy as np +import yaml +from loguru import logger + +from kliff.dataset import Dataset + + +class Trainer: + """Base class for all trainers. + + This class is the base class for all trainers. It provides the basic structure for + training a model. The derived classes should implement the required methods. This + class will provide the basic functionality for training, such as setting up the + work directory, saving the configuration, and setting up the indices for training + and validation datasets. It will save hashes of the configuration fingerprints and + training configuration to the work directory. This would ensure reproducibility of + the training process, and easy restarting. + + The core trainer class will provide the following functionality: + - Set up the work directory + - Set up the dataset + - Set up the test train split + Model, parameter transform and optimizer setup are left for the derived classes to + implement. + + Args: + training_manifest: training manifest + """ + + def __init__(self, training_manifest: dict): + # workspace variables + self.workspace: dict = { + "name": "kliff_workspace", + "seed": 12345, + "resume": False, + "walltime": "2:00:00:00", + } + + # dataset variables + self.dataset_manifest: dict = { + "type": "kliff", + "path": "./", + "save": False, + "shuffle": False, + "ase_keys": {"energy": "energy", "forces": "forces"}, + "colabfit_dataset": { + "dataset_name": None, + "database_name": None, + "database_url": None, + }, + } + self.dataset = None + + # model variables + self.model_manifest: dict = { + "type": "kim", + "name": None, + "path": None, + "instance": None, + } + self.model: Callable = None + + # transform variables + self.transform_manifest: dict = { + "property": [ + { + "name": None, + "property_key": None, + } + ], + "parameter": [], + "configuration": { + "name": None, + "kwargs": None, + }, + } + + # training variables + # this is too complicated to put it in singe dict, therefore the training + # block is divided into loss, optimizer, dataset_sample + self.training_manifest: dict = {} + self.loss_manifest: dict = { + "function": "mse", + "weights": { + "energy": 1.0, + "forces": None, + "stress": None, + "config": 1.0, + }, + "normalize_per_atom": False, + "loss_traj": False, + } + + self.optimizer_manifest: dict = { + "provider": "scipy", + "name": None, + "learning_rate": None, + "kwargs": None, + "epochs": 10000, + "stop_condition": None, + "num_workers": 1, + } + self.optimizer = None + + # part of current? + self.dataset_sample_manifest: dict = { + "train_size": None, + "val_size": None, + "indices_files": {"train": None, "val": None}, + "val_indices": None, + "train_indices": None, + } + self.train_dataset = None + self.val_dataset = None + + # export trained model + self.export_manifest: dict = { + "model_type": None, + "model_name": None, + "model_path": None, + } + + # state variables + self.current: dict = { + "run_title": None, + "run_dir": None, + "run_hash": None, + "start_time": None, + "end_time": None, + "best_loss": None, + "best_model": None, + "loss": None, + "epoch": 0, + "step": 0, + "expected_end_time": None, + "warned_once": False, + "dataset_hash": None, + "appending_to_previous_run": False, + "verbose": False, + "ckpt_interval": 100, + } + self.parse_manifest(training_manifest) + self.initialize() + + def parse_manifest(self, manifest: dict): + """ + It accepts the raw manifest dictionary, and processes it to the formatted + manifest. This includes mapping the string fields to enums, and setting sane + defaults for missing fields. + + Args: + manifest: raw incoming configuration + + Returns: + Processed manifest + """ + _date_time_format = "%Y-%m-%d-%H-%M-%S" + start_time = datetime.now() + date_time_str = start_time.strftime(_date_time_format) + self.current["start_time"] = start_time + + # Workspace variables ################################################ + workspace_block: Union[None, dict] = manifest.get("workspace", None) + if workspace_block is None: + logger.warning( + "Workspace block not found in the configuration. Using default values." + ) + else: + self.workspace |= workspace_block + + if isinstance(self.workspace["walltime"], int): + expected_end_time = datetime.now() + timedelta( + seconds=self.workspace["walltime"] + ) + else: + expected_end_time = datetime.now() + timedelta( + days=int(self.workspace["walltime"].split(":")[0]), + hours=int(self.workspace["walltime"].split(":")[1]), + minutes=int(self.workspace["walltime"].split(":")[2]), + seconds=int(self.workspace["walltime"].split(":")[3]), + ) + self.current["expected_end_time"] = expected_end_time + + # Dataset manifest ################################################# + dataset_manifest: Union[None, dict] = manifest.get("dataset", None) + if dataset_manifest is None: + raise TrainerError("Dataset block not found in the configuration. Exiting.") + + self.dataset_manifest |= dataset_manifest + + # model variables #################################################### + model_manifest: Union[None, dict] = manifest.get("model", None) + if model_manifest is None: + raise TrainerError("Model block not found in the configuration. Exiting.") + self.model_manifest |= model_manifest + + if self.model_manifest.get("name", None) is None: + self.current["run_title"] = ( + f"{self.model_manifest.get('type')}_{date_time_str}" + ) + else: + self.current["run_title"] = ( + f"{self.model_manifest.get('name')}_{date_time_str}" + ) + + # transform variables #################################################### + transform_manifest: Union[None, dict] = manifest.get("transforms", None) + if transform_manifest is None: + logger.warning( + "Transform block not found in the configuration. This is bit unusual." + ) + else: + self.transform_manifest |= transform_manifest + + # training variables ######################################################## + training_manifest: Union[None, dict] = manifest.get("training", None) + if training_manifest is None: + logger.warning( + "Training block not found in the configuration." + "Will try and resume the previous run if possible." + ) + # TODO: implement resume + self.training_manifest |= training_manifest + + if self.training_manifest.get("loss", None) is None: + logger.warning( + "Loss block not found in the configuration. Using default values." + ) + + self.loss_manifest |= self.training_manifest.get("loss") + + if self.training_manifest.get("optimizer", None) is None: + logger.warning( + "Optimizer block not found in the configuration." + "Will resume the previous run if possible." + ) + # TODO: implement resume + + self.optimizer_manifest |= self.training_manifest.get("optimizer") + self.optimizer_manifest["epochs"] = self.training_manifest.get("epochs", 10000) + self.optimizer_manifest["stop_condition"] = self.training_manifest.get( + "stop_condition", None + ) + self.optimizer_manifest["num_workers"] = self.training_manifest.get( + "num_workers", 1 + ) + + self.current["ckpt_interval"] = self.training_manifest.get("ckpt_interval", 100) + self.current["verbose"] = self.training_manifest.get("verbose", False) + + # dataset sample variables will be processed in the setup_dataset method + self.export_manifest |= manifest.get("export", {}) + + def config_to_dict(self): + """ + Convert the configuration to a dictionary. + """ + config = {} + config |= self.workspace + config |= self.dataset_manifest + config |= self.model_manifest + config |= self.transform_manifest + config |= self.training_manifest + return config + + @classmethod + def from_file(cls, filename: Path): + """ + Load the manifest from a YAML file. + + Args: + filename: name of the yaml file + + Returns: + Trainer instance + + """ + manifest = yaml.safe_load(open(filename, "r")) + return cls(manifest) + + def get_trainer_hash(self): + """ + Get the hash of the current configuration. It will be used to create a unique + directory for the current run. It will be the hash of the configuration dictionary + string. + """ + config = self.config_to_dict() + config_immut_str = json.dumps(config, sort_keys=True) + return hashlib.md5(config_immut_str.encode()).hexdigest() + + def initialize(self): + """ + Initialize the trainer. Assigns the configuration objects, and + call setup methods. + """ + # Step 1 - Assign the processed configuration objects to the class variables + # This has been done in the __init__ method + # Step 2 - Initialize all seeds + self.seed_all() + # Step 3 - Set up the workspace folder + self.setup_workspace() + # Step 4 - Read or load the dataset, initialize the property/configuration transforms + self.setup_dataset() + # Step 5 - Set up the test and train datasets, based on the provided indices + self.setup_test_train_datasets() + # Step 6 - Set up the model + self.setup_model() + # Step 6.5 - Setup parameter transform + self.setup_parameter_transforms() + # Step 7 - Set up the optimizer + self.setup_optimizer() + # Step 8 - Save the configuration for future + self.save_config() + + def seed_all(self): + """ + Seed all the random number generators. + """ + np.random.seed(self.workspace["seed"]) + random.seed(self.workspace["seed"]) + # torch.manual_seed(self.seed) # enable torch seed in respective children + # torch.cuda.manual_seed_all(self.seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + + def setup_workspace(self): + """ + Check all the existing runs in the root directory and see if it finished the run + """ + dir_list = sorted(glob(f"{self.workspace['name']}*")) + if len(dir_list) == 0 or not self.workspace["resume"]: + self.current["appending_to_previous_run"] = False + self.current["run_dir"] = ( + f"{self.workspace['name']}/{self.current['run_title']}" + ) + os.makedirs(self.current["run_dir"], exist_ok=True) + else: + last_dir = dir_list[-1] + was_it_finished = os.path.exists(f"{last_dir}/.finished") + if was_it_finished: # start new run + current_run_dir = ( + f"{self.workspace['name']}/{self.current['run_title']}" + ) + os.makedirs(current_run_dir, exist_ok=True) + self.current["appending_to_previous_run"] = False + else: + self.current["appending_to_previous_run"] = True + self.current["run_dir"] = dir_list[-1] + + def setup_dataset(self): + """ + Set up the dataset based on the provided information. + + TODO: It will check the {workspace}/{dataset_name} directory to see if dataset hash + is already there. The dataset hash is determined but hashing the full path to + the dataset + transforms names + configuration transform properties. + TODO: reload hashed dataset if it exists. + """ + dataset_module_manifest = deepcopy(self.dataset_manifest) + dataset_module_manifest["weights"] = self.loss_manifest["weights"] + self.dataset = Dataset.get_dataset_from_manifest( + dataset_module_manifest, self.transform_manifest + ) + + def save_config(self): + """ + Hash and save the configuration to the current run directory. + """ + config_hash = self.get_trainer_hash() + config_file = f"{self.current['run_dir']}/{config_hash}.yaml" + with open(config_file, "w") as f: + yaml.dump(self.config_to_dict(), f, default_flow_style=False) + logger.info(f"Configuration saved in {config_file}.") + + def setup_model(self): + """ + Set up the model based on the provided information. If the model is not provided, + it will be loaded from the model_path. If the model_path is not provided, it will + raise an error. If the model_type is KIM, it will be loaded from the KIM model + repository. If KIM type model is installed in CWD, it will be loaded from there, and + model_path will be set to the KIM CWD. If model is of type TAR, it will be untarred + and model_path will be set to the untarred directory. Left for the derived classes + to implement. + """ + raise TrainerError("setup_model not implemented.") + + def setup_parameter_transforms(self): + """ + This method set up the transformed parameter space for models. It can be used + for any model type in general, but as there exists a significant difference + between how models handles their parameters, it is left for the subclass to + implement. Although to ensure that `initialize` function remains consistent + this method will not raise NotImplemented error, rather it will quietly pass. + So be aware. + """ + pass + + def setup_optimizer(self): + """ + Set up the optimizer based on the provided information. If the optimizer is not + provided, it will be loaded from the optimizer_name. If the optimizer_name is not + provided, it will raise an error. If the optimizer_provider is scipy, it will be + loaded from the scipy.optimize. If the optimizer_provider is torch, it will be + loaded from the torch.optim. Left for the derived classes to implement. + """ + raise TrainerError("setup_optimizer not implemented.") + + def setup_test_train_datasets(self): + """ + Simple test train split for now, will have more options like stratification + in the future. + + """ + # test train splits + train_size = self.dataset_sample_manifest.get("train_size", len(self.dataset)) + val_size = self.dataset_sample_manifest.get("val_size", 0) + + # sanity checks + if not isinstance(train_size, int) or train_size < 1: + logger.warning( + "Train size is not provided or is less than 1. Using full dataset for training." + ) + train_size = len(self.dataset) + + if not isinstance(val_size, int) or val_size < 0: + logger.warning( + "Val size is not provided or is less than 0. Using 0 for validation." + ) + val_size = 0 + + if train_size + val_size > len(self.dataset): + raise TrainerError( + "Sum of train, val, and test sizes is greater than the dataset size." + ) + + # check if indices are provided + train_indices = self.dataset_sample_manifest.get("train_indices") + if train_indices is None: + train_indices = np.arange(train_size, dtype=int) + elif isinstance(train_indices, str): + train_indices = np.genfromtxt(train_indices, dtype=int) + else: + TrainerError("train_indices should be a numpy array or a path to a file.") + + val_indices = self.dataset_sample_manifest.get("val_indices") + if val_indices is None: + val_indices = np.arange(train_size, train_size + val_size, dtype=int) + elif isinstance(val_indices, str): + val_indices = np.genfromtxt(val_indices, dtype=int) + else: + TrainerError("val_indices should be a numpy array or a path to a file.") + + if self.dataset_manifest.get("shuffle", False): + # instead of shuffling the main dataset, validation/train indices are shuffled + # this gives better control over future active learning scenarios + np.random.shuffle(train_indices) + np.random.shuffle(val_indices) + + train_dataset = self.dataset[train_indices] + + if val_size > 0: + val_dataset = self.dataset[val_indices] + else: + val_dataset = None + + self.dataset_sample_manifest["train_size"] = train_size + self.dataset_sample_manifest["val_size"] = val_size + self.dataset_sample_manifest["train_indices"] = train_indices + self.dataset_sample_manifest["val_indices"] = val_indices + + self.train_dataset = train_dataset + self.val_dataset = val_dataset + + # save the indices if generated + if isinstance(train_indices, str): + self.dataset_sample_manifest["indices_files"]["train"] = train_indices + else: + self.dataset_sample_manifest["indices_files"][ + "train" + ] = f"{self.current['run_dir']}/train_indices.txt" + np.savetxt( + self.dataset_sample_manifest["indices_files"]["train"], + train_indices, + fmt="%d", + ) + + if isinstance(val_indices, str): + self.dataset_sample_manifest["indices_files"]["val"] = val_indices + else: + self.dataset_sample_manifest["indices_files"][ + "val" + ] = f"{self.current['run_dir']}/val_indices.txt" + np.savetxt( + self.dataset_sample_manifest["indices_files"]["val"], + val_indices, + fmt="%d", + ) + + def loss(self, *args, **kwargs): + raise TrainerError("loss not implemented.") + + def checkpoint(self, *args, **kwargs): + raise TrainerError("checkpoint not implemented.") + + def train_step(self, *args, **kwargs): + raise TrainerError("train_step not implemented.") + + def validation_step(self, *args, **kwargs): + raise TrainerError("validation_step not implemented.") + + def get_optimizer(self, *args, **kwargs): + raise TrainerError("get_optimizer not implemented.") + + def train(self, *args, **kwargs): + raise TrainerError("train not implemented.") + + def save_kim_model(self, *args, **kwargs): + raise TrainerError("save_kim_model not implemented.") + + +class TrainerError(Exception): + """ + Exceptions to be raised in Trainer and associated classes. + """ + + def __init__(self, message): + super().__init__(message) diff --git a/kliff/trainer/kim_residuals.py b/kliff/trainer/kim_residuals.py new file mode 100644 index 00000000..eced653b --- /dev/null +++ b/kliff/trainer/kim_residuals.py @@ -0,0 +1,19 @@ +from typing import Any, Dict + +import numpy as np + + +def MSE_residuals( + predictions: np.ndarray, + targets: np.ndarray, +) -> np.ndarray: + r""" + Compute the mean squared error (MSE) of the residuals. + + Args: + + Returns: + The MSE of the residuals. + """ + residuals = predictions - targets + return np.mean(residuals**2) diff --git a/kliff/trainer/kim_trainer.py b/kliff/trainer/kim_trainer.py new file mode 100644 index 00000000..c0e5e214 --- /dev/null +++ b/kliff/trainer/kim_trainer.py @@ -0,0 +1,181 @@ +import importlib +import tarfile +from pathlib import Path + +from loguru import logger + +from kliff.models import KIMModel + +from .base_trainer import Trainer, TrainerError +from .kim_residuals import MSE_residuals + +SCIPY_MINIMIZE_METHODS = [ + "Nelder-Mead", + "Powell", + "CG", + "BFGS", + "Newton-CG", + "L-BFGS-B", + "TNC", + "COBYLA", + "SLSQP", + "trust-constr", + "dogleg", + "trust-ncg", + "trust-exact", + "trust-krylov", +] + + +class KIMTrainer(Trainer): + """ + This class extends the base Trainer class for training OpenKIM physics based models. + It will use the scipy optimizers. It will perform a check to exclude TorchML model + driver based models, as they would be handled by TorchTrainer. + """ + + def __init__(self, configuration: dict, collection: str = "user"): + self.collection = collection + + self.model_driver_name = None + self.parameters = None + self.mutable_parameters_list = [] + self.use_energy = True + self.use_forces = False + + super().__init__(configuration) + + self.loss_function = self._get_loss_fn() + + def setup_model(self): + """ + Load either the KIM model, or install it from the tarball. If the model driver + required is TorchML* family, then it will raise an error, as it should be handled + by the TorchTrainer. + """ + # check for unsupported model drivers + if ( + self.model_manifest["type"].lower() == "kim" + or self.model_manifest["type"].lower() == "tar" + ): + self.model = KIMModel.get_model_from_manifest( + self.model_manifest, self.transform_manifest + ) + else: + raise TrainerError( + f"Model type {self.model_manifest['type']} not supported." + ) + + self.parameters = self.model.get_model_params() + + def setup_optimizer(self): + """ + Set up the optimizer based on the provided information. If the optimizer is not + provided, it will be loaded from the optimizer_name. If the optimizer_name is not + provided, it will raise an error. If the optimizer_provider is scipy, it will be + loaded from the scipy.optimize. If the optimizer_provider is torch, it will be + loaded from the torch.optim. Left for the derived classes to implement. + """ + if self.optimizer_manifest["provider"] != "scipy": + raise TrainerError( + f"Optimizer provider {self.optimizer_manifest['provider']} not supported by KIMTrainer." + ) + + if self.optimizer_manifest["name"] not in SCIPY_MINIMIZE_METHODS: + raise TrainerError( + f"Optimizer not supported: {self.optimizer_manifest['name']}." + ) + optimizer_lib = importlib.import_module(f"scipy.optimize") + self.optimizer = getattr(optimizer_lib, "minimize") + + def loss(self, x): + """ + Compute the loss function for the given parameters. It will compute the loss + function. It seems like MPI might be only way to make it parallel as the + multiprocessing does not work with the KIM models. KIMPY models are not yet + pickelable. TODO: include MPI support. + """ + # set the parameters + self.model.update_model_params(x) + # compute the loss + loss = 0.0 + for configuration in self.train_dataset: + compute_energy = True if configuration.weight.energy_weight else False + compute_forces = True if configuration.weight.forces_weight else False + compute_stress = True if configuration.weight.stress_weight else False + + prediction = self.model( + configuration, + compute_energy=compute_energy, + compute_forces=compute_forces, + compute_stress=compute_stress, + ) + + if configuration.weight.energy_weight: + loss += configuration.weight.energy_weight * self.loss_function( + prediction["energy"], configuration.energy + ) + if configuration.weight.forces_weight: + loss += configuration.weight.forces_weight * self.loss_function( + prediction["forces"], configuration.forces + ) + if configuration.weight.stress_weight: + loss += configuration.weight.stress_weight * self.loss_function( + prediction["stress"], configuration.stress + ) + loss *= configuration.weight.config_weight + + return loss + + def checkpoint(self, *args, **kwargs): + TrainerError("checkpoint not implemented.") + + def train_step(self, *args, **kwargs): + TrainerError("train_step not implemented.") + + def validation_step(self, *args, **kwargs): + TrainerError("validation_step not implemented.") + + def get_optimizer(self, *args, **kwargs): + TrainerError("get_optimizer not implemented.") + + def train(self, *args, **kwargs): + def _wrapper_func(x): + return self.loss(x) + + x = self.model.get_opt_params() + options = self.optimizer_manifest.get("kwargs", {}) + options["options"] = { + "maxiter": self.optimizer_manifest["epochs"], + "disp": self.current["verbose"], + } + result = self.optimizer( + _wrapper_func, x, method=self.optimizer_manifest["name"], **options + ) + + if result.success: + logger.info(f"Optimization successful: {result.message}") + self.model.update_model_params(result.x) + else: + logger.error(f"Optimization failed: {result.message}") + + def _get_loss_fn(self): + if self.loss_manifest["function"].lower() == "mse": + return MSE_residuals + + def save_kim_model(self): + if self.export_manifest["model_type"].lower() == "kim": + path = ( + Path(self.export_manifest["model_path"]) + / self.export_manifest["model_name"] + ) + self.model.write_kim_model(path) + elif self.export_manifest["model_type"] == "tar": + path = ( + Path(self.export_manifest["model_path"]) + / self.export_manifest["model_name"] + ) + self.model.write_kim_model(path) + tarfile_path = path.with_suffix(".tar.gz") + with tarfile.open(tarfile_path, "w:gz") as tar: + tar.add(path, arcname=path.name) diff --git a/kliff/utils.py b/kliff/utils.py index 5cddf2d0..57f3a1f4 100644 --- a/kliff/utils.py +++ b/kliff/utils.py @@ -1,6 +1,7 @@ import os import pickle import random +import subprocess import tarfile from collections.abc import Sequence from pathlib import Path @@ -222,3 +223,40 @@ def stress_to_tensor(input_stress: list) -> np.ndarray: stress[0, 1] = stress[1, 0] = input_stress[5] return stress + + +def is_kim_model_installed(model_name: str) -> bool: + """ + Check if the KIM model is installed in any collection. + + Args: + model_name: name of the model. + """ + model_list = subprocess.run( + ["kim-api-collections-management", "list"], capture_output=True, text=True + ) + if model_name in model_list.stdout: + return True + else: + return False + + +def install_kim_model(model_name: str, collection: str = "user") -> bool: + """ + Install the KIM model + + Args: + model_name: name of the model. + collection: name of the collection. + + Returns: + True if the model is now installed, False otherwise. + """ + if not is_kim_model_installed(model_name): + output = subprocess.run( + ["kim-api-collections-management", "install", collection, model_name], + check=True, + ) + return output.returncode == 0 + else: + return True diff --git a/setup.py b/setup.py index f89573e8..db637906 100644 --- a/setup.py +++ b/setup.py @@ -113,7 +113,7 @@ def get_readme(): version=get_version(), packages=find_packages(), ext_modules=[sym_fn, bispectrum, neighlist, graph_module], - install_requires=["requests", "scipy", "pyyaml", "monty", "loguru"], + install_requires=["requests", "scipy", "pyyaml", "monty", "loguru", "dill"], extras_require={ "test": [ "pytest", @@ -126,6 +126,7 @@ def get_readme(): "ase", "libdescriptor", "torch_geometric", + "dill", ], "docs": [ "sphinx", diff --git a/tests/dataset/test_weight.py b/tests/dataset/test_weight.py index ec96fd2e..50405198 100644 --- a/tests/dataset/test_weight.py +++ b/tests/dataset/test_weight.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np from kliff.dataset import Dataset @@ -84,3 +86,21 @@ def _compute_magnitude_inverse_weight(c1, c2, norm): """ sigma = np.sqrt(c1**2 + (c2 * norm) ** 2) return 1 / sigma + + +def test_weight_from_file(): + ds = Dataset.from_ase( + Path(__file__).parents[1].joinpath("test_data/configs/Si_4.xyz"), + energy_key="Energy", + forces_key="force", + weight=Path(__file__).parents[1].joinpath("test_data/weights/Si_4_weights.dat"), + ) + configs = ds.get_configs() + assert len(configs) == 4 + assert configs[0].weight.config_weight == 1.0 + assert configs[0].weight.energy_weight == 0.5 + assert configs[0].weight.stress_weight == 1.0 + assert configs[3].weight.forces_weight == 0.5 + assert configs[3].weight.config_weight == 1.0 + assert configs[3].weight.energy_weight == 0.5 + assert configs[3].weight.stress_weight == 4.0 diff --git a/tests/test_data/weights/Si_4_weights.dat b/tests/test_data/weights/Si_4_weights.dat new file mode 100644 index 00000000..a80b9bb1 --- /dev/null +++ b/tests/test_data/weights/Si_4_weights.dat @@ -0,0 +1,4 @@ +1.0 0.5 0.5 1.0 +1.0 0.5 0.5 2.0 +1.0 0.5 0.5 3.0 +1.0 0.5 0.5 4.0