In [1]:
import numpy as np
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from keras.models import Sequential
from keras.layers import Dense, SimpleRNN, Input, Activation, Dropout, Add, LSTM, GRU, RNN, Layer
from keras import backend as K
from keras.optimizers import Adam,SGD
import tensorflow as tf
from keras import Model, regularizers, activations, initializers
from keras.constraints import Constraint
import pickle

# load and save mat file
import h5py
import hdf5storage

In [2]:
with h5py.File('test_collected_data.mat', 'r') as file:
    U = file['U'][:]
    U = np.array(U)

    Y = file['Y'][:]
    Y = np.array(Y)

In [3]:
# get X and y data for training and testing

U_input = list()
Y_output = list()

for i in range(U.shape[0]-1):
    if i % 5 == 0:
        U_input.append(U[i:i+5, :])
        Y_output.append(Y[i:i+5, :])

In [4]:
U_input = np.array(U_input)
print(U_input.shape)
Y_output = np.array(Y_output)
print(Y_output.shape)

(4000, 5, 2)
(4000, 5, 2)


In [5]:
RNN_input = np.concatenate((U_input, -U_input), axis=2)
print("RNN_input shape is {}".format(RNN_input.shape))

RNN_input shape is (4000, 5, 4)


In [6]:
RNN_output = Y_output
print("RNN_output shape is {}".format(RNN_output.shape))

RNN_output shape is (4000, 5, 2)


In [7]:
class MyRNNCell(tf.keras.layers.Layer):

    def __init__(self, units, input_shape_custom, **kwargs):
        self.units = units
        self.input_shape_custom = input_shape_custom
        self.state_size = [tf.TensorShape([units]), tf.TensorShape([input_shape_custom])]
        super(MyRNNCell, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                      initializer='uniform',
                                      name='kernel',
                                      constraint=tf.keras.constraints.NonNeg(),
                                      trainable=True)
        self.recurrent_kernel = self.add_weight(shape=(self.units, self.units),
                                                initializer='uniform',
                                                name='recurrent_kernel',
                                                constraint=tf.keras.constraints.NonNeg(),
                                                trainable=True)
        self.D1 = self.add_weight(shape=(self.units, self.units),
                                 initializer='uniform',
                                 name='D1',
                                 constraint=tf.keras.constraints.NonNeg(),
                                 trainable=True)
        self.D2 = self.add_weight(shape=(input_shape[-1], self.units),
                                 initializer='uniform',
                                 name='D2',
                                 constraint=tf.keras.constraints.NonNeg(),
                                 trainable=True)
        self.D3 = self.add_weight(shape=(input_shape[-1], self.units),
                                 initializer='uniform',
                                 name='D3',
                                 constraint=tf.keras.constraints.NonNeg(),
                                 trainable=True)
        self.V = self.add_weight(shape=(self.units, self.units),
                                 initializer='uniform',
                                 name='V',
                                 constraint=tf.keras.constraints.NonNeg(),
                                 trainable=True)
        self.built = True

    def call(self, inputs, states):
        # ICRNN
        prev_h, prev_input = states
        h = K.dot(inputs, self.kernel) + K.dot(prev_h, self.recurrent_kernel) + K.dot(prev_input, self.D2)
        h = tf.nn.relu(h)
        y = K.dot(h, self.V) + K.dot(prev_h, self.D1) + K.dot(inputs, self.D3)
        y = tf.nn.relu(y)
        return y, [h, inputs]

    def get_config(self):
        config = super(MyRNNCell, self).get_config()
        config.update({"units": self.units, "input_shape_custom": self.input_shape_custom})
        return config

In [8]:
# set the seed for reproducibility
tf.random.set_seed(42)

# define parameter values
num_step = 5
num_dims = 4

# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(RNN_input, RNN_output, test_size=0.1, random_state=123)

print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)

# define scalers for both X and y base on training data only
scaler_X = preprocessing.MinMaxScaler().fit(X_train.reshape(-1, num_dims))
scaler_y = preprocessing.MinMaxScaler().fit(y_train.reshape(-1, 2))

X_train = scaler_X.transform(X_train.reshape(-1, num_dims)).reshape(-1,num_step,num_dims)
X_test = scaler_X.transform(X_test.reshape(-1, num_dims)).reshape(-1,num_step,num_dims)
y_train = scaler_y.transform(y_train.reshape(-1,2)).reshape(-1,num_step,2)
y_test_normalized = scaler_y.transform(y_test.reshape(-1,2)).reshape(-1,num_step,2)

# ICRNN model training and evaluation
model = Sequential()
model.add(RNN(MyRNNCell(units=64, input_shape_custom=X_train.shape[2]), return_sequences=True))
model.add(RNN(MyRNNCell(units=64, input_shape_custom=64), return_sequences=True))
model.add(Dense(2, activation='relu', kernel_constraint=tf.keras.constraints.NonNeg()))

model.compile(optimizer='adam', loss='mean_squared_error', metrics=[tf.keras.metrics.MeanSquaredError()])
history = model.fit(X_train, y_train, epochs=1000, batch_size=256, validation_split=0, verbose=2)

loss = model.evaluate(X_test, y_test_normalized, batch_size=256)
test_loss = loss[0]
print(test_loss)

model.summary()

(3600, 5, 4)
(400, 5, 4)
(3600, 5, 2)
(400, 5, 2)
Epoch 1/1000
15/15 - 2s - loss: 3.4259 - mean_squared_error: 3.4259 - 2s/epoch - 133ms/step
Epoch 2/1000
15/15 - 0s - loss: 0.0932 - mean_squared_error: 0.0932 - 94ms/epoch - 6ms/step
Epoch 3/1000
15/15 - 0s - loss: 0.0870 - mean_squared_error: 0.0870 - 120ms/epoch - 8ms/step
Epoch 4/1000
15/15 - 0s - loss: 0.0878 - mean_squared_error: 0.0878 - 98ms/epoch - 7ms/step
Epoch 5/1000
15/15 - 0s - loss: 0.0806 - mean_squared_error: 0.0806 - 110ms/epoch - 7ms/step
Epoch 6/1000
15/15 - 0s - loss: 0.0729 - mean_squared_error: 0.0729 - 114ms/epoch - 8ms/step
Epoch 7/1000
15/15 - 0s - loss: 0.0682 - mean_squared_error: 0.0682 - 110ms/epoch - 7ms/step
Epoch 8/1000
15/15 - 0s - loss: 0.0656 - mean_squared_error: 0.0656 - 135ms/epoch - 9ms/step
Epoch 9/1000
15/15 - 0s - loss: 0.0636 - mean_squared_error: 0.0636 - 94ms/epoch - 6ms/step
Epoch 10/1000
15/15 - 0s - loss: 0.0618 - mean_squared_error: 0.0618 - 121ms/epoch - 8ms/step
Epoch 11/1000
15/15 - 0

15/15 - 0s - loss: 0.0133 - mean_squared_error: 0.0133 - 110ms/epoch - 7ms/step
Epoch 89/1000
15/15 - 0s - loss: 0.0133 - mean_squared_error: 0.0133 - 103ms/epoch - 7ms/step
Epoch 90/1000
15/15 - 0s - loss: 0.0132 - mean_squared_error: 0.0132 - 99ms/epoch - 7ms/step
Epoch 91/1000
15/15 - 0s - loss: 0.0132 - mean_squared_error: 0.0132 - 105ms/epoch - 7ms/step
Epoch 92/1000
15/15 - 0s - loss: 0.0132 - mean_squared_error: 0.0132 - 101ms/epoch - 7ms/step
Epoch 93/1000
15/15 - 0s - loss: 0.0131 - mean_squared_error: 0.0131 - 101ms/epoch - 7ms/step
Epoch 94/1000
15/15 - 0s - loss: 0.0131 - mean_squared_error: 0.0131 - 89ms/epoch - 6ms/step
Epoch 95/1000
15/15 - 0s - loss: 0.0131 - mean_squared_error: 0.0131 - 101ms/epoch - 7ms/step
Epoch 96/1000
15/15 - 0s - loss: 0.0130 - mean_squared_error: 0.0130 - 103ms/epoch - 7ms/step
Epoch 97/1000
15/15 - 0s - loss: 0.0130 - mean_squared_error: 0.0130 - 101ms/epoch - 7ms/step
Epoch 98/1000
15/15 - 0s - loss: 0.0130 - mean_squared_error: 0.0130 - 95ms/

Epoch 175/1000
15/15 - 0s - loss: 0.0117 - mean_squared_error: 0.0117 - 130ms/epoch - 9ms/step
Epoch 176/1000
15/15 - 0s - loss: 0.0117 - mean_squared_error: 0.0117 - 133ms/epoch - 9ms/step
Epoch 177/1000
15/15 - 0s - loss: 0.0116 - mean_squared_error: 0.0116 - 134ms/epoch - 9ms/step
Epoch 178/1000
15/15 - 0s - loss: 0.0116 - mean_squared_error: 0.0116 - 117ms/epoch - 8ms/step
Epoch 179/1000
15/15 - 0s - loss: 0.0116 - mean_squared_error: 0.0116 - 116ms/epoch - 8ms/step
Epoch 180/1000
15/15 - 0s - loss: 0.0116 - mean_squared_error: 0.0116 - 130ms/epoch - 9ms/step
Epoch 181/1000
15/15 - 0s - loss: 0.0116 - mean_squared_error: 0.0116 - 133ms/epoch - 9ms/step
Epoch 182/1000
15/15 - 0s - loss: 0.0116 - mean_squared_error: 0.0116 - 133ms/epoch - 9ms/step
Epoch 183/1000
15/15 - 0s - loss: 0.0115 - mean_squared_error: 0.0115 - 114ms/epoch - 8ms/step
Epoch 184/1000
15/15 - 0s - loss: 0.0115 - mean_squared_error: 0.0115 - 141ms/epoch - 9ms/step
Epoch 185/1000
15/15 - 0s - loss: 0.0115 - mean_sq

Epoch 262/1000
15/15 - 0s - loss: 0.0108 - mean_squared_error: 0.0108 - 124ms/epoch - 8ms/step
Epoch 263/1000
15/15 - 0s - loss: 0.0108 - mean_squared_error: 0.0108 - 131ms/epoch - 9ms/step
Epoch 264/1000
15/15 - 0s - loss: 0.0107 - mean_squared_error: 0.0107 - 117ms/epoch - 8ms/step
Epoch 265/1000
15/15 - 0s - loss: 0.0107 - mean_squared_error: 0.0107 - 130ms/epoch - 9ms/step
Epoch 266/1000
15/15 - 0s - loss: 0.0107 - mean_squared_error: 0.0107 - 126ms/epoch - 8ms/step
Epoch 267/1000
15/15 - 0s - loss: 0.0107 - mean_squared_error: 0.0107 - 135ms/epoch - 9ms/step
Epoch 268/1000
15/15 - 0s - loss: 0.0107 - mean_squared_error: 0.0107 - 137ms/epoch - 9ms/step
Epoch 269/1000
15/15 - 0s - loss: 0.0107 - mean_squared_error: 0.0107 - 139ms/epoch - 9ms/step
Epoch 270/1000
15/15 - 0s - loss: 0.0107 - mean_squared_error: 0.0107 - 119ms/epoch - 8ms/step
Epoch 271/1000
15/15 - 0s - loss: 0.0107 - mean_squared_error: 0.0107 - 129ms/epoch - 9ms/step
Epoch 272/1000
15/15 - 0s - loss: 0.0107 - mean_sq

15/15 - 0s - loss: 0.0106 - mean_squared_error: 0.0106 - 146ms/epoch - 10ms/step
Epoch 349/1000
15/15 - 0s - loss: 0.0106 - mean_squared_error: 0.0106 - 134ms/epoch - 9ms/step
Epoch 350/1000
15/15 - 0s - loss: 0.0106 - mean_squared_error: 0.0106 - 121ms/epoch - 8ms/step
Epoch 351/1000
15/15 - 0s - loss: 0.0106 - mean_squared_error: 0.0106 - 117ms/epoch - 8ms/step
Epoch 352/1000
15/15 - 0s - loss: 0.0106 - mean_squared_error: 0.0106 - 116ms/epoch - 8ms/step
Epoch 353/1000
15/15 - 0s - loss: 0.0106 - mean_squared_error: 0.0106 - 127ms/epoch - 8ms/step
Epoch 354/1000
15/15 - 0s - loss: 0.0106 - mean_squared_error: 0.0106 - 141ms/epoch - 9ms/step
Epoch 355/1000
15/15 - 0s - loss: 0.0106 - mean_squared_error: 0.0106 - 158ms/epoch - 11ms/step
Epoch 356/1000
15/15 - 0s - loss: 0.0106 - mean_squared_error: 0.0106 - 137ms/epoch - 9ms/step
Epoch 357/1000
15/15 - 0s - loss: 0.0106 - mean_squared_error: 0.0106 - 136ms/epoch - 9ms/step
Epoch 358/1000
15/15 - 0s - loss: 0.0106 - mean_squared_error: 

Epoch 435/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 151ms/epoch - 10ms/step
Epoch 436/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 136ms/epoch - 9ms/step
Epoch 437/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 125ms/epoch - 8ms/step
Epoch 438/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 141ms/epoch - 9ms/step
Epoch 439/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 150ms/epoch - 10ms/step
Epoch 440/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 125ms/epoch - 8ms/step
Epoch 441/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 141ms/epoch - 9ms/step
Epoch 442/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 156ms/epoch - 10ms/step
Epoch 443/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 143ms/epoch - 10ms/step
Epoch 444/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 124ms/epoch - 8ms/step
Epoch 445/1000
15/15 - 0s - loss: 0.0104 - mea

15/15 - 0s - loss: 0.0105 - mean_squared_error: 0.0105 - 125ms/epoch - 8ms/step
Epoch 522/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 129ms/epoch - 9ms/step
Epoch 523/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 156ms/epoch - 10ms/step
Epoch 524/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 141ms/epoch - 9ms/step
Epoch 525/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 141ms/epoch - 9ms/step
Epoch 526/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 156ms/epoch - 10ms/step
Epoch 527/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 125ms/epoch - 8ms/step
Epoch 528/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 109ms/epoch - 7ms/step
Epoch 529/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 125ms/epoch - 8ms/step
Epoch 530/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 116ms/epoch - 8ms/step
Epoch 531/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 

Epoch 608/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 122ms/epoch - 8ms/step
Epoch 609/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 111ms/epoch - 7ms/step
Epoch 610/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 101ms/epoch - 7ms/step
Epoch 611/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 100ms/epoch - 7ms/step
Epoch 612/1000
15/15 - 0s - loss: 0.0105 - mean_squared_error: 0.0105 - 105ms/epoch - 7ms/step
Epoch 613/1000
15/15 - 0s - loss: 0.0105 - mean_squared_error: 0.0105 - 92ms/epoch - 6ms/step
Epoch 614/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 109ms/epoch - 7ms/step
Epoch 615/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 91ms/epoch - 6ms/step
Epoch 616/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 109ms/epoch - 7ms/step
Epoch 617/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 109ms/epoch - 7ms/step
Epoch 618/1000
15/15 - 0s - loss: 0.0104 - mean_squa

Epoch 695/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 125ms/epoch - 8ms/step
Epoch 696/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 143ms/epoch - 10ms/step
Epoch 697/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 125ms/epoch - 8ms/step
Epoch 698/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 132ms/epoch - 9ms/step
Epoch 699/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 113ms/epoch - 8ms/step
Epoch 700/1000
15/15 - 0s - loss: 0.0105 - mean_squared_error: 0.0105 - 102ms/epoch - 7ms/step
Epoch 701/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 128ms/epoch - 9ms/step
Epoch 702/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 117ms/epoch - 8ms/step
Epoch 703/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 119ms/epoch - 8ms/step
Epoch 704/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 135ms/epoch - 9ms/step
Epoch 705/1000
15/15 - 0s - loss: 0.0104 - mean_s

Epoch 782/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 106ms/epoch - 7ms/step
Epoch 783/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 91ms/epoch - 6ms/step
Epoch 784/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 113ms/epoch - 8ms/step
Epoch 785/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 100ms/epoch - 7ms/step
Epoch 786/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 103ms/epoch - 7ms/step
Epoch 787/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 99ms/epoch - 7ms/step
Epoch 788/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 94ms/epoch - 6ms/step
Epoch 789/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 113ms/epoch - 8ms/step
Epoch 790/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 102ms/epoch - 7ms/step
Epoch 791/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 106ms/epoch - 7ms/step
Epoch 792/1000
15/15 - 0s - loss: 0.0104 - mean_squar

Epoch 869/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 109ms/epoch - 7ms/step
Epoch 870/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 97ms/epoch - 6ms/step
Epoch 871/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 109ms/epoch - 7ms/step
Epoch 872/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 91ms/epoch - 6ms/step
Epoch 873/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 111ms/epoch - 7ms/step
Epoch 874/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 103ms/epoch - 7ms/step
Epoch 875/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 103ms/epoch - 7ms/step
Epoch 876/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 95ms/epoch - 6ms/step
Epoch 877/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 111ms/epoch - 7ms/step
Epoch 878/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 95ms/epoch - 6ms/step
Epoch 879/1000
15/15 - 0s - loss: 0.0104 - mean_square

Epoch 956/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 130ms/epoch - 9ms/step
Epoch 957/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 126ms/epoch - 8ms/step
Epoch 958/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 154ms/epoch - 10ms/step
Epoch 959/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 143ms/epoch - 10ms/step
Epoch 960/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 126ms/epoch - 8ms/step
Epoch 961/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 153ms/epoch - 10ms/step
Epoch 962/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 142ms/epoch - 9ms/step
Epoch 963/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 110ms/epoch - 7ms/step
Epoch 964/1000
15/15 - 0s - loss: 0.0105 - mean_squared_error: 0.0105 - 118ms/epoch - 8ms/step
Epoch 965/1000
15/15 - 0s - loss: 0.0104 - mean_squared_error: 0.0104 - 111ms/epoch - 7ms/step
Epoch 966/1000
15/15 - 0s - loss: 0.0104 - mean

In [9]:
# save model and scalers
model.save('icrnn_energy.keras')
pickle.dump(scaler_X, open('icrnn_scaler_X', 'wb'))
pickle.dump(scaler_y, open('icrnn_scaler_y', 'wb'))