Skip to content

Commit

Permalink
pass ml_model_settings directly to the self.parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
mafrahm committed May 21, 2024
1 parent 51f8356 commit 4a94a78
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions columnflow/tasks/framework/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,11 @@ def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
analysis_inst = params["analysis_inst"]

# NOTE: we could try to implement resolving the default ml_model here
ml_model_inst = cls.get_ml_model_inst(params["ml_model"], analysis_inst, **params["ml_model_settings"])
ml_model_inst = cls.get_ml_model_inst(
params["ml_model"],
analysis_inst,
parameters=params["ml_model_settings"],
)
params["ml_model_inst"] = ml_model_inst

# resolve configs
Expand Down Expand Up @@ -1437,7 +1441,7 @@ def __init__(self, *args, **kwargs):
self.ml_model,
self.analysis_inst,
configs=list(self.configs),
**self.ml_model_settings,
parameters=self.ml_model_settings,
)

def store_parts(self) -> law.util.InsertableDict[str, str]:
Expand Down Expand Up @@ -1531,6 +1535,7 @@ def resolve_param_values(cls, params: dict[str, Any]) -> dict[str, Any]:
params["ml_model"],
analysis_inst,
requested_configs=[config_inst],
parameters=params["ml_model_settings"],
)
elif not cls.allow_empty_ml_model:
raise Exception(f"no ml_model configured for {cls.task_family}")
Expand All @@ -1547,6 +1552,7 @@ def __init__(self, *args, **kwargs):
self.ml_model,
self.analysis_inst,
requested_configs=[self.config_inst],
parameters=self.ml_model_settings,
)

def store_parts(self) -> law.util.InsertableDict:
Expand Down

0 comments on commit 4a94a78

Please sign in to comment.