Skip to content
This repository has been archived by the owner on Oct 7, 2023. It is now read-only.

Commit

Permalink
fix(Arch): finished integration
Browse files Browse the repository at this point in the history
  • Loading branch information
almostintuitive committed May 1, 2023
1 parent 2a05c2b commit 1cb55e1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 32 deletions.
24 changes: 7 additions & 17 deletions src/fold_wrappers/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional, Union

import pandas as pd
from fold.base import fit_noop
from fold.models.base import Model
from fold.utils.checks import is_X_available

Expand Down Expand Up @@ -41,33 +42,22 @@ def fit(
)
if use_exogenous:
self.model = arch_model(y, x=X, **self.init_args)
self.model = self.model.fit()
else:
self.model = arch_model(y, **self.init_args)
self.model = self.model.fit()

def update(
self, X: pd.DataFrame, y: pd.Series, sample_weights: Optional[pd.Series] = None
) -> None:
if not hasattr(self.model, "append"):
return
use_exogenous = (
is_X_available(X) if self.use_exogenous is None else self.use_exogenous
)
if use_exogenous:
self.model = self.model.fit(starting_values=self.model.params, last_obs=y)
else:
self.model = self.model.append(endog=y, refit=True)
self.model = self.model.fit(disp="off")

def predict(self, X: pd.DataFrame) -> Union[pd.Series, pd.DataFrame]:
use_exogenous = (
is_X_available(X) if self.use_exogenous is None else self.use_exogenous
)
if use_exogenous:
return pd.Series(self.model.forecast(horizon=len(X), x=X))
res = self.model.forecast(horizon=len(X), reindex=False, x=X)
else:
return pd.Series(self.model.forecast(horizon=len(X)))
res = self.model.forecast(horizon=len(X), reindex=False)
return pd.Series(res.variance.values[0], index=X.index)

def predict_in_sample(self, X: pd.DataFrame) -> Union[pd.Series, pd.DataFrame]:
res = self.model.forecast(horizon=len(X), start=0, reindex=True)
return res.variance[res.variance.columns[0]]

update = fit_noop
18 changes: 3 additions & 15 deletions tests/test_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,9 @@


def test_arch_univariate() -> None:
# run_pipeline_and_check_if_results_close_univariate(
# model=WrapArch(init_args=dict(vol="Garch", p=1, o=0, q=1, dist="Normal")),
# splitter=ExpandingWindowSplitter(initial_train_window=50, step=1),
# )

_, y = generate_sine_wave_data(length=200)
y = np.log(y + 2.0).diff().dropna() * 100
model = WrapArch(init_args=dict(vol="Garch", p=1, o=0, q=1, dist="Normal"))
splitter = ExpandingWindowSplitter(initial_train_window=0.5, step=0.05)
model = WrapArch(init_args=dict(vol="Garch", p=1, o=1, q=1, dist="Normal"))
splitter = ExpandingWindowSplitter(initial_train_window=0.5, step=1)
pred, _ = train_backtest(model, None, y, splitter)
assert np.isclose(y.squeeze()[pred.index], pred.squeeze().values, atol=0.1).all()


# def test_arch_univariate_online() -> None:
# run_pipeline_and_check_if_results_close_univariate(
# model=WrapArch(init_args={"order": (1, 1, 0)}, online_mode=True),
# splitter=ExpandingWindowSplitter(initial_train_window=50, step=10),
# )
# assert np.isclose(y[pred.index] ** 2, pred.squeeze().values, atol=0.1).all()

0 comments on commit 1cb55e1

Please sign in to comment.