Skip to content

Commit

Permalink
add ml_model_settings parameter to pass kwargs to the ml model init f…
Browse files Browse the repository at this point in the history
…unction
  • Loading branch information
mafrahm committed May 21, 2024
1 parent da009d3 commit a2aa1e2
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions columnflow/tasks/framework/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

from columnflow.types import Sequence, Any, Iterable, Union
from columnflow.tasks.framework.base import AnalysisTask, ConfigTask, RESOLVE_DEFAULT
from columnflow.tasks.framework.parameters import SettingsParameter
from columnflow.calibration import Calibrator
from columnflow.selection import Selector
from columnflow.production import Producer
from columnflow.weight import WeightProducer
from columnflow.ml import MLModel
from columnflow.inference import InferenceModel
from columnflow.columnar_util import Route, ColumnCollection
from columnflow.util import maybe_import
from columnflow.util import maybe_import, DotDict

ak = maybe_import("awkward")

Expand Down Expand Up @@ -1060,8 +1061,30 @@ class MLModelMixinBase(AnalysisTask):
description="the name of the ML model to be applied",
)

ml_model_settings = SettingsParameter(
default=DotDict(),
description="settings passed to the init function of the ML model",
)

exclude_params_repr_empty = {"ml_model"}

@property
def ml_model_repr(self):
"""Returns a string representation of the ML model and it's settings."""
if hasattr(self, "_ml_model_repr"):
# if existing, return the cached value
return self._ml_model_repr

repr = self.ml_model_inst.cls_name

if hasattr(self.ml_model_inst, "parameters_repr"):
# if existing, return the parameters_repr of the ml_model_inst
repr += f"__{self.ml_model_inst.parameters_repr}"

# cache the value and return it
self._ml_model_repr = repr
return self._ml_model_repr

@classmethod
def req_params(cls, inst: law.Task, **kwargs) -> dict[str, Any]:
"""Get the required parameters for the task, preferring the ``--ml-model`` set on task-level via CLI.
Expand Down Expand Up @@ -1409,12 +1432,12 @@ def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# get the ML model instance
self.ml_model_inst = self.get_ml_model_inst(
self.ml_model,
self.analysis_inst,
configs=list(self.configs),
**self.ml_model_settings,
)

def store_parts(self) -> law.util.InsertableDict[str, str]:
Expand Down Expand Up @@ -1467,7 +1490,7 @@ def store_parts(self) -> law.util.InsertableDict[str, str]:
parts.insert_before("version", label, f"{label}__{part}")

if self.ml_model_inst:
parts.insert_before("version", "ml_model", f"ml__{self.ml_model_inst.cls_name}")
parts.insert_before("version", "ml_model", f"ml__{self.ml_model_repr}")

return parts

Expand Down Expand Up @@ -1530,7 +1553,7 @@ def store_parts(self) -> law.util.InsertableDict:
parts = super().store_parts()

if self.ml_model_inst:
parts.insert_before("version", "ml_model", f"ml__{self.ml_model_inst.cls_name}")
parts.insert_before("version", "ml_model", f"ml__{self.ml_model_repr}")

return parts

Expand All @@ -1551,7 +1574,7 @@ def store_parts(self) -> law.util.InsertableDict:
parts = super().store_parts()

# replace the ml_model entry
store_name = self.ml_model_inst.store_name or self.ml_model_inst.cls_name
store_name = self.ml_model_inst.store_name or self.ml_model_repr
parts.insert_before("ml_model", "ml_data", f"ml__{store_name}")
parts.pop("ml_model")

Expand Down

0 comments on commit a2aa1e2

Please sign in to comment.