In [1]:
import numpy as np
import os
np.random.seed(123)
from six.moves import cPickle

from keras import backend as K
from keras.models import Model
from keras.layers import Input, Dense, Flatten
from keras.layers import LSTM
from keras.layers import TimeDistributed
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras.optimizers import Adam

from prednet import PredNet
from data_utils import SequenceGenerator

Using TensorFlow backend.


In [2]:
WEIGHTS_DIR = './weights/'
DATA_DIR = '../data/'

In [3]:
save_model = True  # if weights will be saved
weights_file = os.path.join(WEIGHTS_DIR, 'prednet_weather_weights.hdf5')  # where weights will be saved
json_file = os.path.join(WEIGHTS_DIR, 'prednet_weather_model.json')

In [4]:
# Data files
train_file = os.path.join(DATA_DIR, 'x_train.hkl')
train_sources = os.path.join(DATA_DIR,  'sources_train.hkl')
val_file = os.path.join(DATA_DIR, 'x_val.hkl')
val_sources = os.path.join(DATA_DIR, 'sources_val.hkl')

In [5]:
# Training parameters
nb_epoch = 150
batch_size = 4
samples_per_epoch = 100
N_seq_val = 50  # number of sequences to use for validation

In [6]:
# Model parameters
n_channels, im_height, im_width = (7, 20, 40)
input_shape = (n_channels, im_height, im_width) if K.image_data_format() == 'channels_first' else (im_height, im_width, n_channels)
stack_sizes = (n_channels, 12, 24)
R_stack_sizes = stack_sizes
A_filt_sizes = (2, 2)
Ahat_filt_sizes = (2, 2, 2)
R_filt_sizes = (2, 2, 2)
layer_loss_weights = np.array([1., 0., 0.])  # weighting for each layer in final loss; "L_0" model:  [1, 0, 0, 0], "L_all": [1, 0.1, 0.1, 0.1]
layer_loss_weights = np.expand_dims(layer_loss_weights, 1)
nt = 24  # number of timesteps used for sequences in training
time_loss_weights = 1./ (nt - 1) * np.ones((nt,1))  # equally weight all timesteps except the first
time_loss_weights[0] = 0

In [7]:
prednet = PredNet(stack_sizes, R_stack_sizes,
                  A_filt_sizes, Ahat_filt_sizes, R_filt_sizes,
                  output_mode='error', return_sequences=True)

In [8]:
inputs = Input(shape=(nt,) + input_shape)
errors = prednet(inputs)  # errors will be (batch_size, nt, nb_layers)
errors_by_time = TimeDistributed(Dense(1, trainable=False), weights=[layer_loss_weights, np.zeros(1)], trainable=False)(errors)  # calculate weighted error by layer
errors_by_time = Flatten()(errors_by_time)  # will be (batch_size, nt)
final_errors = Dense(1, weights=[time_loss_weights, np.zeros(1)], trainable=False)(errors_by_time)  # weight errors by time
model = Model(inputs=inputs, outputs=final_errors)
model.compile(loss='mean_absolute_error', optimizer='adam')

In [9]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 24, 20, 40, 7)     0         
_________________________________________________________________
pred_net_1 (PredNet)         (None, 24, 3)             49167     
_________________________________________________________________
time_distributed_1 (TimeDist (None, 24, 1)             4         
_________________________________________________________________
flatten_1 (Flatten)          (None, 24)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 25        
Total params: 49,196
Trainable params: 49,167
Non-trainable params: 29
_________________________________________________________________


In [10]:
train_generator = SequenceGenerator(train_file, train_sources, nt, batch_size=batch_size, shuffle=False)
val_generator = SequenceGenerator(val_file, val_sources, nt, batch_size=batch_size, N_seq=N_seq_val)

In [11]:
lr_schedule = lambda epoch: 0.001 if epoch < 75 else 0.0001    # start with lr of 0.001 and then drop to 0.0001 after 75 epochs
callbacks = [LearningRateScheduler(lr_schedule)]
if save_model:
    if not os.path.exists(WEIGHTS_DIR): os.mkdir(WEIGHTS_DIR)
    callbacks.append(ModelCheckpoint(filepath=weights_file, monitor='val_loss', save_best_only=True))

In [12]:
history = model.fit_generator(train_generator, steps_per_epoch=(samples_per_epoch / batch_size), 
                              epochs=nb_epoch, callbacks=callbacks,
                              validation_data=val_generator, validation_steps=N_seq_val / batch_size,
                              verbose=2, workers=0)

Epoch 1/150
 - 18s - loss: 2.2654e-04 - val_loss: 1.6606e-04
Epoch 2/150
 - 13s - loss: 1.6084e-04 - val_loss: 1.6482e-04
Epoch 3/150
 - 13s - loss: 1.6346e-04 - val_loss: 1.5707e-04
Epoch 4/150
 - 14s - loss: 1.5891e-04 - val_loss: 1.5682e-04
Epoch 5/150
 - 15s - loss: 1.5410e-04 - val_loss: 1.4740e-04
Epoch 6/150
 - 13s - loss: 1.3540e-04 - val_loss: 1.4354e-04
Epoch 7/150
 - 14s - loss: 1.4448e-04 - val_loss: 1.2114e-04
Epoch 8/150
 - 13s - loss: 1.3061e-04 - val_loss: 1.5603e-04
Epoch 9/150
 - 13s - loss: 1.2790e-04 - val_loss: 1.6199e-04
Epoch 10/150
 - 13s - loss: 1.2041e-04 - val_loss: 1.2755e-04
Epoch 11/150
 - 13s - loss: 1.2834e-04 - val_loss: 1.3263e-04
Epoch 12/150
 - 13s - loss: 1.3211e-04 - val_loss: 1.3543e-04
Epoch 13/150
 - 13s - loss: 1.2578e-04 - val_loss: 1.2785e-04
Epoch 14/150
 - 14s - loss: 1.2164e-04 - val_loss: 1.1487e-04
Epoch 15/150
 - 13s - loss: 1.1284e-04 - val_loss: 1.2066e-04
Epoch 16/150
 - 13s - loss: 1.2309e-04 - val_loss: 1.1847e-04
Epoch 17/150
 - 1

Epoch 133/150
 - 13s - loss: 4.8813e-05 - val_loss: 5.2158e-05
Epoch 134/150
 - 13s - loss: 5.1009e-05 - val_loss: 5.2366e-05
Epoch 135/150
 - 13s - loss: 5.2787e-05 - val_loss: 5.2123e-05
Epoch 136/150
 - 13s - loss: 4.9104e-05 - val_loss: 5.2410e-05
Epoch 137/150
 - 14s - loss: 4.8515e-05 - val_loss: 5.2213e-05
Epoch 138/150
 - 13s - loss: 4.9778e-05 - val_loss: 5.1791e-05
Epoch 139/150
 - 13s - loss: 5.2401e-05 - val_loss: 5.1701e-05
Epoch 140/150
 - 13s - loss: 5.0269e-05 - val_loss: 5.2384e-05
Epoch 141/150
 - 13s - loss: 4.7956e-05 - val_loss: 5.3089e-05
Epoch 142/150
 - 13s - loss: 4.8717e-05 - val_loss: 5.1456e-05
Epoch 143/150
 - 13s - loss: 5.0973e-05 - val_loss: 5.1221e-05
Epoch 144/150
 - 13s - loss: 5.1292e-05 - val_loss: 5.1897e-05
Epoch 145/150
 - 13s - loss: 4.7986e-05 - val_loss: 5.1949e-05
Epoch 146/150
 - 13s - loss: 4.8139e-05 - val_loss: 5.1388e-05
Epoch 147/150
 - 13s - loss: 4.9055e-05 - val_loss: 5.1023e-05
Epoch 148/150
 - 13s - loss: 5.2175e-05 - val_loss: 5.0

In [13]:
if save_model:
    json_string = model.to_json()
    with open(json_file, "w") as f:
        f.write(json_string)