In [1]:
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt


In [2]:
import os

import numpy as np
# import loader as L


from scipy.interpolate import griddata, interp1d

from scipy.ndimage import rotate

from functools import partial
from math import sin, cos, pi

In [3]:
import logging
import sys
from braindecode.datasets.bbci import  BBCIDataset
from collections import OrderedDict
from braindecode.datautil.trial_segment import \
    create_signal_target_from_raw_mne
from braindecode.mne_ext.signalproc import mne_apply, resample_cnt
from braindecode.datautil.signalproc import highpass_cnt
from braindecode.datautil.signalproc import exponential_running_standardize

import torch

log = logging.getLogger(__name__)
log.setLevel('DEBUG')

logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                 level=logging.DEBUG, stream=sys.stdout)

In [4]:

## These are the relative locations of the
## corresponding channels (as a comment, to the right).
## The input channel data is supposed to be in this order.
locs = [[ 35.5,  -0.1 ],    # Fz
        [ 12.2,  56.16],    # FC3
        [  9.9,  27.6 ],    # FC1
        [  9.4,  -0.1 ],    # FCz
        [ 10.5, -28.32],    # FC2
        [ 13  , -57   ],    # FC4
        [-15.1,  87.48],    # C5
        [-16.2,  57.48],    # C3
        [-16.5,  28.44],    # C1
        [-16.8, - 0.1 ],    # Cz
        [-16.3, -29.64],    # C2
        [-15.5, -58.68],    # C4
        [-14.1, -87.6 ],    # C6
        [-40.9,  53.16],    # CP3
        [-39.3,  26.4 ],    # CP1
        [-39.3, - 0.12],    # CPz
        [-38.9, -27.84],    # CP2
        [-40.2, -54.24],    # CP4
        [-65.5,  22.08],    # P1
        [-63.8, - 0.12],    # Pz
        [-65.1, -23.28],    # P2
        [-79.7, - 0.1 ]]    # POz

# conversion to an np-array, because further manipulation will happen
# in combination wit other arrys.
locs = np.array(locs)

# labels for the locations of the aforementioned channels
channel_labels = ['Fz'
        , 'FC3'
        , 'FC1'
        , 'FCz'
        , 'FC2'
        , 'FC4'
        , 'C5'
        , 'C3'
        , 'C1'
        , 'Cz'
        , 'C2'
        , 'C4'
        , 'C6'
        , 'CP3'
        , 'CP1'
        , 'CPz'
        , 'CP2'
        , 'CP4'
        , 'P1'
        , 'Pz'
        , 'P2'
        , 'POz'
        ]


In [5]:
def rotate2D(point: list) -> list:
    """ Function to rotate a given point by 45°.
    This is the classic rotational matrix given by
    [cos(ϕ) -sin(ϕ)]
    [sin(ϕ)  cos(ϕ)]
    multiplied with the point.
    
    Args:
        point (list): Two elements. also works with any other iterable type with at least two elements.
        
    Returns:
        point (list): Two elements, their location rotated by 45° around the center.
        
    Throws:
        IndexOutOfBoundsError if 'point' cannot be indexed or has at least two elements.
    """
    degree = (45 * 2 * pi) / 360
    return [point[0] * cos(degree) - point[1] * sin(degree), point[0] * sin(degree) + point[1] * cos(degree)]

In [6]:
# def gen_images(locs, features, n_gridpoints, normalize=True,
#                augment=False, pca=False, std_mult=0.1, n_components=2, edgeless=False):
# needs defined: locs, features, n_gridpoints

# earlier interpolation function. Not as 'general', but supports rotation.

def interpolate(locs, features, n_gridpoints=32, rotate=True, edge_channels=[0, 6, 12, 21]):
    """
    Generates EEG images given electrode locations in 2D space and multiple feature values for each electrode

    Args:
        locs: An array with shape [n_electrodes, 2]
            containing X, Y coordinates for each electrode.
        features: Feature matrix as [n_samples, n_features]
            Features are as columns. Features corresponding
            to each frequency band are concatenated.
            (alpha1, alpha2, ..., beta1, beta2,...)
        n_gridpoints: Number of pixels in the output images.
            Default=32
        rotate: If the whole image should be rotated by 45 degrees 
            and fit to a square.
            Default=True
        edge_channels: The indices of the channels with the highest, 
            leftmost, rightmost and lowest coordinates.
            Only relevant when rotating.
            (Assuming that Fz is on the x axis to the right.)
    
    Returns:
        Tensor of size [samples, colors, W, H] containing
        generated images.
    """
    
    if rotate:
        high_x = locs[edge_channels[0]][0]
        high_y = locs[edge_channels[1]][1]
        low_y  = locs[edge_channels[2]][1]
        low_x  = locs[edge_channels[3]][0]
        
        t_x = abs(high_x) + abs(low_x)
        t_y = abs(high_y) + abs(low_y)
        
        factor = t_y / t_x
        
        # centering:
        xdiff = (high_x + low_x) / 2
        ydiff = (high_y + low_y) / 2
        
        locst = zip(map(lambda x: (x - xdiff) * factor, fst(locs)),
                    map(lambda y: (y - ydiff), snd(locs)))
        
        locs = np.array(list(map(rotate2D, locst)))
    
    cut = lambda a: a if len(a.shape) < 3 else a[:, :, 0]
    
    feat_array_temp = []
    nElectrodes = locs.shape[0]     # Number of electrodes
    # Test whether the feature vector length is divisible by number of electrodes
    assert features.shape[1] % nElectrodes == 0
    n_colors = int(features.shape[1] / nElectrodes)
    for c in range(n_colors):
        feat_array_temp.append(cut(features[:, c * nElectrodes : nElectrodes * (c+1)]))
    nSamples = features.shape[0]
    # getting some surrounding values right
    grid_x, grid_y = np.mgrid[
                     min(locs[:, 0]):max(locs[:, 0]):n_gridpoints*1j,
                     min(locs[:, 1]):max(locs[:, 1]):n_gridpoints*1j
                     ]
    temp_interp = []
    for c in range(n_colors):
        temp_interp.append(np.zeros([nSamples, n_gridpoints, n_gridpoints]))
    # Interpolating
    for i in range(nSamples):
        for c in range(n_colors):
            temp_interp[c][i, :, :] = cut(griddata(locs, feat_array_temp[c][i, :], (grid_x, grid_y),
                                    method='cubic', fill_value=np.nan))
    #                                 method='cubic', fill_value=np.nan)# [:, :, 0]
    #                                 method='cubic', fill_value=np.nan)
        print('Interpolating {0}/{1}'.format(i+1, nSamples), end='\r')
    images = np.swapaxes(np.asarray(temp_interp), 0, 1)     # swap axes to have [samples, colors, W, H]
    return images

In [7]:
# this is a relatively ordered list including the mapping 
# of each of the channels to a location in a 19x11 grid.
# Various other channels exist, so note that this list is not complete.
tight_cap_positions = [
    ['', '', '', '', 'Fp1', 'FPz', 'Fp2', '', '', '', ''],
    ['', '', '', 'AFp3h', '', '', '', 'Afp4h', '', '', ''],
    ['', 'AF7', 'AF5', 'AF3', 'AF1', 'AFz', 'AF2', 'AF4', 'AF6', 'AF8', ''],
    ['', '', 'AFF5h', '', 'AFF1', '', 'AFF2', '', 'AFF6h', '', ''],
    ['', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', ''],
    ['FFT9h', 'FFT7h', 'FFC5h', 'FFC3h', 'FFC1h', '', 'FFC2h', 'FFC4h', 'FFC6h', 'FFT8h', 'FFT10h'],
    ['FT9', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'FT10'],
    ['FTT9h', 'FTT7h', 'FCC5h', 'FCC3h', 'FCC1h', '', 'FCC2h', 'FCC4h', 'FCC6h', 'FTT8h', 'FTT10h'],
    ['M1', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8', 'M2'],
    ['', 'TTP7h', 'CCP5h', 'CCP3h', 'CCP1h', '', 'CCP2h', 'CCP4h', 'CCP6h', 'TTP8h', ''],
    ['TP9', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10'],
    ['TPP9h', 'TPP7h', 'CPP5h', 'CPP3h', 'CPP1h', '', 'CPP2h', 'CPP4h', 'CPP6h', 'TPP8h', 'TPP10h'],
    ['P9', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'P10'],
    ['PPO9h', '', 'PPO5h', '', 'PPO1', '', 'PPO2', '', 'PPO6h', '', 'PPO10h'],
    ['PO9', 'PO7', 'PO5', 'PO3', 'PO1', 'POz', 'PO2', 'PO4', 'PO6', 'PO8', 'PO10'],
    ['POO9h', '', '', 'POO3h', '', '', '', 'POO4h', '', '', 'POO10h'],
    ['', '', '', '', 'O1', 'Oz', 'O2', '', '', '', ''],
    ['', '', '', '', 'OI1h', '', 'OI2h', '', '', '', ''],
    ['', '', '', '', 'I1', 'Iz', 'I2', '', '', '', '']]



In [8]:
def mapper(locs):
    """ This is a higher-order-function, returning a function which will,
    If you give it a list of values, map them to the corresponding location
    given from the 'locs' specifier.
    To be more specific, 'locs' provides the relative index for the values,
    the mapping happens from the 'locs' based locations to the locations
    provided by tight_cap_positions.
    
    Mapper Creation: O(1)
    Actual Mapping: O(11x19x|locs|)
    
    Full function signature:
        mapper :: [[String]] -> ([[Num]] -> np.[[Num]])
    
    Args:
        locs (list): a list of channels, of type [[String]], an
            Empty string denoting that no channel be mapped there.
        
    Returns:
        func
    """
    def actual(values):
        nmap = np.empty((19, 11))
        for i, row in enumerate(tight_cap_positions):
            for j, col in enumerate(row):
                if col in locs:
                    ind = locs.index(col)
                    nmap[i, j] = values[ind]
                else:
                    nmap[i, j] = np.nan
        return nmap
    return actual

In [9]:
def mapper2(locs):
    """ This is a higher-order linear function. Functionality
    is the same as 'mapper', as well as type signature.
    It is supposed to be faster than mapper, and requires the
    original mapper for map creation.
    
    Suposedly faster, especially for larger |locs| values.
    
    Also, this structure can be considerably sped up with
    current caching/pipelining/processor architecture techniques,
    in comparison to the 'normal' mapper function. So even though
    it might seem worse complexity-wise (even if only by a constant factor),
    it is worth a shot, as 'mapper' is bound to irritate the
    processor considerably, whereas this function does not.
    
    Also, mapping does not return 'NaN's if there is no result,
    but f64 zeros.
    
    Mapper Creation: O(11x19x(|locs| + c))
    Actual Mapping: O(129x209 + c)
    """
    
    res = mapper(locs)(list(range(len(locs))))
    res = res.reshape(209)
    matrix = np.zeros((209, 129))
    for i, value in enumerate(res):
        if not np.isnan(value):
            matrix[i, int(value)] = 1
    return lambda v: matrix.dot(v).reshape((19, 11))

In [10]:
# these are all channels (in this order) from the High-Gamma Dataset.
positions = [
    'Fp1', 'Fp2', 'Fpz', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FC5',
    'FC1', 'FC2', 'FC6', 'M1', 'T7', 'C3', 'Cz', 'C4', 'T8', 'M2',
    'CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8', 'POz',
    'O1', 'Oz', 'O2', 'AF7', 'AF3', 'AF4', 'AF8', 'F5', 'F1',
    'F2', 'F6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', 'CP3',
    'CPz', 'CP4', 'P5', 'P1', 'P2', 'P6', 'PO5', 'PO3', 'PO4', 'PO6',
    'FT7', 'FT8', 'TP7', 'TP8', 'PO7', 'PO8', 'FT9', 'FT10', 'TPP9h',
    'TPP10h', 'PO9', 'PO10', 'P9', 'P10', 'AFF1', 'AFz', 'AFF2', 'FFC5h',
    'FFC3h', 'FFC4h', 'FFC6h', 'FCC5h', 'FCC3h', 'FCC4h', 'FCC6h',
    'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 'CPP5h', 'CPP3h', 'CPP4h',
    'CPP6h', 'PPO1', 'PPO2', 'I1', 'Iz', 'I2', 'AFp3h', 'AFp4h', 'AFF5h',
    'AFF6h', 'FFT7h', 'FFC1h', 'FFC2h', 'FFT8h', 'FTT9h', 'FTT7h',
    'FCC1h', 'FCC2h', 'FTT8h', 'FTT10h', 'TTP7h', 'CCP1h', 'CCP2h',
    'TTP8h', 'TPP7h', 'CPP1h', 'CPP2h', 'TPP8h', 'PPO9h', 'PPO5h',
    'PPO6h', 'PPO10h', 'POO9h', 'POO3h', 'POO4h', 'POO10h', 'OI1h',
    'OI2h', 'STI 014']


In [11]:
def load_bbci_data(filename, low_cut_hz):
    """ Provided loader function. Might need a rewrite on it's own.
    """
    load_sensor_names = None

    # we loaded all sensors to always get same cleaning results independent of sensor selection
    # There is an inbuilt heuristic that tries to use only EEG channels and that definitely
    # works for datasets in our paper
    loader = BBCIDataset(filename, load_sensor_names=load_sensor_names)

    log.info("Loading data...")
    cnt = loader.load()

    # Cleaning: First find all trials that have absolute microvolt values
    # larger than +- 800 inside them and remember them for removal later
    log.info("Cutting trials...")

    marker_def = OrderedDict([('Right Hand', [1]), ('Left Hand', [2],),
                              ('Rest', [3]), ('Feet', [4])])
    clean_ival = [0, 4000]

    set_for_cleaning = create_signal_target_from_raw_mne(cnt, marker_def,
                                                  clean_ival)

    clean_trial_mask = np.max(np.abs(set_for_cleaning.X), axis=(1, 2)) < 800

    log.info("Clean trials: {:3d}  of {:3d} ({:5.1f}%)".format(
        np.sum(clean_trial_mask),
        len(set_for_cleaning.X),
        np.mean(clean_trial_mask) * 100))

    # Further preprocessings as descibed in paper

    log.info("Resampling...")

    # now the original frequency is actually 500Hz,
    # so we can resample to everything below that.
    cnt = resample_cnt(cnt, 50.0)

    log.info("Highpassing...")
    cnt = mne_apply(
        lambda a: highpass_cnt(
            a, low_cut_hz, cnt.info['sfreq'], filt_order=3, axis=1),
        cnt)
    log.info("Standardizing...")
    cnt = mne_apply(
        lambda a: exponential_running_standardize(a.T, factor_new=1e-3,
                                                  init_block_size=1000,
                                                  eps=1e-4).T,
        cnt)

    # Trial interval, start at -500 already, since improved decoding for networks
    ival = [-500, 4000]

    dataset = create_signal_target_from_raw_mne(cnt, marker_def, ival)
    dataset.X = dataset.X[clean_trial_mask]
    dataset.y = dataset.y[clean_trial_mask]
    return dataset

In [12]:
class LazyHGDLoader:
    """ LazyLoader of High-Gamma BBCI Dataset.
    """

    # adjust FILENAME template as needed.
    FILENAME = "/datadisk/Coding/high-gamma-dataset/data/%s/%i.mat"
    dataset_number = None
    low_cut_hz = 4

    def __init__(self, number: int, low_cut_hz: int = 4):
        """
        Args:
            number (int): The number of the Dataset to be loaded.
            low_cut_hz (int): The lower threshold for cutting frequencies.
                Default=4
        """
        self.dataset_number = number
        self.low_cut_hz = low_cut_hz

    @property
    def train(self):
        try:
            return self.train_set
        except AttributeError:
            filename = self.FILENAME % ('train', self.dataset_number)
            self.train_set = load_bbci_data(filename=filename, low_cut_hz=low_cut_hz)

    @property
    def test(self):
        try:
            return self.test_set
        except AttributeError:
            filename = self.FILENAME % ('test', self.dataset_number)
            self.test_set = load_bbci_data(filename=filename, low_cut_hz=low_cut_hz)

In [13]:
def interpolate2(features, ch_names=None, ch_mapper=None, grid_x=11, grid_y=19, method='linear'):
    """ More recent interpolation function. Also, much faster. Does not support Rotation. though.
    
    Args:
        features (array-like): the values between which will be interpolated.
        ch_names ([str]): Optional list with the order of channels from 'features'.
        ch_mapper (mapper): Optional Channel to Location mapper.
        grid_x (int): 
        grid_y (int): 
        method (str): 
    
    Returns:
    """
    assert ch_mapper or ch_names, "needs either channel names or a mapper instance!"
    if not ch_mapper:
        ch_mapper = mapper2(ch_names)

    locs = list(map(tuple, np.argwhere(~np.isnan(ch_mapper(features[0, :, 0])))))

    result = np.empty((features.shape[0], grid_y * grid_x, features.shape[2]))

    for trial in range(features.shape[0]):
        for frame in range(features.shape[2]):
            r = ch_mapper(features[trial, :, frame])
            res = griddata(locs, r[~np.isnan(r)], tuple(np.mgrid[0:grid_y, 0:grid_x]), method=method)
            res = res.reshape(grid_y * grid_x)
            result[trial, :, frame] = res

        print("Interpolating: %s/%s" % (trial, features.shape[0]), end='\r')

    return result

In [14]:
def one_hot_labels(label, classes=4):
    """ Classic One-Hot encoding for the labels.
    Args:
        label (array-like): The actual labels,
            Numbers from 0 to classes-1.
        classes (int): The number of distinct categories to encode.
            Default=4
    
    Returns:
        PyTorch Tensor containing One-Hot encodings of the labels.
    """
    one_hot = torch.zeros(label.shape[0], classes)
    one_hot[torch.arange(label.shape[0]), label] = 1
    return one_hot

In [15]:
def try_load_compute_save(filename: str, func: callable):
    """ Small helper function. Try to load the given
    File if the filename exists, and otherwise compute and save it.
    
    Args:
        filename (str): Name of the file that should be loaded.
        func (callable): Function to compute and return the data,
            if it was not found as a file with 'filename'.
    
    Returns:
        The content of the loading-, or computing-operation.
    """
    import pickle
    try:
        with open(filename, 'rb') as f:
            data = pickle.load(f)
            log.info("Succeeded loading " + filename + ".")
    except FileNotFoundError:
        log.info("Computing " + filename + ". This might take some time.")
        data = func()
        with open(filename, 'wb') as f:
            pickle.dump(data, f)
    return data

tlcs = lambda fn, fu: try_load_compute_save(fn, fu)

In [16]:
# This is the main functionality.
#
# Basically, 'load the values from cache', if they are there,
# if not, compute them, and for that 'load the training set from cache'
# if it is there, if not, recompute it from the RAW Data as well.

# creation of a hgd loader for the set with number 2.
hgd = LazyHGDLoader(2)
r = mapper2(positions)

log.info("Loading interpolated image data.")

# Do note that calling 'hgd.train.X' does not evaluate 'hgd.train' yet.
# Also, it will only be computed at most once, even if it would be required for both.
images = tlcs('images_hgd', lambda: interpolate2(hgd.train.X, ch_mapper=r))
labels = tlcs('labels_hgd', lambda: one_hot_labels(hgd.train.y))

# same here.
images_test = tlcs('images_test_hgd', lambda: interpolate2(hgd.test.X, ch_mapper=r))
labels_test = tlcs('labels_test_hgd', lambda: one_hot_labels(hgd.test.y))


log.info("Successfully loaded interpolated images")

2019-11-10 20:41:35,821 INFO : Loading interpolated image data.
2019-11-10 20:41:36,148 INFO : Succeeded loading images_hgd.
2019-11-10 20:41:36,150 INFO : Succeeded loading labels_hgd.
2019-11-10 20:41:36,227 INFO : Succeeded loading images_test_hgd.
2019-11-10 20:41:36,228 INFO : Succeeded loading labels_test_hgd.
2019-11-10 20:41:36,229 INFO : Successfully loaded interpolated images
