In [1]:
import json
from scipy.io import loadmat
from IPython.display import Markdown as md
import yaml
import numpy as np
from scipy import signal
from matplotlib import pyplot as plt
from datetime import datetime
from sklearn.decomposition import FastICA

# Get the snakemake object
with open('.preprocess.ipynb_snakemake.json', 'r') as json_file:
    snakemake = json.load(json_file)

# Read config file
with open('config.yml', 'r') as yaml_file:
    config = yaml.safe_load(yaml_file)

In [None]:
def plot_periodogram(data, sample_rate, log = True, save = False, filename = ''):
    for channel in range(data.shape[ 1 ]):
        f, x = signal.welch(data[ :, channel ], fs = sample_rate)
        x = np.log(x) if log else x

        plt.plot(f, x, linewidth = 0.75)

    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Relative power, log')

    if save:
        plt.savefig(filename)
    else:
        plt.show()


def apply_filter(data, filter):
    filtered_channels = [ ]
    for channel_i in range(data.shape[ 1 ]):
        to_filter = data[ :, channel_i ]
        filtered = signal.sosfiltfilt(filter, to_filter)
        filtered_channels.append(filtered)

    return np.array(filtered_channels).transpose()


def remove_powerline(data, frequencies, sample_rate, range = 1):
    for frequency in frequencies:
        print(f'Removing frequency {frequency}')
        filter = signal.butter(
            N = 20,
            Wn = [ frequency - range, frequency + range ],
            btype = 'bandstop',
            output = 'sos',
            fs = sample_rate
        )

        data = apply_filter(data, filter)

    return data


def low_pass(data, frequency, sample_rate):
    print(f'Passing frequencies lower than {frequency}')

    filter = signal.butter(
        N = 20,
        Wn = frequency,
        btype = 'low',
        output = 'sos',
        fs = sample_rate
    )

    return apply_filter(data, filter)


def high_pass(data, frequency, sample_rate):
    print(f'Passing frequencies higher than {frequency}')

    filter = signal.butter(
        N = 20,
        Wn = frequency,
        btype = 'high',
        output = 'sos',
        fs = sample_rate
    )

    return apply_filter(data, filter)

def plot_periodograms(datasets, labels, sample_rate, log = True, save = False, filename = ''):
    colors = plt.rcParams[ 'axes.prop_cycle' ].by_key()[ 'color' ][ 0:len(datasets) ]

    for dataset, label, color in zip(datasets, labels, colors):
        has_label = False
        for channel in range(dataset.shape[ 1 ]):
            f, x = signal.welch(dataset[ :, channel ], fs = sample_rate)
            x = np.log(x) if log else x

            if has_label:
                plt.plot(f, x, linewidth = 0.75, color = color)
            else:
                plt.plot(f, x, linewidth = 0.75, color = color, label = label)
                has_label = True

    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Relative power, log')
    plt.legend()

    if save:
        plt.savefig(filename, dpi = 1000)
    else:
        plt.show()

# TODO: needs fixing!
def plot_channels(data, timepoints, time_start, time_end, n_channels = 0, channel_list: list = None, offset = 1.0, y_lab = 'Channel'):
    # Get only the data from time range of interest and only the channels of interest
    # If specific channel list is specified, use channels from the list
    if channel_list is not None and len(channel_list) != 0:
        data = data[ time_start:time_end, channel_list ]
        n_channels = len(channel_list)
        channel_names = channel_list
    else:
        data = data[ time_start:time_end, 0:n_channels ]
        channel_names = list(range(0, data.shape[ 1 ]))

    # Scale the data and add an offset for plotting
    scaled = data / np.max(np.abs(data))
    channel_offset = list(range(0, scaled.shape[ 1 ])) * np.array(offset)
    scaled = scaled + channel_offset

    plt.plot(timepoints[ time_start:time_end ], scaled, color = 'black', linewidth = 0.4)
    plt.yticks(channel_offset, channel_names)
    plt.ylim(-1, offset * n_channels - 1)
    plt.ylabel(y_lab)
    plt.xlabel('Time, s')
    plt.show()

In [11]:
md(f'# Preprocessing of {snakemake[ "wildcards" ][ 1 ]}')

# Preprocessing of S2_run1

In [None]:
raw = loadmat(snakemake[ 'input' ][ 1 ])[ 'y' ].transpose()
raw.shape

In [None]:
sample_rate = config[ 'sampleRate' ]
sample_rate

In [None]:
timepoints = raw[ :, 0 ]
events = raw[ :, -1 ]
data = raw[ :, 1:(-1) ]
data.shape

In [None]:
assert round(raw.shape[ 0 ] / sample_rate) == round(timepoints[ -1 ])
f'Measurement length in seconds is {round(timepoints[ -1 ], 2)}'

In [None]:
plt.figure(figsize = (12, 6))
plot_periodogram(data, sample_rate)

In [None]:
powerline = [ 60, 120, 180, 240 ]

filtered = remove_powerline(data, powerline, sample_rate, range = 10)
filtered = low_pass(filtered, 290, sample_rate)

In [None]:
plt.figure(figsize = (12, 6))
plot_periodogram(filtered, sample_rate)

In [None]:
plt.figure(figsize = (12, 6))
plot_periodograms([ data, filtered ], [ 'Before filtering', 'After filtering' ], sample_rate)

In [None]:
time_start = round(int(np.where(events == 2)[ 0 ][ 0 ]), -3)
time_end = int(time_start + 5e3)

plt.figure(figsize = (15, 10))
plot_channels(filtered, timepoints, time_start, time_end, n_channels = 10, offset = 2)

In [None]:
transformer = FastICA(n_components = 20, random_state = 42, whiten = 'unit-variance', max_iter = 200)

start_time = datetime.now()

# Reconstruct the signal
sources = transformer.fit_transform(filtered)

end_time = datetime.now()
print('Duration: {}'.format(end_time - start_time))

In [None]:
plt.figure(figsize = (15, 10))
plot_channels(sources, timepoints, time_start, time_end, n_channels = 20, offset = 2, y_lab = 'Component')

In [None]:
# TODO: plot the bandwidths

# bandwidths = {
#     'delta': (1, 4),
#     'theta': (4, 8),
#     'alpha': (8, 13),
#     'beta': (13, 32),
#     'gamma': (32, 50)
# }
# 
# bandwidth_filters = {
#     name: signal.butter(N = 20, Wn = frange, btype = 'bp', output = 'sos', fs = sample_rate)
#     for name, frange in bandwidths.items()
# }
# 
# fig, axes = plt.subplots(len(bandwidths), figsize = (12, 10), sharex = True, sharey = True)
# 
# for i, (b, filter_i) in enumerate(bandwidth_filters.items()):
#     decomposed = signal.sosfiltfilt(filter_i, final_data[ :, 77 ])
#     axes[ i ].plot(decomposed[ :sample_rate * 3 ])
#     axes[ i ].set_title(f'{b} ({bandwidths[ b ][ 0 ]}-{bandwidths[ b ][ 1 ]} Hz)')
# 
# plt.tight_layout()
# plt.show()