Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/pytorch_tabular/tabular_model_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def _validate_args(
experiment_config: Optional[Union[ExperimentConfig, str]] = None,
common_model_args: Optional[dict] = {},
rank_metric: Optional[str] = "loss",
custom_fit_params: Optional[dict] = {},
):
assert task in [
"classification",
Expand Down Expand Up @@ -149,6 +150,8 @@ def _validate_args(
"lower_is_better",
"higher_is_better",
], "rank_metric[1] must be one of ['lower_is_better', 'higher_is_better'], but" f" got {rank_metric[1]}"
if "metrics" in custom_fit_params.keys():
assert rank_metric[0] == "loss", "only loss is supported as the rank_metric when using custom metrics"


def model_sweep(
Expand All @@ -172,6 +175,7 @@ def model_sweep(
progress_bar: bool = True,
verbose: bool = True,
suppress_lightning_logger: bool = True,
custom_fit_params: Optional[dict] = {},
):
"""Compare multiple models on the same dataset.

Expand Down Expand Up @@ -231,6 +235,10 @@ def model_sweep(

suppress_lightning_logger (bool, optional): If True, will suppress the lightning logger. Defaults to True.

custom_fit_params (dict, optional): A dict specifying custom loss, metrics and optimizer.
The behviour of these custom parameters is similar to those passed through the `fit` method
of `TabularModel`.

Returns:
results: Training results.

Expand All @@ -252,6 +260,7 @@ def model_sweep(
experiment_config=experiment_config,
common_model_args=common_model_args,
rank_metric=rank_metric,
custom_fit_params=custom_fit_params,
)
if suppress_lightning_logger:
suppress_lightning_logs()
Expand Down Expand Up @@ -326,7 +335,7 @@ def _init_tabular_model(m):
name = tabular_model.name
if verbose:
logger.info(f"Training {name}")
model = tabular_model.prepare_model(datamodule)
model = tabular_model.prepare_model(datamodule, **custom_fit_params)
if progress_bar:
progress.update(task_p, description=f"Training {name}", advance=1)
with OutOfMemoryHandler(handle_oom=True) as handler:
Expand Down
72 changes: 71 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,16 @@ def test_tta_regression(


def _run_model_compare(
task, model_list, data_config, trainer_config, optimizer_config, train, test, metric, rank_metric
task,
model_list,
data_config,
trainer_config,
optimizer_config,
train,
test,
metric,
rank_metric,
custom_fit_params={},
):
model_list = copy.deepcopy(model_list)
if isinstance(model_list, list):
Expand All @@ -1161,6 +1170,7 @@ def _run_model_compare(
metrics_params=metric[1],
metrics_prob_input=metric[2],
rank_metric=rank_metric,
custom_fit_params=custom_fit_params,
)


Expand Down Expand Up @@ -1249,6 +1259,66 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols,
# assert best_model.model._get_name() in best_models


@pytest.mark.parametrize("model_list", ["lite", MODEL_CONFIG_MODEL_SWEEP_TEST])
@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)])
@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]])
@pytest.mark.parametrize(
"metric",
[
(["mean_squared_error"], [{}], [False]),
],
)
@pytest.mark.parametrize("rank_metric", [("loss", "lower_is_better")])
@pytest.mark.parametrize(
"custom_fit_params",
[
{
"loss": torch.nn.L1Loss(),
"metrics": [fake_metric],
"metrics_prob_inputs": [True],
"optimizer": torch.optim.Adagrad,
},
],
)
def test_model_compare_custom(
regression_data, model_list, continuous_cols, categorical_cols, metric, rank_metric, custom_fit_params
):
(train, test, target) = regression_data
data_config = DataConfig(
target=target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
handle_missing_values=True,
handle_unknown_categories=True,
)
trainer_config = TrainerConfig(
max_epochs=3,
checkpoints=None,
early_stopping=None,
accelerator="cpu",
fast_dev_run=True,
)
optimizer_config = OptimizerConfig()
comp_df, best_model = _run_model_compare(
"regression",
model_list,
data_config,
trainer_config,
optimizer_config,
train,
test,
metric,
rank_metric,
custom_fit_params=custom_fit_params,
)
if model_list == "lite":
assert len(comp_df) == 3
else:
assert len(comp_df) == len(model_list)
if custom_fit_params.get("metric", None) == fake_metric:
assert "test_fake_metric" in comp_df.columns()


@pytest.mark.parametrize("model_config_class", MODEL_CONFIG_SAVE_TEST)
@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)])
@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]])
Expand Down
Loading