-
Notifications
You must be signed in to change notification settings - Fork 3
/
model.py
249 lines (217 loc) · 10.8 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Add, Dense, Input, Lambda, Reshape, Subtract
from tensorflow.keras.models import Model
def mase_loss(y_true,y_pred,input,frequency):
"""
MASE loss as defined in "Scaled Errors" https://robjhyndman.com/papers/mase.pdf
:param forecast: Forecast values. Shape: batch, time_o
:param insample: Insample values. Shape: batch, time_i
:param outsample: Target values. Shape: batch, time_o
:param frequency: Frequency value
:return: Same shape array with error calculated for each time step
"""
mask=tf.cast(y_true,tf.bool)
mask=tf.cast(mask,tf.float32)
seas_diff=tf.abs(input[:-frequency] - input[frequency:])
scale =tf.reduce_mean(seas_diff)
return tf.reduce_mean(tf.abs(y_true - y_pred)/scale*mask)
# tf.config.threading.get_inter_op_parallelism_threads()
# from nbeats_keras.model import NBeatsNet
class NBeatsNet:
GENERIC_BLOCK = 'generic'
TREND_BLOCK = 'trend'
SEASONALITY_BLOCK = 'seasonality'
_BACKCAST = 'backcast'
_FORECAST = 'forecast'
def __init__(self,
input_dim=1,
output_dim=1,
exo_dim=0,
backcast_length=10,
forecast_length=1,
stack_types=(TREND_BLOCK, SEASONALITY_BLOCK),
nb_blocks_per_stack=3,
thetas_dim=(4, 8),
share_weights_in_stack=False,
hidden_layer_units=(256,256),
nb_harmonics=None,
use_mase=False,
mase_frequency=0,
):
self.stack_types = stack_types
self.nb_blocks_per_stack = nb_blocks_per_stack
self.thetas_dim = thetas_dim
self.units = hidden_layer_units
self.share_weights_in_stack = share_weights_in_stack
self.backcast_length = backcast_length
self.forecast_length = forecast_length
self.input_dim = input_dim
self.output_dim=output_dim
self.exo_dim = exo_dim
self.input_shape = (self.backcast_length, self.input_dim)
self.exo_shape = (self.backcast_length, self.exo_dim)
self.output_shape = (self.forecast_length, self.output_dim)
self.weights = {}
self.nb_harmonics = nb_harmonics
self.use_mase=use_mase
self.mase_frequency=mase_frequency
assert len(self.stack_types) == len(self.thetas_dim)
x = Input(shape=self.input_shape, name='input_variable')
x_ = {}
for k in range(self.input_dim):
x_[k] = Lambda(lambda z: z[..., k])(x)
e_ = {}
if self.has_exog():
e = Input(shape=self.exo_shape, name='exos_variables')
for k in range(self.exo_dim):
e_[k] = Lambda(lambda z: z[..., k])(e)
else:
e = None
y_ = {}
for stack_id in range(len(self.stack_types)):
stack_type = self.stack_types[stack_id]
nb_poly = self.thetas_dim[stack_id]
layer_size = self.units[stack_id]
for block_id in range(self.nb_blocks_per_stack):
backcast, forecast = self.create_block(x_, e_, stack_id, block_id, stack_type, nb_poly,layer_size)
for k in range(self.input_dim):
x_[k] = Subtract()([x_[k], backcast[k]])
if stack_id == 0 and block_id == 0:
y_[k] = forecast[k]
else:
y_[k] = Add()([y_[k], forecast[k]])
for k in range(self.input_dim):
y_[k] = Reshape(target_shape=(self.forecast_length, 1))(y_[k])
x_[k] = Reshape(target_shape=(self.backcast_length, 1))(x_[k])
if self.input_dim > 1:
y_ = Concatenate()([y_[ll] for ll in range(self.input_dim)])
x_ = Concatenate()([x_[ll] for ll in range(self.input_dim)])
else:
y_ = y_[0]
x_ = x_[0]
if self.input_dim != self.output_dim:
y_ = Dense(self.output_dim, activation='linear', name='reg_y')(y_)
x_ = Dense(self.output_dim, activation='linear', name='reg_x')(x_)
inputs_x = [x, e] if self.has_exog() else x
n_beats_forecast = Model(inputs_x, y_, name=self._FORECAST)
n_beats_backcast = Model(inputs_x, x_, name=self._BACKCAST)
if self.use_mase:
y = Input(shape=self.output_shape, name='target_variable')
n_beats_forecast = Model([x,y], y_, name=self._FORECAST)
n_beats_forecast.add_loss(mase_loss(y,y_,x,self.mase_frequency))
self.models = {model.name: model for model in [n_beats_backcast, n_beats_forecast]}
self.cast_type = self._FORECAST
def has_exog(self):
# exo/exog is short for 'exogenous variable', i.e. any input
# features other than the target time-series itself.
return self.exo_dim > 0
@staticmethod
def load(filepath, custom_objects=None, compile=True):
from tensorflow.keras.models import load_model
return load_model(filepath, custom_objects, compile)
def _r(self, layer_with_weights, stack_id):
# mechanism to restore weights when block share the same weights.
# only useful when share_weights_in_stack=True.
if self.share_weights_in_stack:
layer_name = layer_with_weights.name.split('/')[-1]
try:
reused_weights = self.weights[stack_id][layer_name]
return reused_weights
except KeyError:
pass
if stack_id not in self.weights:
self.weights[stack_id] = {}
self.weights[stack_id][layer_name] = layer_with_weights
return layer_with_weights
def create_block(self, x, e, stack_id, block_id, stack_type, nb_poly,units):
# register weights (useful when share_weights_in_stack=True)
def reg(layer):
return self._r(layer, stack_id)
# update name (useful when share_weights_in_stack=True)
def n(layer_name):
return '/'.join([str(stack_id), str(block_id), stack_type, layer_name])
backcast_ = {}
forecast_ = {}
d1 = reg(Dense(units, activation='relu', name=n('d1')))
d2 = reg(Dense(units, activation='relu', name=n('d2')))
d3 = reg(Dense(units, activation='relu', name=n('d3')))
d4 = reg(Dense(units, activation='relu', name=n('d4')))
if stack_type == 'generic':
theta_b = reg(Dense(nb_poly, activation='linear', use_bias=False, name=n('theta_b')))
theta_f = reg(Dense(nb_poly, activation='linear', use_bias=False, name=n('theta_f')))
backcast = reg(Dense(self.backcast_length, activation='linear', name=n('backcast')))
forecast = reg(Dense(self.forecast_length, activation='linear', name=n('forecast')))
elif stack_type == 'trend':
theta_f = theta_b = reg(Dense(nb_poly, activation='linear', use_bias=False, name=n('theta_f_b')))
backcast = Lambda(trend_model, arguments={'is_forecast': False, 'backcast_length': self.backcast_length,
'forecast_length': self.forecast_length})
forecast = Lambda(trend_model, arguments={'is_forecast': True, 'backcast_length': self.backcast_length,
'forecast_length': self.forecast_length})
else: # 'seasonality'
if self.nb_harmonics:
theta_size=4*int(self.nb_harmonics/2*self.forecast_length-self.nb_harmonics+1)
theta_b = reg(Dense(theta_size, activation='linear', use_bias=False, name=n('theta_b')))
else:
theta_b = reg(Dense(self.forecast_length, activation='linear', use_bias=False, name=n('theta_b')))
theta_f = reg(Dense(self.forecast_length, activation='linear', use_bias=False, name=n('theta_f')))
backcast = Lambda(seasonality_model,
arguments={'is_forecast': False, 'backcast_length': self.backcast_length,
'forecast_length': self.forecast_length})
forecast = Lambda(seasonality_model,
arguments={'is_forecast': True, 'backcast_length': self.backcast_length,
'forecast_length': self.forecast_length})
for k in range(self.input_dim):
if self.has_exog():
d0 = Concatenate()([x[k]] + [e[ll] for ll in range(self.exo_dim)])
else:
d0 = x[k]
d1_ = d1(d0)
d2_ = d2(d1_)
d3_ = d3(d2_)
d4_ = d4(d3_)
theta_f_ = theta_f(d4_)
theta_b_ = theta_b(d4_)
backcast_[k] = backcast(theta_b_)
forecast_[k] = forecast(theta_f_)
return backcast_, forecast_
def __getattr__(self, name):
# https://github.com/faif/python-patterns
# model.predict() instead of model.n_beats.predict()
# same for fit(), train_on_batch()...
attr = getattr(self.models[self._FORECAST], name)
if not callable(attr):
return attr
def wrapper(*args, **kwargs):
cast_type = self._FORECAST
if attr.__name__ == 'predict' and 'return_backcast' in kwargs and kwargs['return_backcast']:
del kwargs['return_backcast']
cast_type = self._BACKCAST
return getattr(self.models[cast_type], attr.__name__)(*args, **kwargs)
return wrapper
def linear_space(backcast_length, forecast_length, is_forecast=True):
# ls = K.arange(-float(backcast_length), float(forecast_length), 1) / forecast_length
# return ls[backcast_length:] if is_forecast else K.abs(K.reverse(ls[:backcast_length], axes=0))
horizon = forecast_length if is_forecast else backcast_length
return K.arange(0,horizon)/horizon
def seasonality_model(thetas, backcast_length, forecast_length, is_forecast):
p = thetas.get_shape().as_list()[-1]
p1, p2 = (p // 2, p // 2) if p % 2 == 0 else (p // 2, p // 2 + 1)
t = linear_space(backcast_length, forecast_length, is_forecast=is_forecast)
s1 = K.stack([K.cos(2 * np.pi * i * t) for i in range(p1)])
s2 = K.stack([K.sin(2 * np.pi * i * t) for i in range(p2)])
if p == 1:
s = s2
else:
s = K.concatenate([s1, s2], axis=0)
s = K.cast(s, np.float32)
return K.dot(thetas, s)
def trend_model(thetas, backcast_length, forecast_length, is_forecast):
p = thetas.shape[-1]
t = linear_space(backcast_length, forecast_length, is_forecast=is_forecast)
t = K.stack([t ** i for i in range(p)]) # p*backcast
t = K.cast(t, np.float32)
return K.dot(thetas, t) #batch size * backcast