Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup of dumped model from training framework on drop function #1442

Open
wants to merge 3 commits into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions evadb/catalog/catalog_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,20 @@ def get_function_catalog_entry_by_name(self, name: str) -> FunctionCatalogEntry:
"""
return self._function_service.get_entry_by_name(name)

def get_function_catalog_entries_by_type(
self, type: str
) -> List[FunctionCatalogEntry]:
"""
Get function information based on type.

Arguments:
type (str): type of the function

Returns:
List of FunctionCatalogEntry object
"""
return self._function_service.get_entries_by_type(type)

def delete_function_catalog_entry_by_name(self, function_name: str) -> bool:
return self._function_service.delete_entry_by_name(function_name)

Expand Down
18 changes: 18 additions & 0 deletions evadb/catalog/services/function_catalog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ def get_entry_by_name(self, name: str) -> FunctionCatalogEntry:
return function_obj.as_dataclass()
return None

def get_entries_by_type(self, function_type: str) -> List[FunctionCatalogEntry]:
"""returns the function entries that matches the type provided.
Empty list if no such entry found.

Arguments:
type (str): name to be searched
"""

entries = (
self.session.execute(
select(self.model).filter(self.model._type == function_type)
)
.scalars()
.all()
)

return [entry.as_dataclass() for entry in entries]

def get_entry_by_id(self, id: int, return_alchemy=False) -> FunctionCatalogEntry:
"""return the function entry that matches the id provided.
None if no such entry found.
Expand Down
1 change: 1 addition & 0 deletions evadb/configuration/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@
DEFAULT_XGBOOST_TASK = "regression"
DEFAULT_SKLEARN_TRAIN_MODEL = "rf"
SKLEARN_SUPPORTED_MODELS = ["rf", "extra_tree", "kneighbor"]
TRAINING_FRAMEWORKS = ["Sklearn", "Ludwig", "XGBoost", "Forecasting"]
65 changes: 54 additions & 11 deletions evadb/executor/drop_object_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil

import pandas as pd

from evadb.configuration.constants import TRAINING_FRAMEWORKS
from evadb.database import EvaDBDatabase
from evadb.executor.abstract_executor import AbstractExecutor
from evadb.executor.executor_utils import ExecutorError, handle_vector_store_params
Expand All @@ -24,6 +27,7 @@
from evadb.plan_nodes.drop_object_plan import DropObjectPlan
from evadb.storage.storage_engine import StorageEngine
from evadb.third_party.vector_stores.utils import VectorStoreFactory
from evadb.utils.generic_utils import string_comparison_case_insensitive
from evadb.utils.logging_manager import logger


Expand Down Expand Up @@ -94,19 +98,58 @@ def _handle_drop_function(self, function_name: str, if_exists: bool):
function_entry = self.catalog().get_function_catalog_entry_by_name(
function_name
)
for cache in function_entry.dep_caches:
self.catalog().drop_function_cache_catalog_entry(cache)

# todo also delete the indexes associated with the table

self.catalog().delete_function_catalog_entry_by_name(function_name)

return Batch(
pd.DataFrame(
{f"Function {function_name} successfully dropped"},
index=[0],
# training framework model cleanup on drop function
err_msg = (
f"Error removing {function_entry.type} model for function {function_name}."
)
try:
if function_entry.type.lower() in [x.lower() for x in TRAINING_FRAMEWORKS]:
filtered_metadata = list(
filter(lambda x: x.key == "model_path", function_entry.metadata)
)
if len(filtered_metadata) > 0:
model_path = os.path.abspath(filtered_metadata[0].value)
"""For 'Forecasting' the entire function catalog of forecasting functions
is checked to see if the model path is shared"""
if string_comparison_case_insensitive(
function_entry.type, "Forecasting"
):
forecasting_function_entries = (
self.catalog().get_function_catalog_entries_by_type(
function_entry.type
)
)
functions_using_same_model = sum(
1
for entry in forecasting_function_entries
if any(
x.key == "model_path"
and os.path.abspath(x.value) == model_path
for x in entry.metadata
)
)
if functions_using_same_model == 1:
dir_path = os.path.abspath(os.path.dirname(model_path))
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
else:
if os.path.exists(model_path):
os.remove(model_path)

except Exception as e:
raise RuntimeError(f"{err_msg}\n{e}")

for cache in function_entry.dep_caches:
self.catalog().drop_function_cache_catalog_entry(cache)

# todo also delete the indexes associated with the table
self.catalog().delete_function_catalog_entry_by_name(function_name)
return Batch(
pd.DataFrame(
{f"Function {function_name} successfully dropped"},
index=[0],
)
)

def _handle_drop_index(self, index_name: str, if_exists: bool):
index_obj = self.catalog().get_index_catalog_entry_by_name(index_name)
Expand Down