In [None]:
import os
import sys
import glob
import pickle
from tqdm import tqdm

import numpy as np
from scipy.signal import butter, filtfilt, hilbert
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
from tensorflow import keras
from comet_ml.api import API, APIExperiment

if '..' not in sys.path:
    sys.path.append('..')
from dlml.utils import collect_experiments
from dlml.data import load_data_areas
from dlml.nn import compute_receptive_field, compute_correlations

%matplotlib inline

In [None]:
def my_print(msg, fd=sys.stdout):
    fd.write(msg)
    fd.flush()

#### Make sure that we have the model requested by the user

In [None]:
# experiment_ID = '98475b819ecb4d569646d7e1467d7c9c'
experiment_ID = '474d2016e33b441889ce8b17531487cb'
experiments_path = '../experiments/neural_network'
model_dir = os.path.join(experiments_path, experiment_ID)
if not os.path.isdir(model_dir):
    raise Exception(f'{model_dir}: no such directory')
network_parameters = pickle.load(open(os.path.join(model_dir, 'parameters.pkl'), 'rb'))
low_high = network_parameters['low_high'] if 'low_high' in network_parameters else False
binary_classification = network_parameters['loss_function']['name'].lower() == 'binarycrossentropy'
if 'use_fft' in network_parameters and network_parameters['use_fft']:
    raise Exception('This script assumes that the input data be in the time domain')

#### Get some info about the model

In [None]:
api = API(api_key = os.environ['COMET_API_KEY'])
experiment = api.get_experiment('danielelinaro', 'inertia', experiment_ID)
sys.stdout.write(f'Getting metrics for experiment {experiment_ID[:6]}... ')
sys.stdout.flush()
metrics = experiment.get_metrics()
sys.stdout.write('done.\n')
val_loss = []
for m in metrics:
    if m['metricName'] == 'val_loss':
        val_loss.append(float(m['metricValue']))
    elif m['metricName'] == 'mape_prediction':
        MAPE = float(m['metricValue'])
val_loss = np.array(val_loss)

#### Load the model

In [None]:
try:
    pooling_type = network_parameters['model_arch']['pooling_type']
except:
    pooling_type = ''
checkpoint_path = os.path.join(model_dir, 'checkpoints')
checkpoint_files = glob.glob(checkpoint_path + '/*.h5')
try:
    epochs = [int(os.path.split(file)[-1].split('.')[1].split('-')[0]) for file in checkpoint_files]
    best_checkpoint = checkpoint_files[epochs.index(np.argmin(val_loss) + 1)]
except:
    best_checkpoint = checkpoint_files[-1]
try:
    model = keras.models.load_model(best_checkpoint)
    custom_objects = None
except:
    if pooling_type == 'downsample':
        from dlml.nn import DownSampling1D
        custom_objects = {'DownSampling1D': DownSampling1D}
    elif pooling_type == 'spectral':
        from dlml.nn import SpectralPooling
        custom_objects = {'SpectralPooling': SpectralPooling}
    elif pooling_type == 'argmax':
        from dlml.nn import MaxPooling1DWithArgmax
        custom_objects = {'MaxPooling1DWithArgmax': MaxPooling1DWithArgmax}
    with keras.utils.custom_object_scope(custom_objects):
        model = keras.models.load_model(best_checkpoint)

if pooling_type == 'argmax':
    for layer in model.layers:
        if isinstance(layer, MaxPooling1DWithArgmax):
            print(f'Setting store_argmax = True for layer "{layer.name}".')
            layer.store_argmax = True
x_train_mean = network_parameters['x_train_mean']
x_train_std  = network_parameters['x_train_std']
var_names = network_parameters['var_names']
print(f'Loaded network from {best_checkpoint}.')
print(f'Variable names: {var_names}')

model.summary()

### Compute effective receptive field size and stride

In [None]:
stop_layer = None
if stop_layer is None:
    effective_RF_size,effective_stride = compute_receptive_field(model, stop_layer=keras.layers.Flatten,
                                                                 include_stop_layer=False)
else:
    effective_RF_size,effective_stride = compute_receptive_field(model, stop_layer=stop_layer,
                                                                 include_stop_layer=True)
print('Effective receptive field size:')
for i,(k,v) in enumerate(effective_RF_size.items()):
    print(f'{i}. {k} ' + '.' * (20 - len(k)) + ' {:d}'.format(v))
print()
print('Effective stride:')
for i,(k,v) in enumerate(effective_stride.items()):
    print(f'{i}. {k} ' + '.' * (20 - len(k)) + ' {:d}'.format(v))

### Load the data set

In [None]:
set_name = 'test'
data_dirs = []
for area_ID,data_dir in zip(network_parameters['area_IDs'], network_parameters['data_dirs']):
    data_dirs.append(os.path.join('..', data_dir.format(area_ID)))
data_dir = data_dirs[0]
data_files = sorted(glob.glob(data_dir + os.path.sep + f'*_{set_name}_set.h5'))
ret = load_data_areas({set_name: data_files}, network_parameters['var_names'],
                      network_parameters['generators_areas_map'][:1],
                      network_parameters['generators_Pnom'],
                      network_parameters['area_measure'],
                      trial_dur=network_parameters['trial_duration'],
                      max_block_size=1000,
                      use_tf=False, add_omega_ref=True,
                      use_fft=False)

t = ret[0]
X = [(ret[1][set_name][i] - m) / s for i,(m,s) in enumerate(zip(x_train_mean, x_train_std))]
y = ret[2][set_name]
ds = 10
X[0] = X[0][::ds,:]
y = y[::ds]

if binary_classification:
    IDX = [np.where(y < y.mean())[0], np.where(y > y.mean())[0]]
    n_mom_values = len(IDX)
    y[IDX[0]] = 0
    y[IDX[1]] = 1
    classes = [np.round(tf.keras.activations.sigmoid(model.predict(X[0][jdx]))) for jdx in IDX]
    _,_,accuracy = model.evaluate(tf.squeeze(X[0]), y, verbose=0)
    print(f'Prediction accuracy (with optimized weights): {accuracy*100:.2f}%.')
else:
    if low_high:
        below,_ = np.where(y < y.mean())
        above,_ = np.where(y > y.mean())
        y[below] = y[below].mean()
        y[above] = y[above].mean()
    ### Predict the momentum using the model
    IDX = [np.where(y == mom)[0] for mom in np.unique(y)]
    n_mom_values = len(IDX)
    momentum = [np.squeeze(model.predict(X[0][jdx])) for jdx in IDX]
    mean_momentum = [m.mean() for m in momentum]
    stddev_momentum = [m.std() for m in momentum]
    print('Mean momentum (with optimized weights):', mean_momentum)
    print(' Std momentum (with optimized weights):', stddev_momentum)

### Clone the trained model
This initializes the cloned model with new random weights and will be used in the
following as a control for the correlation analysis

In [None]:
reinit_model = model.__class__.from_config(model.get_config(), custom_objects)
if custom_objects is not None:
    # we have some subclassed layers
    for i in range(len(model.layers)):
        reinit_model.layers[i]._name = model.layers[i].name
if binary_classification:
    reinit_model.compile(metrics=['binary_crossentropy', 'acc'])
    reinit_classes = [np.round(tf.keras.activations.sigmoid(reinit_model.predict(X[0][jdx]))) for jdx in IDX]
    _,_,reinit_accuracy = reinit_model.evaluate(tf.squeeze(X[0]), y, verbose=0)
    print(f'Prediction accuracy (with random weights): {reinit_accuracy*100:.2f}%.')
else:
    reinit_momentum = [np.squeeze(reinit_model.predict(X[0][jdx])) for jdx in IDX]
    mean_reinit_momentum = [m.mean() for m in reinit_momentum]
    stddev_reinit_momentum = [m.std() for m in reinit_momentum]
    print('Mean momentum (with random weights):', mean_reinit_momentum)
    print(' Std momentum (with random weights):', stddev_reinit_momentum)

### Build a model with as many outputs as there are convolutional or pooling layers

Also, build a control model with the same (multiple-output) architecture as the previous one but random weights:

In [None]:
outputs = [layer.output for layer in model.layers \
           if layer.name in effective_RF_size.keys() and not isinstance(layer, keras.layers.InputLayer)]
multi_output_model = keras.Model(inputs=model.inputs, outputs=outputs)

ctrl_outputs = [layer.output for layer in reinit_model.layers \
                if layer.name in effective_RF_size.keys() and not isinstance(layer, keras.layers.InputLayer)]
ctrl_model = keras.Model(inputs=reinit_model.inputs, outputs=ctrl_outputs)
print(f'The model has {len(outputs)} outputs, corresponding to the following layers:')
for i,layer in enumerate(multi_output_model.layers):
    if not isinstance(layer, keras.layers.InputLayer):
        print(f'    {i}. {layer.name}')

### Correlations in the actual model
Define some variables used here and for the control model below

In [None]:
spacing = 'log'
N_bands = 20
filter_order = 6
verbose = True

dt = np.diff(t[:2])[0]
fs = np.round(1/dt)
if spacing == 'lin':
    edges = np.linspace(0.05, 0.5/dt, N_bands+1)
else:
    edges = np.logspace(np.log10(0.05), np.log10(0.5/dt), N_bands+1)
edges_ctrl = edges
bands = [[a,b] for a,b in zip(edges[:-1], edges[1:])]
N_bands = len(bands)
_, N_neurons, N_filters = multi_output_model.layers[-1].output.shape
N_trials = X[0].shape[0]
output_file = os.path.join(model_dir,
                           f'correlations_{experiment_ID[:6]}_{N_bands}-bands_' + \
                           f'{N_filters}-filters_{N_neurons}-neurons_{N_trials}-trials_' + \
                           f'{filter_order}-butter_{multi_output_model.layers[-1].name}')

#### Filter the input in various frequency bands

In [None]:
N_trials, N_samples = X[0].shape
N_bands = len(bands)
X_filt = np.zeros((N_bands, N_trials, N_samples))
if verbose: my_print(f'Filtering the input in {N_bands} frequency bands... ')
for i in tqdm(range(N_bands)):
    b,a = butter(filter_order//2, bands[i], 'bandpass', fs=fs)
    X_filt[i,:,:] = filtfilt(b, a, X[0])
if verbose: print('done.')

#### Compute the envelope of the filtered signal

In [None]:
if verbose: my_print(f'Computing the envelope of the filtered signals... ')
X_filt_envel = np.abs(hilbert(X_filt))
if verbose: print('done.')

In [None]:
i = 0
j = 0
fig,ax = plt.subplots(1, 1, figsize=(8,5))
ax.plot(t, X[0][i,:], 'k', lw=1)
cmap = plt.get_cmap('viridis', N_bands)
for j in range(0, N_bands, 2):
    ax.plot(t, X_filt[j,i,:], '-.', color=cmap(j), lw=1)
    ax.plot(t, X_filt_envel[j,i,:], '-', color=cmap(j), lw=1)
ax.set_xlim([10, 20])
ax.set_ylim([-1, 2])
sns.despine()

### Compute the outputs of the last layer before the fully connected layer

In [None]:
layer_name = multi_output_model.layers[-1].name
if verbose: my_print(f'Computing the output of layer "{layer_name}"... ')
multi_Y = multi_output_model(X)
if verbose: print('done.')
Y = multi_Y[-1].numpy() if isinstance(multi_Y, list) else multi_Y
_, N_neurons, N_filters = Y.shape
if verbose: print(f'Layer "{layer_name}" has {N_filters} filters, each with {N_neurons} neurons.')

In [None]:
if verbose: my_print(f'Computing the output of layer "{layer_name}"... ')
multi_Y = ctrl_model(X)
if verbose: print('done.')
Y_ctrl = multi_Y[-1].numpy() if isinstance(multi_Y, list) else multi_Y

In [None]:
X_filt_envel.shape

### Compute the mean squared envelope for each receptive field

In [None]:
RF_sz, RF_str = effective_RF_size[layer_name], effective_stride[layer_name]
if verbose: print(f'The effective RF size and stride of layer "{layer_name}" are {RF_sz} and {RF_str} respectively.')
mean_squared_envel = np.zeros((N_trials, N_bands, N_neurons))
mean_envel = np.zeros((N_trials, N_bands, N_neurons))
if verbose: my_print('Computing the mean squared envelope for each receptive field... ')
for i in range(N_neurons):
    start, stop = i * RF_str, i * RF_str + RF_sz
    X_filt_envel_sub = X_filt_envel[:, :, start:stop]
    mean_squared_envel[:,:,i] = np.mean(X_filt_envel_sub ** 2, axis=2).T
    mean_envel[:,:,i] = np.mean(X_filt_envel_sub, axis=2).T
if verbose: print('done.')

### Compute the correlation using `pearsonr`
For each frequency band, compute the correlation between mean squared envelope
of the input (to each receptive field) and the output of each neuron in the layer

In [None]:
import ctypes
libcorr = ctypes.CDLL(os.path.join('..', 'libcorr.so'))
libcorr.pearsonr.argtypes = [ctypes.POINTER(ctypes.c_double),
                             ctypes.POINTER(ctypes.c_double),
                             ctypes.c_size_t,
                             ctypes.POINTER(ctypes.c_double),
                             ctypes.POINTER(ctypes.c_double)]
pointer = ctypes.POINTER(ctypes.c_double)
R_pointer = pointer(ctypes.c_double(0.0))
p_pointer = pointer(ctypes.c_double(0.0))

In [None]:
def my_pearsonr(x, y):
    x = x.copy().astype(np.float64)
    y = y.copy()
    x_pointer = x.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
    y_pointer = y.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
    libcorr.pearsonr(x_pointer, y_pointer, x.size, R_pointer, p_pointer)
    return R_pointer[0], p_pointer[0]

In [None]:
rows,cols = 5,4
fig,ax = plt.subplots(rows, cols, figsize=(2*cols, 2*rows))
cmap = plt.get_cmap('viridis', N_filters)
for j in range(N_bands):
    a = ax[j//cols][j%cols]
    for i in range(N_trials):
        for k in range(N_filters):
            a.plot(Y[i,:,k], mean_squared_envel[i,j,:], 'o', color=cmap(k), ms=2)
        break
    a.set_xticks([])
    a.set_yticks([])
sns.despine()

In [None]:
R = np.zeros((N_trials, N_bands, N_filters))
p = np.zeros((N_trials, N_bands, N_filters))
if verbose: print('Computing the correlations tensor...')
for i in tqdm(range(N_trials)):
    for j in range(N_bands):
        for k in range(N_filters):
            R[i,j,k], p[i,j,k] = my_pearsonr(Y[i,:,k], mean_squared_envel[i,j,:])

In [None]:
idx = np.argmax(R)
i = idx // (N_bands * N_filters)
j = (idx - i * (N_bands * N_filters)) // N_filters
k = (idx - i * (N_bands * N_filters)) % N_filters

In [None]:
I,J,K = np.where((R>0.495) & (R<0.505) & (p<0.05))
n = 5
i,j,k = I[n],J[n],K[n]

In [None]:
plt.plot(Y[i,:,k], mean_squared_envel[i,j,:], 'o', color='k', ms=2)

### Plots

In [None]:
cmap = plt.get_cmap('jet', N_bands)
rows, cols = N_bands // 4, 4
fig,ax = plt.subplots(rows, cols, figsize=(3*cols, 1.5*rows), sharex=True, sharey=True)
trial = 0
neuron = 30
start, stop = neuron * RF_str, neuron * RF_str + RF_sz
for k in range(N_bands):
    i,j = k//cols, k%cols
    ax[i,j].plot(t, X[0][trial, :], color=[.6,.6,.6], lw=0.5)
    ax[i,j].plot(t, X_filt[k, trial, :].T, color='k', lw=1)
    ax[i,j].plot(t, X_filt_envel[k, trial, :].T, color=cmap(k), lw=1)
#     ax[i,j].plot(t[[start,stop]], mean_envel[trial, k, neuron] + np.zeros(2), color=cmap(k), lw=4)
    ax[i,j].plot(t[[start,stop]], mean_squared_envel[trial, k, neuron] + np.zeros(2), color=cmap(k), lw=4)
ax[0,0].set_ylim([-1, 1])
xlim, ylim = ax[0,0].get_xlim(), ax[0,0].get_ylim()
for i in range(rows):
    for j in range(cols):
        k = i * cols + j
        ax[i,j].text(xlim[1], ylim[1], f'{bands[k][0]:.4f} - {bands[k][1]:.4f}',
                     ha='right', va='top', color='m')
        for side in 'right','top':
            ax[i,j].spines[side].set_visible(False)
    ax[i,0].set_ylabel('Norm. V')
for i in range(cols):
    ax[-1,i].set_xlabel('Time [s]')
fig.tight_layout()

In [None]:
cmap = plt.get_cmap('jet', N_bands)
rows, cols = N_bands // 4, 4
fig,ax = plt.subplots(rows, cols, figsize=(3*cols, 1.5*rows), sharex=True, sharey=True)
# fig,ax = plt.subplots(rows, cols, figsize=(3*cols, 1.5*rows))
filt = 0
ms = 5
R, p = np.zeros((rows, cols)), np.zeros((rows, cols))
R_ctrl, p_ctrl = np.zeros((rows, cols)), np.zeros((rows, cols))
for k in range(N_bands):
    i,j = k//cols, k%cols
    R[i,j],p[i,j] = pearsonr(Y[trial,:,filt], mean_squared_envel[trial,k,:])
    R_ctrl[i,j],p_ctrl[i,j] = pearsonr(Y_ctrl[trial,:,filt], mean_squared_envel[trial,k,:])
    ax[i,j].plot(mean_squared_envel[trial, k, :], Y_ctrl[trial, :, filt], 'o', color='k',
                 markersize=ms-1, markerfacecolor='w')
    ax[i,j].plot(mean_squared_envel[trial, k, :], Y[trial, :, filt], 'o', color=cmap(k),
                 markersize=ms)
xlim, ylim = ax[0,0].get_xlim(), ax[0,0].get_ylim()
for i in range(rows):
    for j in range(cols):
        k = i * cols + j
        ax[i,j].text(xlim[1], ylim[1], f'{bands[k][0]:.4f} - {bands[k][1]:.4f}',
                     ha='right', va='top', color='m')
        col = 'k' if p[i][j] < 0.05 else 'r'
        ax[i,j].text(xlim[1], ylim[1]-np.diff(ylim)/5, f'{R[i][j]:.3f}, {p[i][j]:.2f}',
                     ha='right', va='top', color=col)
        col = 'k' if p_ctrl[i][j] < 0.05 else 'r'
        ax[i,j].text(xlim[1], ylim[1]-np.diff(ylim)/5*2, f'{R_ctrl[i][j]:.3f}, {p_ctrl[i][j]:.2f}',
                     ha='right', va='top', color=col)
        for side in 'right','top':
            ax[i,j].spines[side].set_visible(False)
    ax[i,0].set_ylabel('Layer output')
for i in range(cols):
    ax[-1,i].set_xlabel('Mean squared envel.')
fig.tight_layout()