Simple recurrent neural network (RNN) implementation in Keras using LSTM (long short-term memory) units to identify time of occurence of some events in temporal data based on the wavelet spectrogram of the data

In [1]:
%matplotlib notebook

In [2]:
import numpy as np
import scipy.signal as ss
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import keras

Using TensorFlow backend.


In [3]:
np.random.seed(1234)

# Create training/validation data
Set up some data by superimposing brown noise with with a cosine with Gaussian envelope inserted once at a random time in each dataset

In [4]:
n_samples = 1000     # number of training + validation datasets
signal_length = 1000 # number of time points in segment
dt = 0.001           # s, temporal resolution

# some parameters for the events
impulse_length = 101 # length of impulse label in units of time steps
impulse_freq = 100   # Hz, frequency of impulse
impulse_std = 20     # std. dev of Gaussian
impulse_amplitude = 0.2 # amplitude of impulse

# Set up labels as boxcars around each event time
y = np.zeros((n_samples, signal_length, 1))
offset = impulse_length // 2 # timesteps
label_width = 51 # timesteps 
times = np.random.randint(offset, signal_length-offset, size=n_samples) # times of events on grid

# create data as brown noise + some oscillatory event at labeled times
X0 = np.random.randn(n_samples, signal_length) # white noise
X0 = X0.cumsum(axis=-1) # 1/f**2 noise
X0 /= X0.std() # normalize

for i in range(n_samples):
    x = np.zeros(signal_length)
    x[times[i]] = 1
    
    # set up label arrays as boxcars centered on time of events
    y[i, :, 0] = np.convolve(x, ss.boxcar(label_width), 'same')
    
    # compute and superimpose event on background noise
    impulse = ss.gaussian(impulse_length, impulse_std) * \
                          np.cos(impulse_freq*2*np.pi*np.arange(impulse_length)*dt - np.random.rand()*2*np.pi)
    X0[i] = X0[i] + np.convolve(x, impulse*impulse_amplitude, 'same')

# center raw data
X0 = (X0.T - X0.mean(axis=-1)).T

# time vector
time = np.arange(signal_length) * dt

In [5]:
# test plot
plt.figure()
plt.plot(time, X0[0, :], label='raw data')
plt.plot(time, y[0, :, 0], label='label (y)')
plt.legend()
plt.xlabel('t (s)')

<IPython.core.display.Javascript object>

Text(0.5, 0, 't (s)')

In [6]:
# plot all labels and raw data matrices
fig, axes = plt.subplots(2, 1, sharex=True, sharey=True)
axes[0].pcolormesh(time, np.arange(n_samples), y[:, :, 0])
axes[0].set_ylabel('#')
axes[0].set_title('labels (y)')
axes[1].pcolormesh(time, np.arange(n_samples), X0)
axes[1].set_ylabel('#')
axes[1].set_xlabel('t (s)')
axes[1].set_title('raw data')
for ax in axes:
    ax.axis(ax.axis('tight'))

<IPython.core.display.Javascript object>

In [7]:
# Set up and apply complex morlet wavelet transform of raw data
Fs = 1 / dt # sampling frequency
waveletfreqs = np.arange(25., 200, 10) # Hz

#set up continuous wavelets
w=6.
s=1.

#wavelets
waveletfun = ss.morlet
wavelets = []
for i, f in enumerate(waveletfreqs):
    kwargs = {
        'M' : 2. * s * Fs * w / f,
        'w' : w,
        's' : s,
        'complete' : True,
    }
    wl = waveletfun(**kwargs)
    wavelets.append(wl)

# Container for preprocessed training/validation data
X = np.empty((X0.shape + (waveletfreqs.size,)), dtype=complex)

#apply wavelets
for j, x in enumerate(X0):
    for i, wavelet in enumerate(wavelets):
        X[j, :, i] = ss.convolve(x, wavelet, 'same')

# envelope
X = np.abs(X).astype(float)**2

In [8]:
# plot wavelet spectrograms vs. labels and raw data for some samples
for i in range(3):
    gs = GridSpec(4, 1)
    fig = plt.figure()
    ax0 = fig.add_subplot(gs[0, 0])
    ax0.plot(time, X0[i, ], label='raw data')
    ax0.plot(time, y[i, :, 0], label='label (y)' )
    ax0.legend(ncol=2)
    ax0.axis(ax0.axis('tight'))
    ax0.set_title('label and raw data')
    plt.setp(ax0.get_xticklabels(), visible=False)
    
    ax1 = fig.add_subplot(gs[1:, 0], sharex=ax0)
    im = ax1.pcolormesh(time, waveletfreqs, X[i, ].T, vmin=0, vmax=X.std()*2)
    ax1.axis(ax1.axis('tight'))
    ax1.set_ylabel('f (Hz)')
    ax1.set_xlabel('t (s)')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# Set up recurrent neural network

In [9]:
def generate_model(input_shape, lr=0.01, dropout_rate=0.2, layer_sizes=[5, 5, 5], ):
    keras.backend.clear_session()

    # input layer
    inputs = keras.layers.Input(shape=input_shape)
    
    # conv layer
    x = keras.layers.Conv1D(layer_sizes[0], 
                            kernel_size=5, strides=1, 
                            padding='same'
                           )(inputs)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Activation('relu')(x)
    x = keras.layers.Dropout(dropout_rate)(x)
    
    # LSTM layer 1
    x = keras.layers.LSTM(layer_sizes[1], return_sequences=True)(x)
    x = keras.layers.BatchNormalization()(x)  
    x = keras.layers.Dropout(dropout_rate)(x)
    
    # LSTM layer 2
    x = keras.layers.LSTM(layer_sizes[2], return_sequences=True)(x)
    x = keras.layers.Dropout(dropout_rate)(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Dropout(dropout_rate)(x)
        
    # dense output layer
    predictions = keras.layers.TimeDistributed(
        keras.layers.Dense(1, activation='sigmoid'))(x)
    
    # Define model
    model = keras.models.Model(inputs=inputs, outputs=predictions)

    opt = keras.optimizers.Adam(lr=lr)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy', 'mse'])

    return model

In [10]:
model = generate_model(input_shape=X[0].shape)

In [11]:
model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 1000, 18)          0         
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 1000, 5)           455       
_________________________________________________________________
batch_normalization_1 (Batch (None, 1000, 5)           20        
_________________________________________________________________
activation_1 (Activation)    (None, 1000, 5)           0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 1000, 5)           0         
_________________________________________________________________
lstm_1 (LSTM)                (None, 1000, 5)           220       
_________________________________________________________________
batch_normalization_2 (Batch (None, 1000, 5)           20  

In [12]:
history = model.fit(X, y, batch_size=20, epochs=10, validation_split=0.05)

Train on 950 samples, validate on 50 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [13]:
plt.figure()
plt.semilogy(history.history['loss'], '-o', label='loss')
#plt.plot(history.history['accuracy'], '-o', label='accuracy')
#plt.plot(history.history['val_accuracy'], '-o', label='val_accuracy')
plt.semilogy(history.history['mse'], '-o', label='mse')
plt.semilogy(history.history['val_mse'], '-o', label='val_mse')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('training/validation loss')
plt.semilogy()

<IPython.core.display.Javascript object>

[]

In [14]:
# visualize predictions on some samples from the validation set
n_val_samples = 3
X_val = history.validation_data[0][:n_val_samples]
y_val = history.validation_data[1][:n_val_samples]

y_pred = model.predict(X_val)

# compare prediction to ground truth
fig, axes = plt.subplots(n_val_samples, 2, figsize=(8, 8), 
                         sharex=True, sharey='col')
for i in range(n_val_samples):
    axes[i, 0].pcolormesh(time, waveletfreqs, X_val[i].T, vmin=0, vmax=X.std()*2)
    axes[i, 1].plot(time, y_val[i], label='$y(t)$')
    axes[i, 1].plot(time, y_pred[i], label='$\hat{y}(t)$')
    if i == 0:
        axes[i, 1].legend()
        axes[i, 0].set_title('$X(t)$')
        axes[i, 1].set_title('$y(t)$ vs $\hat{y}(t)$')
    axes[i, 0].set_ylabel('f (Hz)')
    axes[i, 1].set_ylabel('probability')
axes[i, 0].set_xlabel('$t$ (s)')
axes[i, 1].set_xlabel('$t$ (s)')

<IPython.core.display.Javascript object>

Text(0.5, 0, '$t$ (s)')