In [None]:
from sys import path
path.insert(0, '..')
import keras
import numpy as np
import tensorflow as tf
import keras.backend as K
from trainer import data as D
from trainer import model as M
import matplotlib.pyplot as plt

In [None]:
ckpt = 7000
DIR = './../logs/cross_val/augment_0.4/split_0/'

In [None]:
modelfile = DIR+'ckpt-%i.h5'%ckpt
model = keras.models.load_model(modelfile, custom_objects={'Scale': M.Scale})

In [None]:
num_samples, num_input_layers, num_filters = model.layers[4].get_weights()[1].shape
num_samples, num_input_layers, num_filters

In [None]:
layer_weights = {
    l.name: l.get_weights()
    for l in model.layers
}
layer_to_channel = {
    layer: channel
    for layer, channel in zip(('sequential_1', 'sequential_2', 'sequential_3'), ('EEG', 'EOG', 'EMG'))
}
input_layer_selection = 0
channel_weights = {
    layer_to_channel[layer]: weight[1][:, input_layer_selection, :].transpose()
    for layer, weight in layer_weights.items()
    if layer in layer_to_channel
}
channel_weights['EEG'].shape, '= output_layer, #samples'

In [None]:
plt.figure(figsize=(15, 5))
for i, (channel, weights) in enumerate(channel_weights.items()):
    plt.subplot(131+i)
    plt.title(channel)
    plt.imshow(weights, interpolation='nearest', aspect='auto')
    if i == 0:
        plt.xlabel('time index', fontsize=15)
        plt.ylabel('filter index', fontsize=15)
plt.tight_layout()        
plt.show()

In [None]:
plt.figure(figsize=(15, 5))
for i, (channel, weights) in enumerate(channel_weights.items()):
    plt.subplot(131+i)
    plt.title(channel)
    for f, fil in enumerate(weights):
        plt.plot(fil-f*0.5, 'ko-')
    if i == 0:
        plt.xlabel('time index -->', fontsize=15)
        plt.ylabel('<-- filter index', fontsize=15)
plt.tight_layout()        
plt.show()

In [None]:
plt.figure(figsize=(15, 5))
for i, (channel, weights) in enumerate(channel_weights.items()):
    plt.subplot(131+i)
    plt.title(channel)
    for f, fil in enumerate(weights):
        plt.psd(fil, Fs=D.sr, NFFT=16)
plt.tight_layout()        
plt.show()

In [None]:
def plt_cov_mat(X, ax=None):
    X -= X.mean(axis=1)[:, None]
    X /= X.std(axis=1)[:, None]
    cov = np.cov(X)
#     for i in range(cov.shape[0]):
#         cov[i, i] = 0.0
    if ax is None:
        plt.figure(figsize=(6, 6))
        ax = plt.subplot(111)
    plt.imshow(np.abs(cov), aspect='auto')
    plt.clim(0, 0.8)
    plt.colorbar()
    return cov

In [None]:
plt.figure(figsize=(18, 6))
for c, channel in enumerate(('EEG', 'EMG', 'EOG')):
    ax = plt.subplot(131+c)
    plt.title('channel %s'%channel)
    X = plt_cov_mat(channel_weights[channel], ax=ax)
plt.show()

In [None]:
import scipy.signal

for filt in channel_weights['EEG']:
    f, Pxx = scipy.signal.welch(filt, fs=D.sr, nfft=filt.size)

    t = np.arange(1024)/D.sr
    test_signals = {
        w: 0.25*np.sin(t*2*np.pi*w+2.*np.pi*np.random.rand())
        for w in f
    }
    power = {
        w: np.var(np.convolve(s, filt[::-1], mode='valid'), ddof=1)
        for w, s in test_signals.items()
    }
    x, y = list(power.keys()), list(power.values())

    plt.figure(figsize=(10, 3))
    ax = plt.subplot(121)
    plt.plot(np.arange(filt.size)/D.sr, filt, 'ko-')
    plt.subplot(122)
    plt.semilogy(x, y, 'ro-', alpha=0.5, label='test fun')
    plt.semilogy(f, Pxx, 'ko-', label='filt fft')
    plt.legend()
    plt.tight_layout()
    plt.ylim(10**-5, 10**1)
plt.show()