Skip to content

Commit

Permalink
Merge branch 'main' into operators/automlx_24.2
Browse files Browse the repository at this point in the history
  • Loading branch information
ahosler committed May 3, 2024
2 parents e8e09b9 + d410494 commit 754bb71
Show file tree
Hide file tree
Showing 13 changed files with 377 additions and 9 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/run-forecast-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ jobs:

steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.sha }}


- uses: actions/setup-python@v5
with:
Expand Down
6 changes: 6 additions & 0 deletions THIRD_PARTY_LICENSES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,12 @@ python-fire
* Source code: https://github.com/google/python-fire
* Project home: https://github.com/google/python-fire

mlforecast
* Copyright 2024 Nixtla
* License: Apache License 2.0
* Source code: https://github.com/Nixtla/mlforecast
* Project home: https://github.com/Nixtla/mlforecast

=======
=============================== Licenses ===============================
------------------------------------------------------------------------
Expand Down
55 changes: 49 additions & 6 deletions ads/jobs/builders/infrastructure/dsc_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import inspect
import logging
import os
import re
import time
import traceback
import uuid
Expand Down Expand Up @@ -375,12 +376,13 @@ def delete(self, force_delete: bool = False) -> DSCJob:
"""
runs = self.run_list()
for run in runs:
if run.lifecycle_state in [
DataScienceJobRun.LIFECYCLE_STATE_ACCEPTED,
DataScienceJobRun.LIFECYCLE_STATE_IN_PROGRESS,
DataScienceJobRun.LIFECYCLE_STATE_NEEDS_ATTENTION,
]:
run.cancel(wait_for_completion=True)
if force_delete:
if run.lifecycle_state in [
DataScienceJobRun.LIFECYCLE_STATE_ACCEPTED,
DataScienceJobRun.LIFECYCLE_STATE_IN_PROGRESS,
DataScienceJobRun.LIFECYCLE_STATE_NEEDS_ATTENTION,
]:
run.cancel(wait_for_completion=True)
run.delete()
self.client.delete_job(self.id)
return self
Expand Down Expand Up @@ -582,6 +584,25 @@ def logging(self) -> OCILog:
id=self.log_id, log_group_id=self.log_details.log_group_id, **auth
)

@property
def exit_code(self):
"""The exit code of the job run from the lifecycle details.
Note that,
None will be returned if the job run is not finished or failed without exit code.
0 will be returned if job run succeeded.
"""
if self.lifecycle_state == self.LIFECYCLE_STATE_SUCCEEDED:
return 0
if not self.lifecycle_details:
return None
match = re.search(r"exit code (\d+)", self.lifecycle_details)
if not match:
return None
try:
return int(match.group(1))
except Exception:
return None

@staticmethod
def _format_log(message: str, date_time: datetime.datetime) -> dict:
"""Formats a message as log record with datetime.
Expand Down Expand Up @@ -655,6 +676,22 @@ def _check_and_print_status(self, prev_status) -> str:
print(f"{timestamp} - {status}")
return status

def wait(self, interval: float = SLEEP_INTERVAL):
"""Waits for the job run until if finishes.
Parameters
----------
interval : float
Time interval in seconds between each request to update the logs.
Defaults to 3 (seconds).
"""
self.sync()
while self.status not in self.TERMINAL_STATES:
time.sleep(interval)
self.sync()
return self

def watch(
self,
interval: float = SLEEP_INTERVAL,
Expand Down Expand Up @@ -830,6 +867,12 @@ def download(self, to_dir):
self.job.download(to_dir)
return self

def delete(self, force_delete: bool = False):
if force_delete:
self.cancel(wait_for_completion=True)
super().delete()
return


# This is for backward compatibility
DSCJobRun = DataScienceJobRun
Expand Down
1 change: 1 addition & 0 deletions ads/opctl/operator/lowcode/forecast/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class SupportedModels(str, metaclass=ExtendedEnumMeta):
Prophet = "prophet"
Arima = "arima"
NeuralProphet = "neuralprophet"
MLForecast = "mlforecast"
AutoMLX = "automlx"
AutoTS = "autots"
Auto = "auto"
Expand Down
1 change: 1 addition & 0 deletions ads/opctl/operator/lowcode/forecast/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- oracle-ads>=2.9.0
- prophet
- neuralprophet
- mlforecast
- pmdarima
- statsmodels
- report-creator
Expand Down
2 changes: 2 additions & 0 deletions ads/opctl/operator/lowcode/forecast/model/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .base_model import ForecastOperatorBaseModel
from .neuralprophet import NeuralProphetOperatorModel
from .prophet import ProphetOperatorModel
from .ml_forecast import MLForecastOperatorModel
from ..utils import select_auto_model
from .forecast_datasets import ForecastDatasets

Expand All @@ -32,6 +33,7 @@ class ForecastOperatorModelFactory:
SupportedModels.Prophet: ProphetOperatorModel,
SupportedModels.Arima: ArimaOperatorModel,
SupportedModels.NeuralProphet: NeuralProphetOperatorModel,
SupportedModels.MLForecast: MLForecastOperatorModel,
SupportedModels.AutoMLX: AutoMLXOperatorModel,
SupportedModels.AutoTS: AutoTSOperatorModel
}
Expand Down
16 changes: 15 additions & 1 deletion ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def create_horizon(self, spec, historical_data):
pd.date_range(
start=historical_data.get_max_time(),
periods=spec.horizon + 1,
freq=historical_data.freq,
freq=historical_data.freq
or pd.infer_freq(
historical_data.data.reset_index()[spec.datetime_column.name][-5:]
),
),
name=spec.datetime_column.name,
)
Expand Down Expand Up @@ -135,6 +138,7 @@ def __init__(self, config: ForecastOperatorConfig):

self._horizon = config.spec.horizon
self._datetime_column_name = config.spec.datetime_column.name
self._target_col = config.spec.target_column
self._load_data(config.spec)

def _load_data(self, spec):
Expand All @@ -158,6 +162,16 @@ def get_all_data_long(self, include_horizon=True):
on=[self._datetime_column_name, ForecastOutputColumns.SERIES],
).reset_index()

def get_all_data_long_forecast_horizon(self):
"""Returns all data in long format for the forecast horizon."""
test_data = pd.merge(
self.historical_data.data,
self.additional_data.data,
how="outer",
on=[self._datetime_column_name, ForecastOutputColumns.SERIES],
).reset_index()
return test_data[test_data[self._target_col].isnull()].reset_index(drop=True)

def get_data_multi_indexed(self):
return pd.concat(
[
Expand Down
Loading

0 comments on commit 754bb71

Please sign in to comment.