Skip to content

Commit

Permalink
Adding parameter for regression metric and time limit.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jineet Desai committed Oct 18, 2023
1 parent 8e1bd05 commit 37c94f8
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
1 change: 1 addition & 0 deletions evadb/configuration/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@
DEFAULT_TRAIN_TIME_LIMIT = 120
DEFAULT_DOCUMENT_CHUNK_SIZE = 4000
DEFAULT_DOCUMENT_CHUNK_OVERLAP = 200
DEFAULT_TRAIN_REGRESSION_METRIC = "rmse"
3 changes: 2 additions & 1 deletion evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
from evadb.configuration.constants import (
DEFAULT_TRAIN_REGRESSION_METRIC,
DEFAULT_TRAIN_TIME_LIMIT,
EvaDB_INSTALLATION_DIR,
)
Expand Down Expand Up @@ -190,7 +191,7 @@ def handle_xgboost_function(self):
model = AutoML()
settings = {
"time_budget": arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
"metric": "r2",
"metric": arg_map.get("metric", DEFAULT_TRAIN_REGRESSION_METRIC),
"estimator_list": ["xgboost"],
"task": "regression",
}
Expand Down
3 changes: 2 additions & 1 deletion test/integration_tests/long/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def test_xgboost_regression(self):
( SELECT number_of_rooms, number_of_bathrooms, days_on_market, rental_price FROM HomeRentals )
TYPE XGBoost
PREDICT 'rental_price'
TIME_LIMIT 180;
TIME_LIMIT 180
METRIC 'r2';
"""
execute_query_fetch_all(self.evadb, create_predict_function)

Expand Down

0 comments on commit 37c94f8

Please sign in to comment.