# Pipeline

In [2]:
import mne
import numpy as np
from scipy.signal import find_peaks
import scipy.stats as sp_stats
import matplotlib.pyplot as plt

from NirsLabProject.config.consts import *
from NirsLabProject.config.subject import Subject

from NirsLabProject.utils import *
from NirsLabProject.utils.sleeping_utils import *

sleep_cycle = False
subject = Subject('p487', True)

raw = mne.io.read_raw_fif(subject.paths.subject_resampled_fif_path)
# raw.plot()

    

In [3]:
# spikes = np.load(subject.paths.subject_spikes_path)
from scipy.stats import skew, kurtosis

def find_high_spikes_channels(spikes: np.ndarray, raw: mne.io.Raw, threshold: float = 5):
    high_spikes_channels = {}
    spikes_per_minute = np.zeros_like(raw.ch_names, dtype=float)
    for i, channel_name in enumerate(spikes):
        channel_raw = raw.copy().pick_channels([channel_name])
        channel_spikes = spikes[channel_name]
        number_of_spikes = channel_spikes.shape[0]
        channel_total_time_in_minutes = (channel_raw.tmax - channel_raw.tmin) / (60)
        spikes_per_minute[i] = number_of_spikes / channel_total_time_in_minutes

    zscores = sp_stats.zscore(spikes_per_minute)
    print(zscores)
    print(spikes_per_minute)
    for i, channel_name in enumerate(spikes):
        if zscores[i] > 1:
            print(channel_name)
            
    return high_spikes_channels
            

def find_bad_channels(spikes: np.ndarray, raw: mne.io.Raw):
    n_time_windows = int(raw.tmax-raw.tmin)
    channel_names = list(raw.ch_names)
    bad_channels = np.zeros((len(channel_names), 3))
    print(bad_channels.shape)
    
    area_norms = {}
    area_norms_stds = {}
    area_norms_skew = {}
    area_norms_kurtosis = {}

    total_channel_index = 0
    for i, channel_name in enumerate(channel_names):
        channel_area = channel_name[:-1]
        print(channel_area)
        if channel_area not in area_norms:
            area_norms[channel_area] = []
            area_norms_stds[channel_area] = []
            area_norms_skew[channel_area] = []
            area_norms_kurtosis[channel_area] = []
            
        channel_index = channel_name[-1]
        channel_raw = raw.copy().pick_channels([channel_name])
        channel_data = channel_raw.get_data()[0]
        norms = np.zeros(n_time_windows)
        for window_index in range(n_time_windows):
            norms[window_index] = np.sqrt(np.nansum(channel_data[(window_index-1)*SR:window_index*SR]**2))
        area_norms[channel_area].append(norms)
        area_norms_stds[channel_area].append(np.std(norms))
        area_norms_skew[channel_area].append(skew(norms))
        area_norms_kurtosis[channel_area].append(kurtosis(norms))
        
        if i == len(channel_names)-1 or channel_names[i+1][-1] < channel_index:
            print(total_channel_index)
            for channel_area, channels in area_norms.items():
                area_norms_stds_median = np.median(area_norms_stds[channel_area])
                area_norms_skew_median = np.median(area_norms_skew[channel_area])
                area_norms_kurtosis_median = np.median(area_norms_kurtosis[channel_area])


                for i, _ in enumerate(channels):
                    if area_norms_stds[channel_area][i] > area_norms_stds_median*4:
                        bad_channels[total_channel_index][0] = 2
                    elif area_norms_stds[channel_area][i] > area_norms_stds_median*2:
                          bad_channels[total_channel_index][0] = 1.5
                    elif area_norms_stds[channel_area][i] > area_norms_stds_median*1.5:
                        bad_channels[total_channel_index][0] = 1

                    if area_norms_skew[channel_area][i] > area_norms_skew_median*4:
                        bad_channels[total_channel_index][1] = 2
                    elif area_norms_skew[channel_area][i] > area_norms_skew_median*2:
                        bad_channels[total_channel_index][1] = 1

                    if area_norms_kurtosis[channel_area][i] > area_norms_kurtosis_median*4:
                        bad_channels[total_channel_index][2] = 2
                    elif area_norms_kurtosis[channel_area][i] > area_norms_kurtosis_median*2:
                        bad_channels[total_channel_index][2] = 1
                                            
                    total_channel_index += 1
            area_norms = {}
        
    return bad_channels
    
# channels_names_list = list(spikes.keys())
print(43344)
bd = find_bad_channels([], raw.copy())
print(6666666)

43344
(72, 3)
LA
LA
LA
LA
LA
LA
LA
0
LAH
LAH
LAH
LAH
LAH
LAH
LAH
7
LEC
LEC
LEC
LEC
LEC
LEC
LEC
14
LOF
LOF
LOF
LOF
LOF
LOF
LOF
LOF
21
LPHG
LPHG
LPHG
LPHG
LPHG
LPHG
LPHG
29
LSTG
LSTG
LSTG
LSTG
LSTG
LSTG
LSTG
36
RA
RA
RA
RA
RA
RA
RA
RA
43
RAH
RAH
RAH
RAH
RAH
RAH
RAH
51
REC
REC
REC
REC
REC
REC
REC
58
ROF
ROF
ROF
ROF
ROF
ROF
ROF
65
6666666


In [33]:
for i, line in enumerate(bd):
    if np.sum(bd[i]) > 1:
        print(raw.ch_names[i], bd[i])

ROF3 [1.5 1.  2. ]
RAC1 [1.5 0.  0. ]
LA7 [2. 1. 2.]
LOF7 [0. 0. 2.]
LAC1 [2. 1. 2.]


In [38]:
window_width = 100  # in milliseconds

# Group the timestamps based on the window_width
groups = []
group = [arr[0]]
for i in range(1, arr.shape[0]):
    if group[0][0] + window_width > arr[i][0]:
        group.append(arr[i])
    else:
        groups.append(group)
        group = [arr[i, :]]

import csv
with open('output.csv', 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    for group in groups:
        hemisphers = set()
        stractures = set()
        size = len(group)
        fist_timestamp = group[0][0]
        last_timestamp = group[-1][0]
        first = sorted([spike for spike in group if spike[0] == fist_timestamp], key=lambda x: (x[1], x[2]))[-1]
        first = index_to_channel[first[1]]
        timing_diffrence = last_timestamp - fist_timestamp
        for record in group:
            channel = index_to_channel[record[1]]
            hemisphers.add(channel[0])
            stractures.add(channel[:-1])
        if size > 2:
            row = [size, first, (hemisphers), (stractures), timing_diffrence]
            writer.writerow(row)
#             print(f'Grouop size {size} | Focal: {first} | Hemisphers: {hemisphers} | Stractures: {stractures} | Time Difrences: {timing_diffrence}')


FileNotFoundError: [Errno 2] No such file or directory: '~/output.csv'

In [None]:
import os
from typing import Dict
from NirsLabProject import spikes_detection
from NirsLabProject.pipeline import channel_processing

spikes = spikes_detection.detect_spikes_of_subject(subject, raw)
channel_names = spikes.keys()
channels_spikes = Parallel(n_jobs=os.cpu_count(), backend='multiprocessing')(
        delayed(channel_processing)(subject, raw, dict(spikes), channel_name, i) for i, channel_name in enumerate(channel_names)
    )
channels_spikes_features = {channel_name: channel_spikes for channel_name, channel_spikes in zip(channel_names, channels_spikes)}
channel_spikes = {channel_name: channel_spikes[:,0] / SR for channel_name, channel_spikes in channels_spikes_features.items()}

print('Finished')

In [106]:
from NirsLabProject.group_spikes import MinHeap

class Group:
    def __init__(self, group, group_index, index_to_channel):
        self._group = group
        self.index = group_index
        self.size = len(group)
        self.fist_event_timestamp = group[0][TIMESTAMP_INDEX]
        self.last_event_timestamp = group[-1][TIMESTAMP_INDEX]
        self.group_event_duration = self.last_event_timestamp - self.fist_event_timestamp
        self.focal_channnel_index = sorted(
            [spike for spike in group if spike[0] == self.fist_event_timestamp],
            key=lambda x: (x[1], x[2])
        )[-1][1]
        self.focal_channnel_name = index_to_channel[self.focal_channnel_index]

        self.hemisphers = set()
        self.stractures = set()

        for record in group:
            channel = index_to_channel[record[1]]
            self.hemisphers.add(channel[0])
            self.stractures.add(channel[:-1])
            
        def __str__(self):
            return f'Grouop size {size} | Focal: {first} | Hemisphers: {hemisphers} | Stractures: {stractures} | Time Difrences: {timing_diffrence}'


def group_spikes(channels_spikes: Dict[str, np.ndarray]):
    index_to_channel = {}
    for i, channel_name in enumerate(channels_spikes.keys()):
        index_to_channel[i] = channel_name
    
    # Merge all the spikes into one sorted array
    all_spikes = [spikes for spikes in channels_spikes.values() if spikes.shape[0] > 0]
    all_spikes_flat = MinHeap.mergeKSortedArrays(all_spikes, len(all_spikes))
            
    # Group the timestamps based on the window_width
    groups_list = []
    group_index_to_group = {}
    group = [all_spikes_flat[0]]
    for i in range(1, all_spikes_flat.shape[0]):
        if group[0][0] + SPIKES_GROUPING_WINDOW_SIZE > all_spikes_flat[i][0]:
            group.append(all_spikes_flat[i])
        else:
            groups_list.append(group)
            group = [all_spikes_flat[i, :]]
    groups_list.append(group)
    
    spike_index = 0
    all_spikes_group_indexes = np.zeros(all_spikes_flat.shape[0], dtype=int)
    for group_index, group in enumerate(groups_list):
        group = Group(group, group_index, index_to_channel)
        group_index_to_group[group_index] = group
        for i in range(group.size):
            all_spikes_group_indexes[spike_index] = group.index
            spike_index +=1
        
    all_spikes_group_indexes = all_spikes_group_indexes.reshape((-1, 1))
    all_spikes_flat = np.concatenate((all_spikes_flat, all_spikes_group_indexes), axis=1)
    
    return group_index_to_group, all_spikes_flat


    
groups, flat_features = group_spikes(channels_spikes_features)

# print(flat_features)

[[4.00000000e+00 2.20000000e+01 4.91686426e+00 8.00000000e+00
  0.00000000e+00]
 [1.80000000e+01 8.00000000e+00 4.48488629e+00 1.00000000e+01
  0.00000000e+00]
 [2.00000000e+01 5.00000000e+00 4.43637427e+00 8.00000000e+00
  0.00000000e+00]
 ...
 [2.95827000e+07 5.60000000e+01 1.46590371e+01 7.00000000e+00
  2.90480000e+04]
 [2.95827020e+07 5.00000000e+00 2.20556207e+01 6.00000000e+00
  2.90480000e+04]
 [2.95827020e+07 2.20000000e+01 2.34305851e+01 6.00000000e+00
  2.90480000e+04]]
