Skip to content
This repository has been archived by the owner on Oct 7, 2023. It is now read-only.

Commit

Permalink
fix(Tunable): updated interface
Browse files Browse the repository at this point in the history
  • Loading branch information
almostintuitive committed May 11, 2023
1 parent 29c60f5 commit 14fcea9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 4 additions & 2 deletions src/fold_wrappers/lightgbm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Optional, Type, Union
from typing import Any, Callable, Optional, Type, Union

import pandas as pd
from fold.base import Tunable
Expand Down Expand Up @@ -81,7 +81,9 @@ def predict(self, X: pd.DataFrame) -> Union[pd.Series, pd.DataFrame]:
def get_params(self) -> dict:
return self.model.get_params()

def clone_with_params(self, **parameters) -> Tunable:
def clone_with_params(
self, parameters: dict, clone_children: Optional[Callable] = None
) -> Tunable:
return WrapLGBM(
self.model_class,
init_args=parameters,
Expand Down
6 changes: 4 additions & 2 deletions src/fold_wrappers/xgboost.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Optional, Type, Union
from typing import Any, Callable, Optional, Type, Union

import pandas as pd
from fold.base import Tunable
Expand Down Expand Up @@ -85,7 +85,9 @@ def predict(self, X: pd.DataFrame) -> Union[pd.Series, pd.DataFrame]:
def get_params(self) -> dict:
return self.model.get_params()

def clone_with_params(self, **parameters) -> Tunable:
def clone_with_params(
self, parameters: dict, clone_children: Optional[Callable] = None
) -> Tunable:
return WrapXGB(
model_class=self.model_class,
init_args=parameters,
Expand Down

0 comments on commit 14fcea9

Please sign in to comment.