# Wienier filter training

In [None]:
import loading
import spir
import numpy as np
from pathlib import Path
import os

In [None]:
import pickle

# save.p contains a dictionary with keys == filenames. Each key contains a list
# of interferences
with open('save.p', 'rb') as f:
    inter = pickle.load(f)

## Estimate interference cov matrix

In [None]:
## Calculate covariance matrix for each interference

lag = 50

for filename in inter:
    if len(inter[filename]):
        print(filename)
        (fs, data, labels) = loading.loadRecording(filename + '.edf')
        folder = '/'.join(filename.split('/')[6:-1])
        folder = os.path.join('results-ii', folder)
        file = filename.split('/')[-1]
        Path(folder).mkdir(parents=True, exist_ok=True)
        for j in range(len(inter[filename])):
            rnn = spir.build_cov(data, [[inter[filename][j][0]/fs, inter[filename][j][1]/fs]], lag, fs)
            np.save(os.path.join(folder, file +  '-rii-{}'.format(j)), rnn)

In [None]:
## Load previously calculated covariance matrices

rnns = list()
for filename in inter:
    if len(inter[filename]):
        folder = '/'.join(filename.split('/')[6:-1])
        folder = os.path.join('results-ii', folder)
        file = filename.split('/')[-1]
        Path(folder).mkdir(parents=True, exist_ok=True)
        for j in range(len(inter[filename])):
            rnn = np.load(os.path.join(folder, file +  '-rii-{}.npy'.format(j)))
            rnns.append(rnn.flatten())
rnns = np.array(rnns)

# Bad element == 70
rnns = np.delete(rnns, 70, axis=0)

## Compress Rnns

In [None]:
from sklearn.decomposition import PCA

pca = PCA(0.999)
pca.fit(rnns)
compressed = pca.fit_transform(rnns)

In [None]:
print('Number of compressed components: {}'.format(compressed.shape[1]))

## Perform K-means clustering

In [None]:
import h5py
import numpy as np
with h5py.File('compressed.h5', 'r') as h5f:
    compressed = np.array(h5f['compressed'])

In [None]:
from sklearn.cluster import KMeans

## Find n-clusters
def calculate_WSS(points, kmax):
    sse = []
    for k in range(1, kmax+1):
        kmeans = KMeans(n_clusters = k).fit(points)
        centroids = kmeans.cluster_centers_
        pred_clusters = kmeans.predict(points)
        curr_sse = 0

        for i in range(len(points)):
            curr_center = centroids[pred_clusters[i]]
            curr_sse += (points[i, 0] - curr_center[0]) ** 2 + (points[i, 1] - curr_center[1]) ** 2

        sse.append(curr_sse)
    return sse


sse = calculate_WSS(compressed, 50)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(16, 6))
plt.plot(sse)
plt.ylabel('# SSE')
plt.xlabel('# of clusters')
plt.title('Choice of # of cluster')
plt.show()

In [None]:
n_clusters = 9

In [None]:
kmeans = KMeans(n_clusters=n_clusters).fit(compressed)

In [None]:
plt.figure(figsize=(16, 6))
plt.hist(kmeans.labels_)
plt.xlabel('Cluster labels')
plt.title('histogram of cluster labels')
plt.show()

Clusters 0 and 5 are retained

In [None]:
import plotting
t = 0
plot_5 = 0
cluster_n = 0
print('# 5 Examples of cluster 0')
for filename in inter:
    if len(inter[filename]) and plot_5 < 5:
        (fs, data, labels) = loading.loadRecording(filename + '.edf')
        folder = '/'.join(filename.split('/')[6:-1])
        folder = os.path.join('results-ii', folder)
        file = filename.split('/')[-1]
        Path(folder).mkdir(parents=True, exist_ok=True)
        for j in range(len(inter[filename])):
            if kmeans.labels_[t] == cluster_n and plot_5 < 5:
                a,p = plotting.plot_event(fs, data, labels, inter[filename][j])
                plotting.show(p)
                plot_5 += 1
            t += 1

In [None]:
import plotting
t = 0
plot_5 = 0
cluster_n = 5
print('# 5 Examples of cluster 5')
for filename in inter:
    if len(inter[filename]) and plot_5 < 5:
        (fs, data, labels) = loading.loadRecording(filename + '.edf')
        folder = '/'.join(filename.split('/')[6:-1])
        folder = os.path.join('results-ii', folder)
        file = filename.split('/')[-1]
        Path(folder).mkdir(parents=True, exist_ok=True)
        for j in range(len(inter[filename])):
            if kmeans.labels_[t] == cluster_n and plot_5 < 5:
                a,p = plotting.plot_event(fs, data, labels, inter[filename][j])
                plotting.show(p)
                plot_5 += 1
            t += 1

## Calculate filters

In [None]:
## Calculate filters

# Average Cov
filters = list()
for i in [0, 5]:
    filters.append(np.mean(rnns[kmeans.labels_ == i,:], axis=0))

for i, filt in enumerate(filters):
    dim = int(filt.shape[0]**0.5)
    w, v = np.linalg.eig(filt.reshape(dim, dim))
    index_i = np.argmax(np.cumsum(np.real(w))/np.sum(np.real(w)) > 0.9)
    filters[i] = np.real(v[:,:index_i])

In [None]:
from numba import jit, prange

def wiener_filter(data, v):
    """Apply maxSPIR filter.

    Args:
        data: data contained in an array (row = channels, column = samples)
        v: maxSPIR filter as a flattened vector
        noise: noise binary mask contained in an array of the same size as data
    Return:
        out: filtered data
    """
    lag = int(v.shape[0]/data.shape[0])
    filtered = list()
    for j in prange(v.shape[1]):
        v_shaped = np.reshape(v[:,j], (data.shape[0], lag))
        out = np.convolve(v_shaped[0, :], data[0, :], 'full')
        for i in range(1, v_shaped.shape[0]):
            out += np.convolve(v_shaped[i, :], data[i, :], 'full')
        filtered.append(out)
    t = np.arange(0, v.shape[0], step=lag, dtype=int)
    filtered = np.dot(v[t,:], filtered)
    return np.array(filtered[:,:data.shape[1]])

## Example

In [None]:
filename = '/esat/biomeddata/Neureka_challenge/edf/train/01_tcp_ar/006/00000630/s002_2003_05_28/00000630_s002_t001'

import nedc
seizures = nedc.loadTSE(filename + '.tse')
print('# seizures : {}'.format(len(seizures)))
(fs, data, labels) = loading.loadRecording(filename + '.edf')

In [None]:
filtered_0  = wiener_filter(data, filters[0])
data_filt0 = data - filtered_0
filtered_1  = wiener_filter(data, filters[1])
datafiltered_1  = wiener_filter(data_filt0, filters[1])
data_filt = data_filt0 - datafiltered_1

In [None]:
j = 0

event = inter[filename][j]
event = [int(seizures[j][0]*fs) + 4000, int(seizures[j][1]*fs)]

a,p = plot_event(fs, data, labels, event)
plotting.show(p)
a,p = plot_event(fs, data_filt, labels, event)
plotting.show(p)

In [None]:
with open('filters.pickle', 'wb') as handle:
    pickle.dump(filters, handle)