In [1]:
from tensorflow import keras
from rnn_models import lstm_1Dcnn_model, lstm_3Dcnn_model, gru_1Dcnn_model, gru_3Dcnn_model

In [2]:
# model parameters
n_output = 10 # 9 phonemes + 1 blank
n_input_time = 200 # 1 second length neural data at 200 Hz
n_input_channel = 111 # number of significant HG channels (of 128)
n_channel_x = 8 # number of electrode columns (8x16 grid = 128 channels)
n_channel_y = 16 # number of electrode rows (8x16 grid = 128 channels)
n_filters = 100
filter_size = 2
n_units = 500 # number of units in LSTM/GRU layers
reg_lambda = 1e-5

In [3]:
# training_model, inf_enc, inf_dec = lstm_1Dcnn_model(n_input_time, n_input_channel, n_output, n_filters, filter_size, n_units, reg_lambda)
# training_model, inf_enc, inf_dec = lstm_3Dcnn_model(n_input_time, n_channel_x, n_channel_y, n_output, n_filters, filter_size, n_units, reg_lambda)
# training_model, inf_enc, inf_dec = gru_1Dcnn_model(n_input_time, n_input_channel, n_output, n_filters, filter_size, n_units, reg_lambda)
training_model, inf_enc, inf_dec = gru_3Dcnn_model(n_input_time, n_channel_x, n_channel_y, n_output, n_filters, filter_size, n_units, reg_lambda)

In [4]:
print(training_model.summary())
print(training_model.layers[-1].summary())
print(inf_enc.summary())
print(inf_enc.layers[-1].summary())
print(inf_dec.summary())

Model: "training_model_final"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 8, 16, 200,  0           []                               
                                 1)]                                                              
                                                                                                  
 conv3d (Conv3D)                (None, 4, 8, 100, 1  900         ['input_1[0][0]']                
                                00)                                                               
                                                                                                  
 permute (Permute)              (None, 100, 4, 8, 1  0           ['conv3d[0][0]']                 
                                00)                                            

In [5]:
training_model.compile(optimizer='rmsprop', loss=keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy'])



In [16]:
from train.train import shuffle_weights
import numpy as np

initial_train_weights = training_model.layers[-1].layers[2].get_weights()
initial_enc_weights = inf_enc.layers[-1].layers[-1].get_weights()
print('##### Initial Training Weights #####')
for w in range(len(initial_train_weights)):
    print(np.array_equal(initial_train_weights[w], initial_enc_weights[w]))
# print(np.array_equal(initial_train_weights[0], initial_enc_weights[0]))

shuffle_weights(training_model.layers[-1].layers[2])
shuf_train_weights = training_model.layers[-1].layers[2].get_weights()
shuf_enc_weights = inf_enc.layers[-1].layers[-1].get_weights()
print('##### Control: Shuffled and Initial Weights Not Equal #####')
for w in range(len(initial_train_weights)):
    # print(np.array_equal(shuf_train_weights[w], initial_train_weights[w]) and
    #       np.array_equal(shuf_enc_weights[w], initial_enc_weights[w]))
    # print(np.array_equal(shuf_train_weights[w], initial_train_weights[w]))
    print(np.array_equal(shuf_enc_weights[w], initial_enc_weights[w]))


print('##### Shuffled Training Weights #####')
for w in range(len(initial_train_weights)):
    print(np.array_equal(shuf_train_weights[w], shuf_enc_weights[w]))

##### Initial Training Weights #####
True
True
True
##### Control: Shuffled and Initial Weights Not Equal #####
False
False
True
##### Shuffled Training Weights #####
True
True
True


[[ 0.01910936 -0.00182319  0.00090101 ...  0.01409642 -0.02776063
  -0.00903148]
 [ 0.0165459  -0.00401019  0.00919852 ...  0.0302639  -0.01348658
   0.01896115]
 [-0.0254625   0.01318161  0.03321907 ...  0.01105202 -0.01236829
   0.02774541]
 ...
 [-0.00880198  0.00288168 -0.00499636 ...  0.0319857   0.03192992
   0.02508578]
 [-0.00305493  0.01437064 -0.0082486  ... -0.01029604  0.01168732
   0.01807649]
 [ 0.02930294 -0.01737107 -0.02468666 ... -0.00814121 -0.03346367
   0.01907596]]
