In [1]:
import numpy as np
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from keras.models import Sequential
from keras.layers import Dense, SimpleRNN, Input, Activation, Dropout, Add, LSTM, GRU, RNN, Layer
from keras import backend as K
from keras.optimizers import Adam,SGD
import tensorflow as tf
from keras import Model, regularizers, activations, initializers
from keras.constraints import Constraint
import pickle

# load and save mat file
import h5py
import hdf5storage

In [2]:
with h5py.File('test_collected_data.mat', 'r') as file:
    U = file['U'][:]
    U = np.array(U)

    Y = file['Y'][:]
    Y = np.array(Y)

In [3]:
# get X and y data for training and testing

U_input = list()
Y_output = list()

for i in range(U.shape[0]-1):
    if i % 5 == 0:
        U_input.append(U[i:i+5, :])
        Y_output.append(Y[i:i+5, :])

In [4]:
U_input = np.array(U_input)
print(U_input.shape)
Y_output = np.array(Y_output)
print(Y_output.shape)

(4000, 5, 2)
(4000, 5, 2)


In [5]:
RNN_input = U_input
print("RNN_input shape is {}".format(RNN_input.shape))

RNN_input shape is (4000, 5, 2)


In [6]:
RNN_output = Y_output
print("RNN_output shape is {}".format(RNN_output.shape))

RNN_output shape is (4000, 5, 2)


In [7]:
class MyLRNNCell(tf.keras.layers.Layer):

    def __init__(self, units, eps=0.01, gamma=0.01, beta=0.8, alpha=1, **kwargs):
        self.units = units
        self.state_size = units
        self.I = tf.eye(units)
        self.eps = eps
        self.gamma = gamma
        self.beta = beta
        self.alpha = alpha
        super(MyLRNNCell, self).__init__(**kwargs)

    def build(self, input_shape):
        self.C = self.add_weight(shape=(self.units, self.units),
                                      initializer='random_normal',
                                      name='C',
                                      trainable=True)
        self.B = self.add_weight(shape=(self.units, self.units),
                                                initializer='random_normal',
                                                name='B',
                                                trainable=True)
        self.U = self.add_weight(shape=(input_shape[-1], self.units),
                                                initializer='random_normal',
                                                name='U',
                                                trainable=True)
        self.b = self.add_weight(shape=(self.units,),
                                    initializer='zeros',
                                    name='b',
                                    trainable=True)
        self.built = True

    def call(self, inputs, states):
        prev_h = states[0]

        A = self.beta * (self.B - tf.transpose(self.B)) + (1 - self.beta) * (self.B + tf.transpose(self.B)) - self.gamma * self.I
        W = self.beta * (self.C - tf.transpose(self.C)) + (1 - self.beta) * (self.C + tf.transpose(self.C)) - self.gamma * self.I

        h = prev_h + self.eps * self.alpha * K.dot(prev_h, A) + self.eps * tf.nn.tanh(K.dot(prev_h, W) + K.dot(inputs, self.U) + self.b)
        return h, [h]

    def get_config(self):
        config = super(MyLRNNCell, self).get_config()
        config.update({"units": self.units, "eps":self.eps, "gamma":self.gamma, "beta":self.beta, "alpha":self.alpha})
        return config

In [8]:
# set the seed for reproducibility
tf.random.set_seed(42)

# define parameter values
num_step = 5
num_dims = 2

# split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(RNN_input, RNN_output, test_size=0.1, random_state=123)

print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)

# define scalers for both X and y base on training data only
scaler_X = preprocessing.MinMaxScaler().fit(X_train.reshape(-1, num_dims))
scaler_y = preprocessing.MinMaxScaler().fit(y_train.reshape(-1, 2))

X_train = scaler_X.transform(X_train.reshape(-1, num_dims)).reshape(-1,num_step,num_dims)
X_test = scaler_X.transform(X_test.reshape(-1, num_dims)).reshape(-1,num_step,num_dims)
y_train = scaler_y.transform(y_train.reshape(-1,2)).reshape(-1,num_step,2)
y_test_normalized = scaler_y.transform(y_test.reshape(-1,2)).reshape(-1,num_step,2)

# LRNN model training and evaluation
model = Sequential()
model.add(RNN(MyLRNNCell(units=64),return_sequences=True))
model.add(RNN(MyLRNNCell(units=64),return_sequences=True))
model.add(Dense(2, activation='linear'))

model.compile(optimizer='adam', loss='mean_squared_error', metrics=[tf.keras.metrics.MeanSquaredError()])
history = model.fit(X_train, y_train, epochs=1000, batch_size=256, validation_split=0, verbose=2)

loss = model.evaluate(X_test, y_test_normalized, batch_size=256)
test_loss = loss[0]
print(test_loss)

model.summary()

(3600, 5, 2)
(400, 5, 2)
(3600, 5, 2)
(400, 5, 2)
Epoch 1/1000
15/15 - 2s - loss: 0.1995 - mean_squared_error: 0.1995 - 2s/epoch - 141ms/step
Epoch 2/1000
15/15 - 0s - loss: 0.1852 - mean_squared_error: 0.1852 - 109ms/epoch - 7ms/step
Epoch 3/1000
15/15 - 0s - loss: 0.1703 - mean_squared_error: 0.1703 - 94ms/epoch - 6ms/step
Epoch 4/1000
15/15 - 0s - loss: 0.1548 - mean_squared_error: 0.1548 - 94ms/epoch - 6ms/step
Epoch 5/1000
15/15 - 0s - loss: 0.1384 - mean_squared_error: 0.1384 - 94ms/epoch - 6ms/step
Epoch 6/1000
15/15 - 0s - loss: 0.1216 - mean_squared_error: 0.1216 - 92ms/epoch - 6ms/step
Epoch 7/1000
15/15 - 0s - loss: 0.1048 - mean_squared_error: 0.1048 - 109ms/epoch - 7ms/step
Epoch 8/1000
15/15 - 0s - loss: 0.0891 - mean_squared_error: 0.0891 - 109ms/epoch - 7ms/step
Epoch 9/1000
15/15 - 0s - loss: 0.0757 - mean_squared_error: 0.0757 - 109ms/epoch - 7ms/step
Epoch 10/1000
15/15 - 0s - loss: 0.0658 - mean_squared_error: 0.0658 - 94ms/epoch - 6ms/step
Epoch 11/1000
15/15 - 0s 

Epoch 89/1000
15/15 - 0s - loss: 0.0348 - mean_squared_error: 0.0348 - 86ms/epoch - 6ms/step
Epoch 90/1000
15/15 - 0s - loss: 0.0348 - mean_squared_error: 0.0348 - 72ms/epoch - 5ms/step
Epoch 91/1000
15/15 - 0s - loss: 0.0347 - mean_squared_error: 0.0347 - 83ms/epoch - 6ms/step
Epoch 92/1000
15/15 - 0s - loss: 0.0347 - mean_squared_error: 0.0347 - 79ms/epoch - 5ms/step
Epoch 93/1000
15/15 - 0s - loss: 0.0347 - mean_squared_error: 0.0347 - 78ms/epoch - 5ms/step
Epoch 94/1000
15/15 - 0s - loss: 0.0347 - mean_squared_error: 0.0347 - 74ms/epoch - 5ms/step
Epoch 95/1000
15/15 - 0s - loss: 0.0347 - mean_squared_error: 0.0347 - 82ms/epoch - 5ms/step
Epoch 96/1000
15/15 - 0s - loss: 0.0346 - mean_squared_error: 0.0346 - 76ms/epoch - 5ms/step
Epoch 97/1000
15/15 - 0s - loss: 0.0346 - mean_squared_error: 0.0346 - 77ms/epoch - 5ms/step
Epoch 98/1000
15/15 - 0s - loss: 0.0346 - mean_squared_error: 0.0346 - 78ms/epoch - 5ms/step
Epoch 99/1000
15/15 - 0s - loss: 0.0346 - mean_squared_error: 0.0346 -

Epoch 177/1000
15/15 - 0s - loss: 0.0253 - mean_squared_error: 0.0253 - 80ms/epoch - 5ms/step
Epoch 178/1000
15/15 - 0s - loss: 0.0252 - mean_squared_error: 0.0252 - 76ms/epoch - 5ms/step
Epoch 179/1000
15/15 - 0s - loss: 0.0250 - mean_squared_error: 0.0250 - 89ms/epoch - 6ms/step
Epoch 180/1000
15/15 - 0s - loss: 0.0248 - mean_squared_error: 0.0248 - 79ms/epoch - 5ms/step
Epoch 181/1000
15/15 - 0s - loss: 0.0247 - mean_squared_error: 0.0247 - 98ms/epoch - 7ms/step
Epoch 182/1000
15/15 - 0s - loss: 0.0245 - mean_squared_error: 0.0245 - 95ms/epoch - 6ms/step
Epoch 183/1000
15/15 - 0s - loss: 0.0243 - mean_squared_error: 0.0243 - 90ms/epoch - 6ms/step
Epoch 184/1000
15/15 - 0s - loss: 0.0241 - mean_squared_error: 0.0241 - 93ms/epoch - 6ms/step
Epoch 185/1000
15/15 - 0s - loss: 0.0239 - mean_squared_error: 0.0239 - 95ms/epoch - 6ms/step
Epoch 186/1000
15/15 - 0s - loss: 0.0237 - mean_squared_error: 0.0237 - 92ms/epoch - 6ms/step
Epoch 187/1000
15/15 - 0s - loss: 0.0235 - mean_squared_erro

Epoch 264/1000
15/15 - 0s - loss: 0.0112 - mean_squared_error: 0.0112 - 94ms/epoch - 6ms/step
Epoch 265/1000
15/15 - 0s - loss: 0.0112 - mean_squared_error: 0.0112 - 109ms/epoch - 7ms/step
Epoch 266/1000
15/15 - 0s - loss: 0.0111 - mean_squared_error: 0.0111 - 112ms/epoch - 7ms/step
Epoch 267/1000
15/15 - 0s - loss: 0.0110 - mean_squared_error: 0.0110 - 94ms/epoch - 6ms/step
Epoch 268/1000
15/15 - 0s - loss: 0.0109 - mean_squared_error: 0.0109 - 109ms/epoch - 7ms/step
Epoch 269/1000
15/15 - 0s - loss: 0.0108 - mean_squared_error: 0.0108 - 109ms/epoch - 7ms/step
Epoch 270/1000
15/15 - 0s - loss: 0.0107 - mean_squared_error: 0.0107 - 109ms/epoch - 7ms/step
Epoch 271/1000
15/15 - 0s - loss: 0.0107 - mean_squared_error: 0.0107 - 123ms/epoch - 8ms/step
Epoch 272/1000
15/15 - 0s - loss: 0.0106 - mean_squared_error: 0.0106 - 125ms/epoch - 8ms/step
Epoch 273/1000
15/15 - 0s - loss: 0.0105 - mean_squared_error: 0.0105 - 109ms/epoch - 7ms/step
Epoch 274/1000
15/15 - 0s - loss: 0.0104 - mean_squa

Epoch 351/1000
15/15 - 0s - loss: 0.0072 - mean_squared_error: 0.0072 - 109ms/epoch - 7ms/step
Epoch 352/1000
15/15 - 0s - loss: 0.0072 - mean_squared_error: 0.0072 - 94ms/epoch - 6ms/step
Epoch 353/1000
15/15 - 0s - loss: 0.0072 - mean_squared_error: 0.0072 - 125ms/epoch - 8ms/step
Epoch 354/1000
15/15 - 0s - loss: 0.0071 - mean_squared_error: 0.0071 - 125ms/epoch - 8ms/step
Epoch 355/1000
15/15 - 0s - loss: 0.0071 - mean_squared_error: 0.0071 - 107ms/epoch - 7ms/step
Epoch 356/1000
15/15 - 0s - loss: 0.0071 - mean_squared_error: 0.0071 - 127ms/epoch - 8ms/step
Epoch 357/1000
15/15 - 0s - loss: 0.0071 - mean_squared_error: 0.0071 - 103ms/epoch - 7ms/step
Epoch 358/1000
15/15 - 0s - loss: 0.0071 - mean_squared_error: 0.0071 - 104ms/epoch - 7ms/step
Epoch 359/1000
15/15 - 0s - loss: 0.0070 - mean_squared_error: 0.0070 - 98ms/epoch - 7ms/step
Epoch 360/1000
15/15 - 0s - loss: 0.0070 - mean_squared_error: 0.0070 - 109ms/epoch - 7ms/step
Epoch 361/1000
15/15 - 0s - loss: 0.0070 - mean_squa

Epoch 438/1000
15/15 - 0s - loss: 0.0064 - mean_squared_error: 0.0064 - 79ms/epoch - 5ms/step
Epoch 439/1000
15/15 - 0s - loss: 0.0064 - mean_squared_error: 0.0064 - 75ms/epoch - 5ms/step
Epoch 440/1000
15/15 - 0s - loss: 0.0064 - mean_squared_error: 0.0064 - 76ms/epoch - 5ms/step
Epoch 441/1000
15/15 - 0s - loss: 0.0064 - mean_squared_error: 0.0064 - 77ms/epoch - 5ms/step
Epoch 442/1000
15/15 - 0s - loss: 0.0064 - mean_squared_error: 0.0064 - 75ms/epoch - 5ms/step
Epoch 443/1000
15/15 - 0s - loss: 0.0064 - mean_squared_error: 0.0064 - 75ms/epoch - 5ms/step
Epoch 444/1000
15/15 - 0s - loss: 0.0064 - mean_squared_error: 0.0064 - 76ms/epoch - 5ms/step
Epoch 445/1000
15/15 - 0s - loss: 0.0063 - mean_squared_error: 0.0063 - 82ms/epoch - 5ms/step
Epoch 446/1000
15/15 - 0s - loss: 0.0063 - mean_squared_error: 0.0063 - 76ms/epoch - 5ms/step
Epoch 447/1000
15/15 - 0s - loss: 0.0063 - mean_squared_error: 0.0063 - 73ms/epoch - 5ms/step
Epoch 448/1000
15/15 - 0s - loss: 0.0063 - mean_squared_erro

Epoch 525/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 94ms/epoch - 6ms/step
Epoch 526/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 109ms/epoch - 7ms/step
Epoch 527/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 104ms/epoch - 7ms/step
Epoch 528/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 93ms/epoch - 6ms/step
Epoch 529/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 89ms/epoch - 6ms/step
Epoch 530/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 94ms/epoch - 6ms/step
Epoch 531/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 94ms/epoch - 6ms/step
Epoch 532/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 110ms/epoch - 7ms/step
Epoch 533/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 88ms/epoch - 6ms/step
Epoch 534/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 84ms/epoch - 6ms/step
Epoch 535/1000
15/15 - 0s - loss: 0.0062 - mean_squared_e

Epoch 612/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 81ms/epoch - 5ms/step
Epoch 613/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 86ms/epoch - 6ms/step
Epoch 614/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 87ms/epoch - 6ms/step
Epoch 615/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 84ms/epoch - 6ms/step
Epoch 616/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 90ms/epoch - 6ms/step
Epoch 617/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 93ms/epoch - 6ms/step
Epoch 618/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 88ms/epoch - 6ms/step
Epoch 619/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 85ms/epoch - 6ms/step
Epoch 620/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 102ms/epoch - 7ms/step
Epoch 621/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 82ms/epoch - 5ms/step
Epoch 622/1000
15/15 - 0s - loss: 0.0062 - mean_squared_err

15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 99ms/epoch - 7ms/step
Epoch 700/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 93ms/epoch - 6ms/step
Epoch 701/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 104ms/epoch - 7ms/step
Epoch 702/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 88ms/epoch - 6ms/step
Epoch 703/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 79ms/epoch - 5ms/step
Epoch 704/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 89ms/epoch - 6ms/step
Epoch 705/1000
15/15 - 0s - loss: 0.0062 - mean_squared_error: 0.0062 - 94ms/epoch - 6ms/step
Epoch 706/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 101ms/epoch - 7ms/step
Epoch 707/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 107ms/epoch - 7ms/step
Epoch 708/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 87ms/epoch - 6ms/step
Epoch 709/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 

15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 109ms/epoch - 7ms/step
Epoch 787/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 125ms/epoch - 8ms/step
Epoch 788/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 94ms/epoch - 6ms/step
Epoch 789/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 111ms/epoch - 7ms/step
Epoch 790/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 105ms/epoch - 7ms/step
Epoch 791/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 95ms/epoch - 6ms/step
Epoch 792/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 99ms/epoch - 7ms/step
Epoch 793/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 98ms/epoch - 7ms/step
Epoch 794/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 97ms/epoch - 6ms/step
Epoch 795/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 85ms/epoch - 6ms/step
Epoch 796/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 -

Epoch 873/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 84ms/epoch - 6ms/step
Epoch 874/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 89ms/epoch - 6ms/step
Epoch 875/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 103ms/epoch - 7ms/step
Epoch 876/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 99ms/epoch - 7ms/step
Epoch 877/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 84ms/epoch - 6ms/step
Epoch 878/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 117ms/epoch - 8ms/step
Epoch 879/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 99ms/epoch - 7ms/step
Epoch 880/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 101ms/epoch - 7ms/step
Epoch 881/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 84ms/epoch - 6ms/step
Epoch 882/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 109ms/epoch - 7ms/step
Epoch 883/1000
15/15 - 0s - loss: 0.0061 - mean_squared_

Epoch 960/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 83ms/epoch - 6ms/step
Epoch 961/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 99ms/epoch - 7ms/step
Epoch 962/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 100ms/epoch - 7ms/step
Epoch 963/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 102ms/epoch - 7ms/step
Epoch 964/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 98ms/epoch - 7ms/step
Epoch 965/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 100ms/epoch - 7ms/step
Epoch 966/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 99ms/epoch - 7ms/step
Epoch 967/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 95ms/epoch - 6ms/step
Epoch 968/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 93ms/epoch - 6ms/step
Epoch 969/1000
15/15 - 0s - loss: 0.0061 - mean_squared_error: 0.0061 - 100ms/epoch - 7ms/step
Epoch 970/1000
15/15 - 0s - loss: 0.0061 - mean_squared_

In [9]:
# save model and scalers
model.save('lrnn_energy.keras')
pickle.dump(scaler_X, open('lrnn_scaler_X', 'wb'))
pickle.dump(scaler_y, open('lrnn_scaler_y', 'wb'))