Skip to content

Commit

Permalink
Setup dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
xzdandy committed Oct 11, 2023
1 parent 9887342 commit 6825ac1
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 8 deletions.
6 changes: 3 additions & 3 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@
from evadb.utils.generic_utils import (
load_function_class_from_file,
string_comparison_case_insensitive,
try_to_import_automl,
try_to_import_ludwig,
try_to_import_neuralforecast,
try_to_import_sklearn,
try_to_import_statsforecast,
try_to_import_torch,
try_to_import_ultralytics,
try_to_import_xgboost,
)
from evadb.utils.logging_manager import logger

Expand Down Expand Up @@ -169,7 +169,7 @@ def handle_xgboost_function(self):
We use the Flaml AutoML model for training xgboost models.
"""
try_to_import_automl()
try_to_import_xgboost()

assert (
len(self.children) == 1
Expand Down Expand Up @@ -445,7 +445,7 @@ def handle_forecasting_function(self):
if int(x.split("horizon")[1].split(".pkl")[0]) >= horizon
]
if len(existing_model_files) == 0:
print("Training, please wait...")
logger.info("Training, please wait...")
if library == "neuralforecast":
model.fit(df=data, val_size=horizon)
else:
Expand Down
4 changes: 2 additions & 2 deletions evadb/functions/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pandas as pd

from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.utils.generic_utils import try_to_import_automl
from evadb.utils.generic_utils import try_to_import_xgboost


class GenericXGBoostModel(AbstractFunction):
Expand All @@ -26,7 +26,7 @@ def name(self) -> str:
return "GenericXGBoostModel"

def setup(self, model_path: str, **kwargs):
try_to_import_automl()
try_to_import_xgboost()

self.model = pickle.load(open(model_path, "rb"))

Expand Down
10 changes: 9 additions & 1 deletion evadb/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def is_sklearn_available() -> bool:
return False


def try_to_import_automl():
def try_to_import_xgboost():
try:
import flaml # noqa: F401
from flaml import AutoML # noqa: F401
Expand All @@ -388,6 +388,14 @@ def try_to_import_automl():
)


def is_xgboost_available() -> bool:
try:
try_to_import_xgboost()
return True
except ValueError: # noqa: E722
return False


##############################
## VISION
##############################
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def read(path, encoding="utf-8"):

sklearn_libs = ["scikit-learn"]

xgboost_libs = ["flaml[automl]"]

forecasting_libs = [
"statsforecast", # MODEL TRAIN AND FINE TUNING
"neuralforecast" # MODEL TRAIN AND FINE TUNING
Expand Down Expand Up @@ -165,9 +167,10 @@ def read(path, encoding="utf-8"):
"postgres": postgres_libs,
"ludwig": ludwig_libs,
"sklearn": sklearn_libs,
"xgboost": xgboost_libs,
"forecasting": forecasting_libs,
# everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11.
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs,
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + xgboost_libs,
}

setup(
Expand Down
3 changes: 2 additions & 1 deletion test/integration_tests/long/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test.markers import ludwig_skip_marker, sklearn_skip_marker
from test.markers import ludwig_skip_marker, sklearn_skip_marker, xgboost_skip_marker
from test.util import get_evadb_for_testing, shutdown_ray

import pytest
Expand Down Expand Up @@ -95,6 +95,7 @@ def test_sklearn_regression(self):
self.assertEqual(len(result.columns), 1)
self.assertEqual(len(result), 10)

@xgboost_skip_marker
def test_xgboost_regression(self):
create_predict_function = """
CREATE FUNCTION IF NOT EXISTS PredictRent FROM
Expand Down
5 changes: 5 additions & 0 deletions test/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
is_pinecone_available,
is_qdrant_available,
is_sklearn_available,
is_xgboost_available,
)

asyncio_skip_marker = pytest.mark.skipif(
Expand Down Expand Up @@ -88,6 +89,10 @@
is_sklearn_available() is False, reason="Run only if sklearn is available"
)

xgboost_skip_marker = pytest.mark.skipif(
is_xgboost_available() is False, reason="Run only if xgboost is available"
)

chatgpt_skip_marker = pytest.mark.skip(
reason="requires chatgpt",
)
Expand Down

0 comments on commit 6825ac1

Please sign in to comment.