diff --git a/src/pytorch_tabular/tabular_model_sweep.py b/src/pytorch_tabular/tabular_model_sweep.py index b6e749f2..eb1e2d60 100644 --- a/src/pytorch_tabular/tabular_model_sweep.py +++ b/src/pytorch_tabular/tabular_model_sweep.py @@ -92,6 +92,7 @@ def _validate_args( experiment_config: Optional[Union[ExperimentConfig, str]] = None, common_model_args: Optional[dict] = {}, rank_metric: Optional[str] = "loss", + custom_fit_params: Optional[dict] = {}, ): assert task in [ "classification", @@ -149,6 +150,8 @@ def _validate_args( "lower_is_better", "higher_is_better", ], "rank_metric[1] must be one of ['lower_is_better', 'higher_is_better'], but" f" got {rank_metric[1]}" + if "metrics" in custom_fit_params.keys(): + assert rank_metric[0] == "loss", "only loss is supported as the rank_metric when using custom metrics" def model_sweep( @@ -172,6 +175,7 @@ def model_sweep( progress_bar: bool = True, verbose: bool = True, suppress_lightning_logger: bool = True, + custom_fit_params: Optional[dict] = {}, ): """Compare multiple models on the same dataset. @@ -231,6 +235,10 @@ def model_sweep( suppress_lightning_logger (bool, optional): If True, will suppress the lightning logger. Defaults to True. + custom_fit_params (dict, optional): A dict specifying custom loss, metrics and optimizer. + The behviour of these custom parameters is similar to those passed through the `fit` method + of `TabularModel`. + Returns: results: Training results. @@ -252,6 +260,7 @@ def model_sweep( experiment_config=experiment_config, common_model_args=common_model_args, rank_metric=rank_metric, + custom_fit_params=custom_fit_params, ) if suppress_lightning_logger: suppress_lightning_logs() @@ -326,7 +335,7 @@ def _init_tabular_model(m): name = tabular_model.name if verbose: logger.info(f"Training {name}") - model = tabular_model.prepare_model(datamodule) + model = tabular_model.prepare_model(datamodule, **custom_fit_params) if progress_bar: progress.update(task_p, description=f"Training {name}", advance=1) with OutOfMemoryHandler(handle_oom=True) as handler: diff --git a/tests/test_common.py b/tests/test_common.py index d0c3e26c..b02b0782 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1143,7 +1143,16 @@ def test_tta_regression( def _run_model_compare( - task, model_list, data_config, trainer_config, optimizer_config, train, test, metric, rank_metric + task, + model_list, + data_config, + trainer_config, + optimizer_config, + train, + test, + metric, + rank_metric, + custom_fit_params={}, ): model_list = copy.deepcopy(model_list) if isinstance(model_list, list): @@ -1161,6 +1170,7 @@ def _run_model_compare( metrics_params=metric[1], metrics_prob_input=metric[2], rank_metric=rank_metric, + custom_fit_params=custom_fit_params, ) @@ -1249,6 +1259,66 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols, # assert best_model.model._get_name() in best_models +@pytest.mark.parametrize("model_list", ["lite", MODEL_CONFIG_MODEL_SWEEP_TEST]) +@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)]) +@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]]) +@pytest.mark.parametrize( + "metric", + [ + (["mean_squared_error"], [{}], [False]), + ], +) +@pytest.mark.parametrize("rank_metric", [("loss", "lower_is_better")]) +@pytest.mark.parametrize( + "custom_fit_params", + [ + { + "loss": torch.nn.L1Loss(), + "metrics": [fake_metric], + "metrics_prob_inputs": [True], + "optimizer": torch.optim.Adagrad, + }, + ], +) +def test_model_compare_custom( + regression_data, model_list, continuous_cols, categorical_cols, metric, rank_metric, custom_fit_params +): + (train, test, target) = regression_data + data_config = DataConfig( + target=target, + continuous_cols=continuous_cols, + categorical_cols=categorical_cols, + handle_missing_values=True, + handle_unknown_categories=True, + ) + trainer_config = TrainerConfig( + max_epochs=3, + checkpoints=None, + early_stopping=None, + accelerator="cpu", + fast_dev_run=True, + ) + optimizer_config = OptimizerConfig() + comp_df, best_model = _run_model_compare( + "regression", + model_list, + data_config, + trainer_config, + optimizer_config, + train, + test, + metric, + rank_metric, + custom_fit_params=custom_fit_params, + ) + if model_list == "lite": + assert len(comp_df) == 3 + else: + assert len(comp_df) == len(model_list) + if custom_fit_params.get("metric", None) == fake_metric: + assert "test_fake_metric" in comp_df.columns() + + @pytest.mark.parametrize("model_config_class", MODEL_CONFIG_SAVE_TEST) @pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)]) @pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]])