In [40]:
import numpy as np
from statsmodels.tsa.arima_model import ARIMA
from statsmodels.tsa.statespace.sarimax import SARIMAX
import statsmodels.api as sm
import time

In [62]:
train = np.load("../data/train.npy")
val = np.load("../data/val.npy")
mini = np.load("../data/mini_train.npy")

In [3]:
print(mini.shape)

(28, 100, 100)


In [25]:
p, d, q = 1, 1, 1

window_size = train.shape[-1]

print(f"window_size: {window_size}")
fitted_models = []
for i in range(window_size**2):
    start_time = time.time()

    print(f"training model: {i+1}/{window_size**2 + 1}")
    model = ARIMA(train[:, i // window_size, i % window_size], order=(p, d, q))
    model_fit = model.fit(disp=0)
    fitted_models.append(model_fit)
#     model_fit.save(f"temp/save_{i}.pickle")
    
    elapsed = time.time() - start_time
    print(f"took: {elapsed} seconds")
    
    break

window_size: 100
training model: 1/10001
took: 0.8920671939849854 seconds


In [39]:
model = fitted_models[0]
model.predict(start=10, end=20, exog=np.random.randn(100)).shape

(11,)

In [19]:
for i in range(8):
    start_time = time.time()
    print(f"loading model")
    saved = sm.load(f"temp/save_{i}.pickle")
    
    elapsed = time.time() - start_time
    print(f"took: {elapsed} seconds")

loading model
took: 0.0029108524322509766 seconds
loading model
took: 0.0020389556884765625 seconds
loading model
took: 0.0020127296447753906 seconds
loading model
took: 0.0020232200622558594 seconds
loading model
took: 0.0015790462493896484 seconds
loading model
took: 0.002247333526611328 seconds
loading model
took: 0.0019788742065429688 seconds
loading model
took: 0.0015597343444824219 seconds


In [None]:
def predict(coef, history):
    yhat = 0.0
    for i in range(1, len(coef)+1):
        yhat += coef[i-1] * history[-i]
    return yhat
 
def difference(dataset):
    diff = list()
    for i in range(1, len(dataset)):
        value = dataset[i] - dataset[i - 1]
        diff.append(value)
    return numpy.array(diff)
 
def predict(model_fit, orders, inputs, output_length)
# series = Series.from_csv('daily-minimum-temperatures.csv', header=0)
# X = series.values
# size = len(X) - 7
# train, test = X[0:size], X[size:]
# history = [x for x in train]
# predictions = list()
for t in range(len(test)):

    ar_coef, ma_coef = model_fit.arparams, model_fit.maparams
    resid = model_fit.resid
    for i in range(d):
        inputs = difference(inputs)
        
    
    yhat = history[-1] + predict(ar_coef, diff) + predict(ma_coef, resid)
    predictions.append(yhat)
    obs = test[t]
    history.append(obs)
    print('>predicted=%.3f, expected=%.3f' % (yhat, obs))
rmse = sqrt(mean_squared_error(test, predictions))
print('Test RMSE: %.3f' % rmse)

In [43]:
mod = SARIMAX(train[:50, 1, 1], order=(1,1,1))
mod_fit = mod.fit()



In [46]:
naujas = SARIMAX(train[50:100, 1, 1], order=(1,1,1))
naujas_fit = naujas.filter(mod_fit.params)

In [50]:
pred = naujas_fit.get_prediction(start=10, end=14, dynamic=True)

In [54]:
pred.predicted_mean

array([-0.45269765, -0.4576199 , -0.45877185, -0.45904144, -0.45910453])

In [92]:
naujas_fit.params

array([ 2.34030314e-01, -8.47683842e-01,  3.41470988e-05])

In [61]:
naujas_fit.save("temp.pickle", remove_data=True)

In [114]:
def nrmse(targets, predictions):
    targets = np.squeeze(targets)
    predictions = np.squeeze(predictions)

    error = predictions - targets
    mse = np.mean(np.square(error))
    rmse = np.sqrt(mse)
    return rmse / np.mean(targets)

def make_prediction(save_path, eval_data, order, input_start, input_size, output_size):
    grid_size = 100
    prediction = np.zeros((output_size, grid_size, grid_size))
    for x_coord in range(grid_size):
        print(f"x_coord: {x_coord}")
        for y_coord in range(grid_size):
            
            trained_model = sm.load(f"{save_path}/{x_coord}_{y_coord}.pickle")
            model = SARIMAX(eval_data[:, x_coord, y_coord], order=order)
            model_fit = model.filter(trained_model.params)
            
            prediction_wrapper = model_fit.get_prediction(start=input_start, 
                                     end=input_start + 3 + output_size - 1, dynamic=3)
            

            print(f"pred mean: {prediction_wrapper.predicted_mean}")
            prediction[:, x_coord, y_coord] += prediction_wrapper.predicted_mean[-output_size:]
#             print(prediction[:, x_coord, y_coord])
            
            
    return prediction

def evaluate(save_path, eval_data, order, output_size=12, 
             train_mean=67.61768898039853, train_std=132.47248595705986):
    input_size = order[0]  # p is input_size
    errors = []
    for i in range(len(eval_data) - output_size - input_size):
        predictions = make_prediction(save_path, eval_data, order, i, input_size, output_size)
        targets = eval_data[i:i+output_size]
        
        predictions = predictions * train_std + train_mean 
        targets = targets * train_std + train_mean
        
        error = nrmse(targets, predictions)
        print(f"error: {error}")
        errors.append(error)
        break
            
    print(f"mean error: {np.array(errors).mean()}")
    print(f"error std: {np.array(errors).std()}")
    return errors
    

            

In [115]:
save_path = "../results/arima/p0_d0_q1/saved_models"

errors = evaluate(save_path, val, order=(0, 0, 1))

x_coord: 0
pred mean: [ 0.         -0.23802954 -0.16230908 -0.23728833  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.2377447  -0.16217328 -0.23699141  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.23774444 -0.16217315 -0.23699089  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.23774445 -0.16217316 -0.23699091  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.23774095 -0.16217151 -0.23698401  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.24446742 -0.16567567 -0.24559828  0.          0.
  0.          0.  

pred mean: [ 0.         -0.03835317 -0.04297147 -0.03473885  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.03835317 -0.04297147 -0.03473885  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.0567731  -0.06021763 -0.05288835  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.06555032 -0.06422799 -0.06019779  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.06417521 -0.05055983 -0.0588178   0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.04259421 -0.02597966 -0.03379857  0.          0.
  0.          0.          0. 

pred mean: [ 0.         -0.23774446 -0.16217316 -0.23699093  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.23774397 -0.16217293 -0.23698997  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.24543017 -0.16617752 -0.24682695  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.24749908 -0.16725649 -0.24946085  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.24445204 -0.16449456 -0.2472571   0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.24400243 -0.16408751 -0.24693054  0.          0.
  0.          0.          0. 

pred mean: [ 0.         -0.03835317 -0.04297147 -0.03473885  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.03835317 -0.04297147 -0.03473885  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.03838959 -0.04294812 -0.03480554  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.04582037 -0.04119195 -0.04245089  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.0379366  -0.0237082  -0.02933666  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.03755484 -0.02030065 -0.02800071  0.          0.
  0.          0.          0. 

pred mean: [ 0.         -0.2382009  -0.16265262 -0.23726917  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.23729599 -0.16026973 -0.23687235  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.24671545 -0.16679159 -0.24846419  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.24749904 -0.16725647 -0.24946078  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.24749631 -0.16725521 -0.24945182  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.24749643 -0.16725527 -0.24945233  0.          0.
  0.          0.          0. 

pred mean: [ 0.         -0.04770516 -0.04860485 -0.04376893  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.03835317 -0.04297147 -0.03473885  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.01937665 -0.03778248 -0.00289168  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.04181302 -0.02995185 -0.0319021   0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.03755607 -0.02029986 -0.02800291  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.03755484 -0.02030065 -0.02800071  0.          0.
  0.          0.          0. 

pred mean: [ 0.         -0.24751197 -0.16334899 -0.24948007  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.24749303 -0.16699256 -0.24944868  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.24360997 -0.1645141  -0.24512734  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.24702381 -0.16687402 -0.24880611  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.24734825 -0.16705175 -0.24922271  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.2474967  -0.16725539 -0.24945607  0.          0.
  0.          0.          0. 

pred mean: [0.         0.15616804 0.0778125  0.10420737 0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.        ]
pred mean: [ 0.         -0.10403366 -0.10028773 -0.1133091   0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.12102144 -0.11175637 -0.12825613  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.12102144 -0.11175637 -0.12825613  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.12102144 -0.11175637 -0.12825613  0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.        ]
pred mean: [ 0.         -0.08074799 -0.09186007 -0.08876163  0.          0.
  0.          0.          0.          0.    

KeyboardInterrupt: 