Skip to content

Commit

Permalink
Merge pull request #434 from columnflow/feature/MLModel_settings
Browse files Browse the repository at this point in the history
add ml_model_settings parameter
  • Loading branch information
pkausw committed May 28, 2024
2 parents 2b8c051 + 76df8e7 commit 3253c97
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 10 deletions.
34 changes: 32 additions & 2 deletions columnflow/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,16 @@ def __init__(
if "configs" in kwargs:
self._setup(kwargs["configs"])

def __str__(self):
"""
Returns a string representation of this model instance. The string is composed of the class
name and the string representation of all parameters.
"""
model_str = f"{self.cls_name}"
if self.parameters_repr:
model_str += f"__{self.parameters_repr}"
return model_str

@property
def config_inst(self: MLModel) -> od.Config:
if self.single_config and len(self.config_insts) != 1:
Expand Down Expand Up @@ -185,6 +195,26 @@ def _format_value(self: MLModel, value: Any) -> str:
# any other case
return str(value)

@property
def parameters_repr(self: MLModel) -> str:
"""
Returns a hash of string representation of all parameters. This is used to uniquely identify
a model instance based on its parameters.
:raises: Exception in case the parameters_repr changed after it was set.
:returns: String representation of all parameters.
"""
if not self.parameters:
return ""
parameters_repr = law.util.create_hash(self._join_parameter_pairs(only_significant=True))
if hasattr(self, "_parameters_repr") and self._parameters_repr != parameters_repr:
raise Exception(
f"parameters_repr changed from {self._parameters_repr} to {parameters_repr};"
"this should not happen",
)
self._parameters_repr = parameters_repr
return self._parameters_repr

def _join_parameter_pairs(self: MLModel, only_significant: bool = True) -> str:
"""
Returns a joined string representation of all significant parameters. In this context,
Expand All @@ -198,11 +228,11 @@ def _join_parameter_pairs(self: MLModel, only_significant: bool = True) -> str:

def parameter_pairs(self: MLModel, only_significant: bool = False) -> list[tuple[str, Any]]:
"""
Returns a list of all parameter name-value tuples. In this context, significant parameters
Returns a sorted list of all parameter name-value tuples. In this context, significant parameters
are those that potentially lead to different results (e.g. network architecture parameters
as opposed to some log level).
"""
return list(self.parameters.items())
return sorted(self.parameters.items())

@property
def accepts_scheduler_messages(self: MLModel) -> bool:
Expand Down
38 changes: 30 additions & 8 deletions columnflow/tasks/framework/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

from columnflow.types import Sequence, Any, Iterable, Union
from columnflow.tasks.framework.base import AnalysisTask, ConfigTask, RESOLVE_DEFAULT
from columnflow.tasks.framework.parameters import SettingsParameter
from columnflow.calibration import Calibrator
from columnflow.selection import Selector
from columnflow.production import Producer
from columnflow.weight import WeightProducer
from columnflow.ml import MLModel
from columnflow.inference import InferenceModel
from columnflow.columnar_util import Route, ColumnCollection
from columnflow.util import maybe_import
from columnflow.util import maybe_import, DotDict

ak = maybe_import("awkward")

Expand Down Expand Up @@ -1069,8 +1070,18 @@ class MLModelMixinBase(AnalysisTask):
description="the name of the ML model to be applied",
)

ml_model_settings = SettingsParameter(
default=DotDict(),
description="settings passed to the init function of the ML model",
)

exclude_params_repr_empty = {"ml_model"}

@property
def ml_model_repr(self):
"""Returns a string representation of the ML model instance."""
return str(self.ml_model_inst)

@classmethod
def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]:
"""Get the required parameters for the task, preferring the ``--ml-model`` set on task-level via CLI.
Expand Down Expand Up @@ -1389,7 +1400,11 @@ def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
analysis_inst = params["analysis_inst"]

# NOTE: we could try to implement resolving the default ml_model here
ml_model_inst = cls.get_ml_model_inst(params["ml_model"], analysis_inst)
ml_model_inst = cls.get_ml_model_inst(
params["ml_model"],
analysis_inst,
parameters=params["ml_model_settings"],
)
params["ml_model_inst"] = ml_model_inst

# resolve configs
Expand Down Expand Up @@ -1418,12 +1433,12 @@ def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# get the ML model instance
self.ml_model_inst = self.get_ml_model_inst(
self.ml_model,
self.analysis_inst,
configs=list(self.configs),
parameters=self.ml_model_settings,
)

def store_parts(self) -> law.util.InsertableDict[str, str]:
Expand Down Expand Up @@ -1476,7 +1491,7 @@ def store_parts(self) -> law.util.InsertableDict[str, str]:
parts.insert_before("version", label, f"{label}__{part}")

if self.ml_model_inst:
parts.insert_before("version", "ml_model", f"ml__{self.ml_model_inst.cls_name}")
parts.insert_before("version", "ml_model", f"ml__{self.ml_model_repr}")

return parts

Expand Down Expand Up @@ -1517,6 +1532,7 @@ def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
params["ml_model"],
analysis_inst,
requested_configs=[config_inst],
parameters=params["ml_model_settings"],
)
elif not cls.allow_empty_ml_model:
raise Exception(f"no ml_model configured for {cls.task_family}")
Expand All @@ -1533,13 +1549,14 @@ def __init__(self, *args, **kwargs):
self.ml_model,
self.analysis_inst,
requested_configs=[self.config_inst],
parameters=self.ml_model_settings,
)

def store_parts(self) -> law.util.InsertableDict:
parts = super().store_parts()

if self.ml_model_inst:
parts.insert_before("version", "ml_model", f"ml__{self.ml_model_inst.cls_name}")
parts.insert_before("version", "ml_model", f"ml__{self.ml_model_repr}")

return parts

Expand All @@ -1560,7 +1577,7 @@ def store_parts(self) -> law.util.InsertableDict:
parts = super().store_parts()

# replace the ml_model entry
store_name = self.ml_model_inst.store_name or self.ml_model_inst.cls_name
store_name = self.ml_model_inst.store_name or self.ml_model_repr
parts.insert_before("ml_model", "ml_data", f"ml__{store_name}")
parts.pop("ml_model")

Expand All @@ -1580,6 +1597,12 @@ class MLModelsMixin(ConfigTask):

exclude_params_repr_empty = {"ml_models"}

@property
def ml_models_repr(self):
"""Returns a string representation of the ML models."""
ml_models_repr = "__".join([str(model_inst) for model_inst in self.ml_model_insts])
return ml_models_repr

@classmethod
def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
params = super().resolve_param_values(params)
Expand Down Expand Up @@ -1635,8 +1658,7 @@ def store_parts(self) -> law.util.InsertableDict:
parts = super().store_parts()

if self.ml_model_insts:
part = "__".join(model_inst.cls_name for model_inst in self.ml_model_insts)
parts.insert_before("version", "ml_models", f"ml__{part}")
parts.insert_before("version", "ml_models", f"ml__{self.ml_models_repr}")

return parts

Expand Down

0 comments on commit 3253c97

Please sign in to comment.