Skip to content

Commit

Permalink
plot as a new column
Browse files Browse the repository at this point in the history
  • Loading branch information
americast committed Nov 10, 2023
1 parent d5d90ef commit 4ef2e70
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 6 deletions.
4 changes: 2 additions & 2 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,8 @@ def get_optuna_config(trial):
model_path = os.path.join(model_dir, existing_model_files[-1])
io_list = self._resolve_function_io(None)
data["ds"] = data.ds.astype(str)
last_ds = list(data["ds"])[-2*horizon:]
last_y = list(data["y"])[-2*horizon:]
last_ds = list(data["ds"])[-2 * horizon :]
last_y = list(data["y"])[-2 * horizon :]
metadata_here = [
FunctionMetadataCatalogEntry("model_name", arg_map["model"]),
FunctionMetadataCatalogEntry("model_path", model_path),
Expand Down
62 changes: 58 additions & 4 deletions evadb/functions/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import os
import pickle

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from evadb.functions.abstract.abstract_function import AbstractFunction
Expand All @@ -40,7 +43,7 @@ def setup(
library: str,
conf: int,
last_ds: list,
last_y: list
last_y: list,
):
self.library = library
if "neuralforecast" in self.library:
Expand Down Expand Up @@ -104,10 +107,61 @@ def forward(self, data) -> pd.DataFrame:
log_str += "\nMean normalized RMSE: " + str(self.rmse)
if self.hypers is not None:
log_str += "\nHyperparameters: " + self.hypers
import pudb; pu.db

# Plot figure

## Plot figure

pred_plt = self.last_y + list(
forecast_df[
self.model_name
if self.library == "statsforecast"
else self.model_name + "-median"
]
)
pred_plt_lo = self.last_y + list(
forecast_df[self.model_name + "-lo-" + str(self.conf)]
)
pred_plt_hi = self.last_y + list(
forecast_df[self.model_name + "-hi-" + str(self.conf)]
)

plt.plot(pred_plt, label="Prediction")
plt.fill_between(
x=range(len(pred_plt)), y1=pred_plt_lo, y2=pred_plt_hi, alpha=0.3
)
plt.plot(self.last_y, label="Actual")
plt.xlabel("Time")
plt.ylabel("Value")
xtick_strs = self.last_ds + list(forecast_df["ds"])
num_to_keep_args = list(
range(0, len(xtick_strs), int((len(xtick_strs) - 2) / 8))
) + [len(xtick_strs) - 1]
xtick_strs = [
x if i in num_to_keep_args else "" for i, x in enumerate(xtick_strs)
]
plt.xticks(range(len(pred_plt)), xtick_strs, rotation=85)
plt.legend()
plt.tight_layout()

# convert plt figure to opencv https://copyprogramming.com/howto/convert-matplotlib-figure-to-cv2-image-a-complete-guide-with-examples#converting-matplotlib-figure-to-cv2-image
# convert figure to canvas
canvas = plt.get_current_fig_manager().canvas

# render the canvas
canvas.draw()

# convert canvas to image
img = np.fromstring(canvas.tostring_rgb(), dtype="uint8")
img = img.reshape(canvas.get_width_height()[::-1] + (3,))

# convert image to cv2 format
cv2_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

## Conver to bytes
_, buffer = cv2.imencode(".jpg", cv2_img)
img_bytes = buffer.tobytes()

## Add to dataframe as a plot
forecast_df["plot"] = [img_bytes] + [None] * (len(forecast_df) - 1)

print(log_str)

Expand Down

0 comments on commit 4ef2e70

Please sign in to comment.