Skip to content

Commit

Permalink
Adding changes for Flaml Sklearn integration
Browse files Browse the repository at this point in the history
  • Loading branch information
Jineet Desai committed Nov 16, 2023
1 parent 1fbb74f commit a36685e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
2 changes: 2 additions & 0 deletions evadb/configuration/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,5 @@
DEFAULT_DOCUMENT_CHUNK_OVERLAP = 200
DEFAULT_TRAIN_REGRESSION_METRIC = "rmse"
DEFAULT_XGBOOST_TASK = "regression"
DEFAULT_SKLEARN_TRAIN_MODEL = "rf"
SKLEARN_SUPPORTED_MODELS = ["rf", "extra_tree", "kneighbor"]
26 changes: 20 additions & 6 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry
from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry
from evadb.configuration.constants import (
DEFAULT_SKLEARN_TRAIN_MODEL,
DEFAULT_TRAIN_REGRESSION_METRIC,
DEFAULT_TRAIN_TIME_LIMIT,
DEFAULT_XGBOOST_TASK,
SKLEARN_SUPPORTED_MODELS,
EvaDB_INSTALLATION_DIR,
)
from evadb.database import EvaDBDatabase
Expand Down Expand Up @@ -165,7 +167,6 @@ def handle_sklearn_function(self):
Use Sklearn's regression to train models.
"""
try_to_import_sklearn()
from sklearn.linear_model import LinearRegression

assert (
len(self.children) == 1
Expand All @@ -181,13 +182,26 @@ def handle_sklearn_function(self):
aggregated_batch.drop_column_alias()

arg_map = {arg.key: arg.value for arg in self.node.metadata}
model = LinearRegression()
Y = aggregated_batch.frames[arg_map["predict"]]
aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True)
from flaml import AutoML

model = AutoML()
sklearn_model = arg_map.get("model", DEFAULT_SKLEARN_TRAIN_MODEL)
if sklearn_model not in SKLEARN_SUPPORTED_MODELS:
raise ValueError(
f"Sklearn Model {sklearn_model} provided as input is not supported."
)
settings = {
"time_budget": arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT),
"metric": arg_map.get("metric", DEFAULT_TRAIN_REGRESSION_METRIC),
"estimator_list": [sklearn_model],
"task": arg_map.get("task", DEFAULT_XGBOOST_TASK),
}
start_time = int(time.time())
model.fit(X=aggregated_batch.frames, y=Y)
model.fit(
dataframe=aggregated_batch.frames, label=arg_map["predict"], **settings
)
train_time = int(time.time()) - start_time
score = model.score(X=aggregated_batch.frames, y=Y)
score = model.best_loss
model_path = os.path.join(
self.db.catalog().get_configuration_catalog_value("model_dir"),
self.node.name,
Expand Down
3 changes: 2 additions & 1 deletion test/integration_tests/long/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def test_sklearn_regression(self):
CREATE OR REPLACE FUNCTION PredictHouseRentSklearn FROM
( SELECT number_of_rooms, number_of_bathrooms, days_on_market, rental_price FROM HomeRentals )
TYPE Sklearn
PREDICT 'rental_price';
PREDICT 'rental_price'
MODEL 'extra_tree';
"""
execute_query_fetch_all(self.evadb, create_predict_function)

Expand Down

0 comments on commit a36685e

Please sign in to comment.