Proof of concept convolutional neural network trained to predict the nonlinear subthreshold membrane voltage dynamics of a biological neuron modeled using the Hodkin and Huxley (J Physiol. 1952 Aug 28; 117(4): 500–544.
doi: https://doi.org/10.1113/jphysiol.1952.sp004764) formalism, to noisy (synaptic) input currents. 

To generate training/validation data file(s), run the notebook HodkinHuxleyDataGeneration.ipynb first

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import HDF5Matrix

In [None]:
np.random.seed(1234)

In [None]:
# number of validation datas
n_val = 500

# load training/validation data
X = HDF5Matrix('HodkinHuxleySubthresholdData.h5', 'X')
Y = HDF5Matrix('HodkinHuxleySubthresholdData.h5', 'Y')

# test data
X_test = HDF5Matrix('HodkinHuxleySubthresholdData.h5', 'X_test')
Y_test = HDF5Matrix('HodkinHuxleySubthresholdData.h5', 'Y_test')

In [None]:
X.shape, Y.shape

In [None]:
def generate_model(input_shape, lr=0.001, layer_sizes=[2, 6, 16], 
                  kernel_sizes=[11, 5, 1]):
    keras.backend.clear_session()

    # Define model
    model = keras.models.Sequential()
    
    # input layer
    model.add(keras.layers.InputLayer(input_shape))

    # convolutional layers and activation
    for ls, ks in zip(layer_sizes, kernel_sizes):
        model.add(keras.layers.Conv1D(ls, 
                                      kernel_size=ks, 
                                      padding='same',
                                      kernel_regularizer=l2(),
                                      bias_regularizer=l2(),
                                      activation='relu'))

    # dense output layer
    model.add(keras.layers.TimeDistributed(
        keras.layers.Dense(1, activation='linear')))

    # optimizer
    opt = keras.optimizers.Adam(lr=lr)

    # compile model
    model.compile(loss='mse', optimizer=opt, metrics=['mse'])

    return model

In [None]:
model = generate_model(input_shape=(None, 1))

In [None]:
model.summary()

In [None]:
import pydot
from tensorflow.keras.utils import plot_model
from IPython.display import SVG
from tensorflow.keras.utils import model_to_dot

#plot_model(model, to_file='model.png')
SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))

In [None]:
history = model.fit(X[:-n_val], 
                    Y[:-n_val], batch_size=100, epochs=100, 
                    validation_data=(X[-n_val:], Y[-n_val:]))

In [None]:
plt.figure()
plt.semilogy(history.history['loss'], '-o', label='loss')
plt.semilogy(history.history['mse'], '-o', label='mse')
plt.semilogy(history.history['val_mse'], '-o', label='val_mse')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('training/validation loss')

In [None]:
# visualize predictions on some samples from the training/validation set
X_val = X[-n_val:]
Y_val = Y[-n_val:]

Y_pred = model.predict(X_val)

# compare prediction to ground truth
fig, axes = plt.subplots(n_val_samples, 2, figsize=(8, 8), 
                         sharex=True, sharey='col')
for i in range(3):
    axes[i, 0].plot(X_val[i], label='$x(t)$')
    axes[i, 1].plot(Y_val[i], label='$y(t)$')
    axes[i, 1].plot(Y_pred[i], label='$\hat{y}(t)$')
    if i == 0:
        axes[i, 1].legend()
        axes[i, 0].set_title('$x(t)$')
        axes[i, 1].set_title('$y(t)$ vs $\hat{y}(t)$')
    axes[i, 0].set_ylabel('I (pA)')
    axes[i, 1].set_ylabel('Vm (mV)')
axes[i, 0].set_xlabel('$t$ (ms)')
axes[i, 1].set_xlabel('$t$ (ms)')

In [None]:
# test with time series longer than training set time series
Y_pred = model.predict(X_test)

# compare prediction to ground truth
fig, axes = plt.subplots(n_val_samples, 2, figsize=(8, 8), 
                         sharex=True, sharey='col')
for i in range(n_val_samples):
    axes[i, 0].plot(X_test[i][2900:3000], label='$x(t)$')
    axes[i, 1].plot(Y_test[i][2900:3000], label='$y(t)$')
    axes[i, 1].plot(Y_pred[i][2900:3000], label='$\hat{y}(t)$')
    if i == 0:
        axes[i, 1].legend()
        axes[i, 0].set_title('$x(t)$')
        axes[i, 1].set_title('$y(t)$ vs $\hat{y}(t)$')
    axes[i, 0].set_ylabel('I (pA)')
    axes[i, 1].set_ylabel('Vm (mV)')
axes[i, 0].set_xlabel('$t$ (ms)')
axes[i, 1].set_xlabel('$t$ (ms)')