Skip to content

Commit

Permalink
Fix column rename and output binded object
Browse files Browse the repository at this point in the history
  • Loading branch information
xzdandy committed Sep 9, 2023
1 parent e904fc5 commit 4680380
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 39 deletions.
3 changes: 1 addition & 2 deletions evadb/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,7 @@ def _bind_tuple_expr(self, node: TupleValueExpression):
self._binder_context.enable_audio_retrieval()
if node.name == VideoColumnName.data:
self._binder_context.enable_video_retrieval()
#node.col_alias = "{}.{}".format(table_alias, node.name.lower())
node.col_alias = "{}.{}".format(table_alias, node.name)
node.col_alias = "{}.{}".format(table_alias, node.name.lower())
node.col_object = col_obj

@bind.register(FunctionExpression)
Expand Down
19 changes: 4 additions & 15 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,24 +220,13 @@ def handle_forecasting_function(self):
pickle.dump(model, f)
f.close()

arg_map_here = {"model_name": model_name, "model_path": model_path}
function = self._try_initializing_function(impl_path, arg_map_here)
io_list = self._resolve_function_io(function)
io_list = self._resolve_function_io(None)

metadata_here = [
FunctionMetadataCatalogEntry("model_name", model_name),
FunctionMetadataCatalogEntry("model_path", model_path),
FunctionMetadataCatalogEntry(
key="model_name",
value=model_name,
function_id=None,
function_name=None,
row_id=None,
),
FunctionMetadataCatalogEntry(
key="model_path",
value=model_path,
function_id=None,
function_name=None,
row_id=None,
"output_column_rename", self.node.outputs[0].name
),
]

Expand Down
28 changes: 6 additions & 22 deletions evadb/functions/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@

import pandas as pd

from evadb.catalog.catalog_type import NdArrayType
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward, setup
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
from evadb.functions.decorators.decorators import setup


class ForecastModel(AbstractFunction):
Expand All @@ -30,35 +28,21 @@ def name(self) -> str:
return "ForecastModel"

@setup(cacheable=False, function_type="Forecasting", batchable=True)
def setup(self, model_name: str, model_path: str):
def setup(self, model_name: str, model_path: str, output_column_rename: str):
f = open(model_path, "rb")
loaded_model = pickle.load(f)
f.close()
self.model = loaded_model
self.model_name = model_name
self.output_column_rename = output_column_rename

@forward(
input_signatures=[],
output_signatures=[
PandasDataframe(
columns=["y"],
column_types=[
NdArrayType.FLOAT32,
],
column_shapes=[(None,)],
)
],
)
def forward(self, data) -> pd.DataFrame:
horizon = list(data.iloc[:, -1])[0]
assert (
type(horizon) is int
), "Forecast UDF expects integral horizon in parameter."
forecast_df = self.model.predict(h=horizon)
forecast_df = forecast_df.rename(columns={self.model_name: "y"})
return pd.DataFrame(
forecast_df,
columns=[
"y",
],
forecast_df = forecast_df.rename(
columns={self.model_name: self.output_column_rename}
)
return forecast_df

0 comments on commit 4680380

Please sign in to comment.