Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Design Review 1: KLIFF Trainer framework #172

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
501 changes: 469 additions & 32 deletions kliff/dataset/dataset.py

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions kliff/dataset/weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,34 @@ def compute_weight(self, config):
def config_weight(self):
return self._config_weight

@config_weight.setter
def config_weight(self, value):
self._config_weight = value

@property
def energy_weight(self):
return self._energy_weight

@energy_weight.setter
def energy_weight(self, value):
self._energy_weight = value

@property
def forces_weight(self):
return self._forces_weight

@forces_weight.setter
def forces_weight(self, value):
self._forces_weight = value

@property
def stress_weight(self):
return self._stress_weight

@stress_weight.setter
def stress_weight(self, value):
self._stress_weight = value

def _check_compute_flag(self, config):
"""
Check whether compute flag correctly set when the corresponding weight in
Expand Down
233 changes: 233 additions & 0 deletions kliff/models/kim.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import importlib
import os
import subprocess
import tarfile
from collections import OrderedDict
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import kimpy
import numpy as np
from loguru import logger

Expand All @@ -12,6 +16,7 @@
from kliff.models.model import ComputeArguments, Model
from kliff.models.parameter import Parameter
from kliff.neighbor import assemble_forces, assemble_stress
from kliff.utils import install_kim_model, is_kim_model_installed

try:
import kimpy
Expand All @@ -21,6 +26,13 @@
except ImportError:
kimpy_avail = False

# list of model drivers that are not supported by this trainer.
# example quip, torchml, etc.
# TODO: Get the complete list of unsupported model drivers.
UNSUPPORTED_MODEL_DRIVERS = [
"TorchML",
]


class KIMComputeArguments(ComputeArguments):
"""
Expand Down Expand Up @@ -88,6 +100,8 @@ def __init__(
self._update_neigh(influence_distance)
self._register_data(compute_energy, compute_forces)

self.model_trainable_via_kim_api = False

def _get_implemented_property(self):
"""
Get implemented property of model.
Expand Down Expand Up @@ -681,6 +695,225 @@ def __call__(

return kim_ca_instance.results

@staticmethod
def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None):
"""
Get the model from a configuration. If it is a valid KIM model, it will return
the KIMModel object. If it is a TorchML model, it will return the torch
ReverseScriptedModule object *in future*. Else raise error. If the model is a tarball, it
will extract and install the model.

Example `model_manifest`:
```yaml
model:
model_type: kim # kim or torch or tar
model_path: ./
model_name: SW_StillingerWeber_1985_Si__MO_405512056662_006 # KIM model name, installed if missing
model_collection: "user"
```

Example `param_manifest`:
```yaml
parameter:
- A # dict means the parameter is transformed
- B # these are the parameters that are not transformed
- sigma:
transform_name: LogParameterTransform
value: 2.0
bounds: [[1.0, 10.0]]
```

```{note}
`parameter` block is usually defined as the children of the `transform` block
in trainer configuration file.
```

Args:
model_manifest: configuration object
param_manifest: parameter transformation configuration

Returns:
Model object
"""
model_name: Union[None, str] = model_manifest.get("model_name", None)
model_type: Union[None, str] = model_manifest.get("model_type", None)
model_path: Union[None, str, Path] = model_manifest.get("model_path", None)
model_driver = KIMModel.get_model_driver_name(model_name)
model_collection = model_manifest.get("model_collection")

if model_driver in UNSUPPORTED_MODEL_DRIVERS:
logger.error(
"Model driver not supported for KIM-API based training. "
"Please use appropriate trainer for this model."
)
raise KIMModelError(
f"Model driver {model_driver} not supported for KIMModel training."
)

# ensure model is installed
if model_type.lower() == "kim":
is_model_installed = install_kim_model(model_name, model_collection)
if not is_model_installed:
logger.error(
f"Mode: {model_name} neither installed nor available in the KIM API collections. Please check the model name and try again."
)
raise KIMModelError(f"Model {model_name} not found.")
else:
logger.info(
f"Model {model_name} is present in {model_collection} collection."
)
elif model_type.lower() == "tar":
archive_content = tarfile.open(model_path + "/" + model_name)
model = archive_content.getnames()[0]
archive_content.extractall(model_path)
subprocess.run(
[
"kim-api-collections-management",
"install",
"--force",
model_collection,
model_path + "/" + model,
],
check=True,
)
logger.info(
f"Tarball Model {model} installed in {model_collection} collection."
)
else:
raise KIMModelError(f"Model type {model_type} not supported.")

model = KIMModel(model_name)

if param_manifest:
mutable_param_list = []
for param_to_transform in param_manifest.get("parameter", []):
if isinstance(param_to_transform, dict):
parameter_name = list(param_to_transform.keys())[0]
elif isinstance(param_to_transform, str):
parameter_name = param_to_transform
else:
raise KIMModelError(f"Parameter can be a str or dict")
mutable_param_list.append(parameter_name)

model.set_params_mutable(mutable_param_list)
model_param_list = model.parameters()

# apply transforms if needed
for model_params, input_params in zip(
model_param_list, param_manifest.get("parameter", [])
):
if isinstance(input_params, dict):
param_name = list(input_params.keys())[0]
if param_name != model_params.name:
raise KIMModelError(
f"Parameter name mismatch. Expected {model_params.name}, got {param_name}."
)

param_value_dict = input_params[param_name]
transform_name = param_value_dict.get("transform_name", None)
params_value = param_value_dict.get("value", None)
bounds = param_value_dict.get("bounds", None)

if transform_name is not None:
transform_module = getattr(
importlib.import_module(
f"kliff.transforms.parameter_transforms"
),
transform_name,
)
transform_module = transform_module()
model_params.add_transform(transform_module)

if params_value is not None:
model_params.copy_from_model_space(params_value)

if bounds is not None:
model_params.add_bounds_model_space(np.array(bounds))

elif isinstance(input_params, str):
if input_params != model_params.name:
raise KIMModelError(
f"Parameter name mismatch. Expected {model_params.name}, got {input_params}."
)
else:
raise KIMModelError(
f"Optimizable parameters must be string or value dict. Got {input_params} instead."
)

return model

@staticmethod
def get_model_driver_name(model_name: str) -> Union[str, None]:
"""
Get the model driver from the model name. It will return the model driver
string from the installed KIM API model. If the model is not installed, and the
model name is a tarball, it will extract the model driver name from the CMakeLists.txt.
This is needed to ensure that it excludes the model drivers that it cannot handle.
Example: TorchML driver based models. These models are to be trained using the
TorchTrainer.

TODO: This is not a clean solution. I think KIMPY must have a better way to handle this.
Ask Mingjian/Yaser for comment.

Args:
model_name: name of the model.

Returns:
Model driver name.
"""
# check if model is tarball
if "tar" in model_name:
return KIMModel._get_model_driver_name_for_tarball(model_name)

collections = kimpy.collections.create()
try:
shared_obj_path, collection = (
collections.get_item_library_file_name_and_collection(
kimpy.collection_item_type.portableModel, model_name
)
)
except RuntimeError: # not a portable model
return None
shared_obj_content = open(shared_obj_path, "rb").read()
md_start_idx = shared_obj_content.find(b"model-driver")

if md_start_idx == -1:
return None
else:
md_start_idx += 15 # length of 'model-driver" "'
md_end_idx = shared_obj_content.find(b'"', md_start_idx)
return shared_obj_content[md_start_idx:md_end_idx].decode("utf-8")

@staticmethod
def _get_model_driver_name_for_tarball(tarball: str) -> Union[str, None]:
"""
Get the model driver name from the tarball. It will extract the model driver
name from the CMakeLists.txt file in the tarball. This is needed to ensure that
it excludes the model drivers that it cannot handle. Example: TorchML driver based
models. These models are to be trained using the TorchTrainer.

Args:
tarball: path to the tarball.

Returns:
Model driver name.
"""
archive_content = tarfile.open(tarball)
cmake_file_path = archive_content.getnames()[0] + "/CMakeLists.txt"
cmake_file = archive_content.extractfile(cmake_file_path)
cmake_file_content = cmake_file.read().decode("utf-8")

md_start_idx = cmake_file_content.find("DRIVER_NAME")
if md_start_idx == -1:
return None
else:
# name strats at "
md_start_idx = cmake_file_content.find('"', md_start_idx) + 1
if md_start_idx == -1:
return None
md_end_idx = cmake_file_content.find('"', md_start_idx)
return cmake_file_content[md_start_idx:md_end_idx]


class KIMModelError(Exception):
def __init__(self, msg):
Expand Down
2 changes: 2 additions & 0 deletions kliff/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base_trainer import Trainer
from .kim_trainer import KIMTrainer