Skip to content

Commit

Permalink
Fix ARIMA.
Browse files Browse the repository at this point in the history
  • Loading branch information
csadorf committed Feb 6, 2023
1 parent a682a5a commit 19a7d39
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python/cuml/tsa/arima.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ class ARIMA(Base):
raise NotImplementedError("ARIMA is unable to be cloned via "
"`get_params` and `set_params`.")

@cuml.internals.api_base_return_generic(input_arg=None)
@cuml.internals.api_base_return_array()
def predict(
self,
start=0,
Expand Down Expand Up @@ -745,7 +745,7 @@ class ARIMA(Base):
d_upper)

@nvtx.annotate(message="tsa.arima.ARIMA.forecast", domain="cuml_python")
@cuml.internals.api_base_return_generic_skipall
@cuml.internals.api_base_return_array()
def forecast(
self,
nsteps: int,
Expand Down Expand Up @@ -938,7 +938,7 @@ class ARIMA(Base):
return self
@nvtx.annotate(message="tsa.arima.ARIMA._loglike", domain="cuml_python")
@cuml.internals.api_base_return_any_skipall
@cuml.internals.api_base_return_array()
def _loglike(self, x, trans=True, method="ml", truncate=0):
"""Compute the batched log-likelihood for the given parameters.

Expand Down Expand Up @@ -1002,7 +1002,7 @@ class ARIMA(Base):
@nvtx.annotate(message="tsa.arima.ARIMA._loglike_grad",
domain="cuml_python")
@cuml.internals.api_base_return_any_skipall
@cuml.internals.api_base_return_array()
def _loglike_grad(self, x, h=1e-8, trans=True, method="ml", truncate=0):
"""Compute the gradient (via finite differencing) of the batched
log-likelihood.
Expand Down

0 comments on commit 19a7d39

Please sign in to comment.