# Tuning Envelope Detector net and best model weights interpretation

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import json
import tensorflow as tf
import numpy as np
# Use this or tf random seed for train ds in singe expriment
np.random.seed(42)

In [None]:
import json
#Loading config file
with open('config.json') as config_file:
    config = json.load(config_file)

In [None]:
N_CHANNELS = config["n_channels"]

N_CLASSES = config["n_classes"]


ECOG_INT_LEN = config["n_points"]


DATA_DIR = config["data_dir"]


X = np.load(os.path.join(DATA_DIR, config["x_file"]), allow_pickle=True)

Y = np.load(os.path.join(DATA_DIR,  config["y_file"]), allow_pickle=True)


#number of examples
split = config["val_split"]

y_tr = np.arange(0, Y.shape[0])

x_tr = y_tr[:int(Y.shape[0]*split)]

y_val = y_tr[int(Y.shape[0]*split):]



def read_one(idx, training=True):
    
    y_f = Y[idx]
    
    x_f = X[idx]
    # in case of data augmentation
    if training:
        # ...
        pass  
          
    return x_f.astype("float32"),y_f.astype("int32")
    
def preprocess(idx):
    
    spec, audio = tf.numpy_function(read_one, [idx, False], [tf.float32, tf.int32])
    
    return spec, audio


def preprocess_val(idx):
    
    spec, audio = tf.numpy_function(read_one, [idx, False], [tf.float32, tf.int32])
    
    return spec, audio

    
train_dataset = tf.data.Dataset.from_tensor_slices((x_tr,))
train_dataset = train_dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)

test_dataset = tf.data.Dataset.from_tensor_slices((y_val,))
test_dataset = test_dataset.map(preprocess_val, num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
from tensorflow.keras import layers, Input
from tensorflow import keras
import keras_tuner
import tensorflow_addons as tfa


LR_RATE = config["learning_rate"]
weight_decay = config["weight_decay"]

optimizer = tfa.optimizers.AdamW(learning_rate=LR_RATE, weight_decay=weight_decay)

lfn = tf.keras.losses.SparseCategoricalCrossentropy()

msca = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')


#envelope detector net
def ed_net(input_shape, n_branches, lstm_units, filtering_size = 25, envelope_size = 15):
    DOWNSAMPLING =  10
    FILTERING_SIZE = filtering_size
    ENVELOPE_SIZE = envelope_size
    
    inputs = keras.Input(shape=input_shape)
    
    x = layers.Conv1D(n_branches, 1, padding="same")(inputs)
    x = layers.BatchNormalization(center=False, scale=False)(x)
    
    x = layers.Conv1D(n_branches, FILTERING_SIZE, padding="same", groups=n_branches, use_bias = False)(x)
    x = layers.BatchNormalization(center=False, scale=False)(x)
    x = layers.LeakyReLU(-1)(x)
    
    x = layers.Conv1D(n_branches, ENVELOPE_SIZE, padding="same",  groups=n_branches, use_bias = False)(x)
    
    x = x[:,::DOWNSAMPLING,:]
    
    x = layers.Bidirectional(layers.LSTM(lstm_units//2))(x)
    
    x = layers.BatchNormalization(center=False, scale=False)(x)
    
    outputs = layers.Dense(N_CLASSES, activation='softmax')(x)
    
    model = keras.Model(inputs, outputs, name=f'ed_net_b_{n_branches}_l_{lstm_units}_f_{filtering_size}_e_{envelope_size}')
    
    model.compile(optimizer=optimizer, loss= lfn, metrics=msca)
    
    return model


def build_model(hp):
    
    # defauls optimals
    
    n_branches = 64 
    lstm_units = 32
    
    f_size = 25
    e_size = 15
    
    if config["tune_n_branches"] :
    
        n_branches = hp.Int("n_branches", min_value=config["branches_min"], max_value=config["branches_max"], step=config["branches_step"])
        
    if config["tune_n_lstm"] :
    
        lstm_units = hp.Int("lstm_units", min_value=config["lstm_units_min"], max_value=config["lstm_units_max"], step=config["lstm_units_step"])
 

    if config["tune_filtering_size"] :
        
        f_size = hp.Int("f_size", min_value=config["filtering_min"], max_value=config["filtering_max"], step=config["filtering_max"])
        

    if config["tune_envelope_size"] :
    
        e_size = hp.Int("e_size", min_value=config["envelope_min"], max_value=config["envelope_max"], step=config["envelope_step"])
    
    
    # call existing model-building code with the hyperparameter values.
    model = ed_net((ECOG_INT_LEN, N_CHANNELS), n_branches = n_branches, lstm_units = lstm_units, filtering_size = f_size, envelope_size = e_size)
    
    return model


build_model(keras_tuner.HyperParameters())


In [None]:
#training & tuning

tuner = keras_tuner.RandomSearch(
    hypermodel=build_model,
    objective="val_accuracy", # modify here for you project and in config.json
    max_trials=config["n_trials"],
    executions_per_trial=config["n_runs_per_trial"],
    overwrite=True,
    directory=config["tuner_directory"],
    project_name=config["project_name"],
)

BATCH_SIZE = config["batch_size"]
N_EPOCH = config["n_epoch"]

tuner.search(train_dataset.shuffle(200).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE), 
              validation_data =test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE),
            epochs=N_EPOCH
            )



In [None]:
tuner.results_summary()

In [None]:
models = tuner.get_best_models(num_models=2)

best_model = models[0]
# Build the best model.

best_model.build(input_shape=(ECOG_INT_LEN, N_CHANNELS))
best_model.summary()



# Model weights interpretation

In [None]:
# Preparing data

model =  keras.Model(best_model.inputs, [best_model.outputs, best_model.layers[2]],

X_unmixed = []

feature_grads = []

for b_id, (x_val, y_val) in enumerate(train_dataset):
    
    with tf.GradientTape() as tape:
        
        val_logits, unmix = model(x_val, training=False)
        
    grad = tape.gradient(val_logits, model.layers[6].variables)
    
    feature_grads.append(grad[0].numpy())
    
    X_unmixed.append(unmix.numpy())
    
X_unmixed = np.concatenate(X_unmixed, axis=0)


In [None]:
#Interpreting ... 
import sklearn
from sklearn.preprocessing import minmax_scale
import scipy
import matplotlib.pyplot as plt

spat_l =  model.layers[1].weights[0]

temp_l =  model.layers[3].weights[0]


NPERSEQ = 500

HIDDEN_CHANNELS = 32 # 

N_TOP_BRANCHES = 5

FREQUENCY = 1000

COMPARISON_TOLERANCE = 10

ZOOM = 125


def get_spatial_patterns(X, temporal_weights, spatial_weights):
    
    
    patterns = np.zeros((spatial_weights.shape[1], spatial_weights.shape[0]))
    
    for i in range(temporal_weights.shape[2]): # == n_hidden ch 
        
        X_filtered = np.zeros(X.shape)
        
        for j in range(X_filtered.shape[1]):
            
            X_filtered[:, j] = np.convolve(X[:, j], temporal_weights[:, 0, i], mode="same")
            
        patterns[i, :] = np.dot(np.cov(X_filtered, rowvar=False), spatial_weights[:, i].reshape((-1, 1)))[:, 0]
        
    return patterns

def get_freq_domain(signal, frequency):
    n = NPERSEQ
    amplitude = np.abs(np.fft.fft(signal,n))
    frequencies = np.fft.fftfreq(n , 1 / frequency)
    assert len(amplitude) == len(frequencies), f"{len(amplitude)}!={len(frequencies)}"
    end = int(len(frequencies)/2)
    return frequencies[:end], amplitude[:end]



convs_weights = temp_l

interpret_patterns = get_spatial_patterns(
    np.copy(X_test[10000:510000, :]),
    convs_weights,
    np.repeat(np.copy(spat_l[0]), 1, axis=1),
    ).reshape(X_train.shape[1], HIDDEN_CHANNELS, 1)



assert interpret_patterns.shape[2] == 1 #


importance_weights = np.sum(np.sum(np.abs(feature_grads).squeeze(), axis=0), axis=0)
importance_weights = importance_weights / sum(importance_weights)

importance_indexes = np.argsort(-importance_weights)


assert len(importance_indexes) == interpret_patterns.shape[1]

X_unmixed = X_unmixed.reshape(-1, HIDDEN_CHANNELS)

for i in importance_indexes[:N_TOP_BRANCHES]:

    FINAL_FUGURE, FINAL_AXIS = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 2],'height_ratios': [1]})
    
    FINAL_FUGURE.set_figwidth(15)
    FINAL_FUGURE.set_figheight(3)
    
    plt.rc('font', family='serif', size=14)
    plt.rc('ytick', labelsize=14)
    
    plt.setp(FINAL_AXIS[0], title=f'Branch {i}, ({str(round(importance_weights[i], 2))})')

    plt.setp(FINAL_AXIS[1], xlabel='Frequency, Hz')

    FINAL_AXIS[1].set_title("Frequency domain profiles")

    weights = convs_weights[:, 0, i].numpy().T

    frequencies_input, spectrum_input = scipy.signal.welch(X_unmixed[:, i], FREQUENCY, nperseg=NPERSEQ, detrend='linear')
    frequencies_input = frequencies_input[:-1]
    spectrum_input = spectrum_input[:-1]

    frequencies, amplitude = get_freq_domain(weights, FREQUENCY)

    assert len(frequencies_input) == len(frequencies), f"{len(frequencies_input)}!={len(frequencies)}"
    assert(
        list(np.round(frequencies_input, COMPARISON_TOLERANCE)) ==\
        list(np.round(frequencies, COMPARISON_TOLERANCE))
    )

    recovered = np.power(sklearn.preprocessing.minmax_scale(amplitude), 1) * sklearn.preprocessing.minmax_scale(spectrum_input)
    out_spectrum = np.power(sklearn.preprocessing.minmax_scale(amplitude), 2) * sklearn.preprocessing.minmax_scale(spectrum_input)

    frequencies, amplitude = get_freq_domain(weights, FREQUENCY)

    figure = FINAL_AXIS[1]

    figure.plot(frequencies_input[:ZOOM], sklearn.preprocessing.minmax_scale(spectrum_input)[:ZOOM] * 5, label='Input')
    figure.plot(frequencies[:ZOOM], sklearn.preprocessing.minmax_scale(recovered)[:ZOOM], label='Patterns')
    figure.plot(frequencies[:ZOOM], sklearn.preprocessing.minmax_scale(amplitude)[:ZOOM], label = 'Weights')
    figure.plot(frequencies[:ZOOM], sklearn.preprocessing.minmax_scale(out_spectrum)[:ZOOM], label='Out')
    figure.grid()
    figure.axis(ymin=0, ymax=1)
    figure.legend(bbox_to_anchor=(1, -0.35), ncol=4)

    figure = FINAL_AXIS[0]

    spatial_new_plot_patterns = np.abs(interpret_patterns[:, i, 0]).transpose()
    figure.bar(np.arange(1, ECOG_N_CH+1), spatial_new_plot_patterns / np.sum(spatial_new_plot_patterns))
    plt.setp(figure, xlabel='Channel')
    plt.setp(figure, ylabel='Importance')


FINAL_FUGURE.show()

