### Import libraries

In [None]:
import os
import sys
import time
import logging

import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
from hdf5storage import loadmat, savemat

import tensorflow as tf
import tensorflow_probability as tfp

### Load sample data

Load provided sample data, which consists of a training set of 54 trials and a test set of multiple repetitions of 8 unique trials. Here, the repetitions of the test set have been already averaged into two groups of responses to odd repetitions and even repetitions. The stimuli are in form of auditory spectrograms of the presented audio and responses are preprocessed and z-scored highgamma envelopes of three sample ECoG electrodes.

$X_{tr}$ and $Y_{tr}$ are the stimulus and response data for the train set, respectively, while $X_{te}$ and $Y_{te}$ represent the test set. $X_{tr}$ and $X_{te}$ consist of multiple trials, each with shape $[time \times freq\_bins]$. $Y_{tr}$ and $Y_{te}$ have shapes $[time \times channels]$ and $[time \times channels \times repetitions]$ per trial, respectively.

In [None]:
data = loadmat('../sample_data_keshishian_etal.mat')

# train set
X_tr = data['X_tr'].flatten()
Y_tr = data['Y_tr'].flatten()

# test set
X_te = data['X_te'].flatten()
Y_te = data['Y_te'].flatten()

# separate even and odd repetitions of test set
Y_r0 = np.array([y[:,:,0::2].mean(-1) for y in Y_te])
Y_r1 = np.array([y[:,:,1::2].mean(-1) for y in Y_te])
Y_te = np.array([y.mean(-1) for y in Y_te])

# stimulus frequency channels
freq_bins = X_tr[0].shape[-1]

Display sample stimulus and response:

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(211), plt.imshow(X_tr[0][:300].T, origin="ll", aspect=2), plt.xlim([0, 300])
plt.subplot(212), plt.plot(Y_tr[0][:300,0]), plt.xlim([0, 300])
plt.show()

Split the train set into two train and validation subsets for training cross-validation. This data split can be done using jackknife, etc.

In [None]:
# set aside two trials for validation
trials_vl = np.zeros(len(X_tr), dtype=bool)
trials_vl[[1, 3]] = 1

X_vl = X_tr[trials_vl]
Y_vl = Y_tr[trials_vl]

X_tr = X_tr[~trials_vl]
Y_tr = Y_tr[~trials_vl]

### Select channel

Select target channel from list of electrodes. The new variables are 1-dimensional timecourses per trial.

In [None]:
channel = 0
print(f"Fitting models for channel {channel:d}...")

Y_tr = np.array([y[:, channel] for y in Y_tr])
Y_vl = np.array([y[:, channel] for y in Y_vl])
Y_te = np.array([y[:, channel] for y in Y_te])
Y_r0 = np.array([y[:, channel] for y in Y_r0])
Y_r1 = np.array([y[:, channel] for y in Y_r1])

### Prepare data for training

Concatenate all trials into a single numpy variable for ease of training. Trials are spaced by 500ms to avoid between-trial contamination by the network.

In [None]:
X_tr = np.concatenate([np.pad(x, ((25, 25), (0, 0)), constant_values=0) for x in X_tr])
Y_tr = np.concatenate([np.pad(x, ((25, 25),), constant_values=np.nan) for x in Y_tr])

X_vl = np.concatenate([np.pad(x, ((25, 25), (0, 0)), constant_values=0) for x in X_vl])
Y_vl = np.concatenate([np.pad(x, ((25, 25),), constant_values=np.nan) for x in Y_vl])

X_te = np.concatenate([np.pad(x, ((25, 25), (0, 0)), constant_values=0) for x in X_te])
Y_te = np.concatenate([np.pad(x, ((25, 25),), constant_values=np.nan) for x in Y_te])
Y_r0 = np.concatenate([np.pad(x, ((25, 25),), constant_values=np.nan) for x in Y_r0])
Y_r1 = np.concatenate([np.pad(x, ((25, 25),), constant_values=np.nan) for x in Y_r1])

print(X_tr.shape, X_vl.shape, X_te.shape)
print(Y_tr.shape, Y_vl.shape, Y_te.shape)

Split the data into continous chunks for batch training if needed.<br/>Network inputs must have shape $[batch \times time \times freq\_bins]$ and outputs must have shape $[batch \times time]$.

In [None]:
# use full data as single batch
X_tr = X_tr[np.newaxis]
Y_tr = Y_tr[np.newaxis]

X_vl = X_vl[np.newaxis]
Y_vl = Y_vl[np.newaxis]

X_te = X_te[np.newaxis]
Y_te = Y_te[np.newaxis]
Y_r0 = Y_r0[np.newaxis]
Y_r1 = Y_r1[np.newaxis]

print(X_tr.shape, X_vl.shape, X_te.shape)
print(Y_tr.shape, Y_vl.shape, Y_te.shape)

### Define model

In [None]:
# drop nan elements for computing loss
def drop_nan(response, prediction):
    mask = tf.math.is_finite(response)
    return tf.boolean_mask(response, mask), tf.boolean_mask(prediction, mask)

# loss function
def loss_se(response, prediction):
    """Squared error loss."""
    response, prediction = drop_nan(response, prediction)
    num = tf.reduce_mean(tf.square(response - prediction))
    den = tf.reduce_mean(tf.square(response))
    # den = tf.reduce_mean(tf.square(response - tf.reduce_mean(response)))
    return num / den

# correlation metric
def metric_corr(response, prediction):
    response, prediction = drop_nan(response, prediction)
    response, prediction = tf.expand_dims(response, 0), tf.expand_dims(prediction, 0)
    return tfp.stats.correlation(response, prediction, 1, 0)

# noise-corrected correlation
def fn_ncorr(prediction):
    mask = np.isfinite(Y_te)
    r0 = scipy.stats.pearsonr(Y_r0[mask], prediction[mask])[0]
    r1 = scipy.stats.pearsonr(Y_r1[mask], prediction[mask])[0]
    rr = scipy.stats.pearsonr(Y_r0[mask], Y_r1[mask])[0]
    return (r0 + r1)/2 / np.sqrt(rr)

In [None]:
l2 = tf.keras.regularizers.l2(0.0001)
layer_opts = dict(activation='relu', use_bias=False, kernel_regularizer=l2)

layers = (
    tf.keras.layers.InputLayer(input_shape=(None, freq_bins)),
    tf.keras.layers.Reshape((-1, freq_bins, 1)),
    
    tf.keras.layers.ZeroPadding2D(((2, 0), (1, 1))),
    tf.keras.layers.Conv2D(16, 3, **layer_opts),
    tf.keras.layers.ZeroPadding2D(((2, 0), (1, 1))),
    tf.keras.layers.Conv2D(16, 3, **layer_opts),
    tf.keras.layers.ZeroPadding2D(((2, 0), (1, 1))),
    tf.keras.layers.Conv2D(16, 3, **layer_opts),
    
    # tf.keras.layers.ZeroPadding2D(((1, 0), (0, 0))),
    # tf.keras.layers.MaxPool2D((2, 2), (1, 2)),
    
    tf.keras.layers.ZeroPadding2D(((39, 0), (0, 0))),
    tf.keras.layers.Conv2D(1, (40, 32), use_bias=True, kernel_regularizer=l2),
    tf.keras.layers.Flatten()
)

blueprint = tf.keras.models.Sequential(layers)
blueprint.summary()

### Fit model

In [None]:
# prepare directory for saving trained models
model_dir = f"models/"
os.makedirs(model_dir, exist_ok=True)

# path for saving model
model_path = os.path.join(model_dir, f"mdl-v1-{channel:03d}")

# initialize model from blueprint
model = tf.keras.models.clone_model(blueprint)
#optim = tf.keras.optimizers.RMSprop(1e-3, momentum=0.9)
optim = tf.keras.optimizers.Adam(1e-3)
model.compile(optimizer=optim, loss=loss_se, metrics=[metric_corr])

# set callbacks
callbk_early_stop = tf.keras.callbacks.EarlyStopping(
    monitor='val_metric_corr', patience=500, mode='max', restore_best_weights=True)
callbk_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    model_path, monitor='val_metric_corr', save_best_only=True, mode='max')
callbacks = [callbk_early_stop, callbk_checkpoint]

# fit model to data
history = model.fit(X_tr, Y_tr, validation_data=(X_vl, Y_vl),
                    epochs=5000, verbose=1, callbacks=callbacks)

# plot performance curves
plt.figure()
plt.plot(history.history['metric_corr'])
plt.plot(history.history['val_metric_corr'])

### Evaluate test set

In [None]:
model.load_weights(f"{model_base:s}.h5")
pred = model(X_te)
print(f"\tsplit #{jk+1}\t{fn_ncorr(pred):.3f}")

### Calculate dynamic STRFs

### Save dynamic STRFs