forked from xqnwang/darima
-
Notifications
You must be signed in to change notification settings - Fork 3
/
forecast.py
73 lines (58 loc) · 2.56 KB
/
forecast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#! /usr/local/bin/python3.7
# FIXME: write a native `forecast.ar` R function.
import os, zipfile, pathlib
import numpy as np
import pandas as pd
import functools
import warnings
import rpy2.robjects as robjects
from rpy2.robjects import numpy2ri
import rpy2
from pyspark.sql.types import *
from pyspark.sql.functions import pandas_udf, PandasUDFType
##--------------------------------------------------------------------------------------
# R version
##--------------------------------------------------------------------------------------
# robjects.r.source("~/xiaoqian-darima/darima//R/forecast_darima.R", verbose=False)
## robjects.r.source(os.path.dirname(os.path.abspath(__file__)) + "/R/forecast_darima.R", verbose=False)
forecast_darima_rcode = zipfile.ZipFile(pathlib.Path(__file__).parents[1]).open("darima/R/forecast_darima.R").read().decode("utf-8")
robjects.r.source(exprs=rpy2.rinterface.parse(forecast_darima_rcode), verbose=False)
forecast_darima=robjects.r['forecast.darima']
##--------------------------------------------------------------------------------------
# Python version
##--------------------------------------------------------------------------------------
def darima_forec(Theta, Sigma, x, period, h = 1, level = 95):
'''
Forecasting
'''
# Calculate sigma2 hat
#--------------------------------------
sigma2 = float(sum(Sigma.values.diagonal())/Sigma.shape[0])
# Get series data as numpy array (pdf -> numpy array)
#--------------------------------------
Theta = Theta.values
x = x.values
# Forecasting
#--------------------------------------
forec = forecast_darima(Theta = robjects.FloatVector(Theta), sigma2 = sigma2,
x = robjects.FloatVector(x), period = period,
h = h, level = level)
# Extract returns
#--------------------------------------
pred = robjects.FloatVector(forec.rx2("mean"))
lower = robjects.FloatVector(forec.rx2("lower"))
upper = robjects.FloatVector(forec.rx2("upper"))
# R object to python object
#--------------------------------------
pred = np.array(pred).reshape(h, 1) # h-by-1
lower = np.array(lower).reshape(h, 1) # h-by-1
upper = np.array(upper).reshape(h, 1) # h-by-1
# Out
#--------------------------------------
out_np = np.concatenate((pred, lower, upper),1) # h-by-3
out_pdf = pd.DataFrame(out_np,
columns=pd.Index(["pred", "lower", "upper"]))
out = out_pdf
if out.isna().values.any():
warnings.warn("NAs appear in the final output")
return out