In [None]:
# Importing Python and external packages
import os
import sys
import importlib
import json
import csv
from dataclasses import dataclass, field, fields
from collections import namedtuple
from typing import Any
from itertools import compress
from pathlib import Path
import pandas as pd
import numpy as np
import sklearn as sk
from scipy.stats import pearsonr, mannwhitneyu

import matplotlib.pyplot as plt
from  matplotlib import __version__ as plt_version
from scipy import signal, stats
# from array import array
# import datetime as dt
# #mne
# import mne_bids
# import mne


In [None]:
# check some package versions for documentation and reproducability
print('Python sys', sys.version)
print('pandas', pd.__version__)
print('numpy', np.__version__)
# print('mne_bids', mne_bids.__version__)
# print('mne', mne.__version__)
# print('sci-py', scipy.__version__)
print('sci-kit learn', sk.__version__)
print('matplotlib', plt_version)
## FEB 2022:
# Python sys 3.9.7 (default, Sep 16 2021, 08:50:36) 
# [Clang 10.0.0 ]
# pandas 1.3.4
# numpy 1.20.3
# mne_bids 0.9
# mne 0.24.1
# sci-py 1.7.1
# sci-kit learn 1.0.1

In [None]:
def get_project_path_in_notebook(
    subfolder: str = '',
):
    """
    Finds path of projectfolder from Notebook.
    Start running this once to correctly find
    other modules/functions
    """
    path = os.getcwd()

    while path[-20:] != 'dyskinesia_neurophys':

        path = os.path.dirname(path)
    
    return path

In [None]:
# define local storage directories
projectpath = get_project_path_in_notebook()
codepath = os.path.join(projectpath, 'code')
figpath = os.path.join(projectpath, 'figures')
datapath = os.path.join(projectpath, 'data')

In [None]:
os.chdir(codepath)
# own utility functions
import utils.utils_fileManagement as utilsFiles
import utils.utils_windowing as utilsWindows
from utils.utils_fileManagement import (get_project_path,
                                        load_class_pickle,
                                        save_class_pickle,
                                        mergedData,
                                        correct_acc_class)
# own data preprocessing functions
import lfpecog_preproc.preproc_data_management as dataMng
import lfpecog_preproc.preproc_filters as fltrs
# own data exploration functions
import lfpecog_features.feats_read_proc_data as read_data
import lfpecog_features.feats_spectral_baseline as specBase
import lfpecog_features.feats_spectral_features as spectral
import lfpecog_features.feats_spectral_helpers as specHelp
import lfpecog_features.feats_helper_funcs as ftHelp

import lfpecog_preproc.preproc_import_scores_annotations as importClin
import lfpecog_analysis.import_ephys_results as importResults
# import lfpecog_analysis.get_acc_derivs as accDerivs


import lfpecog_plotting.plotHelpers as plotHelp

### Load TMSi File

In [None]:
import lfpecog_preproc.tmsi_poly5reader as tmsiReader
import lfpecog_preproc.preproc_load_raw as loadRaw
import lfpecog_preproc.preproc_get_mne_data as loadData
import lfpecog_preproc.preproc_resample as resample
import lfpecog_preproc.preproc_filters as filters

In [None]:
def get_download_folder(user='habetsj'):
    path = os.getcwd()
    while not path.endswith('habetsj'):
        path = os.path.dirname(path)

    path = os.path.join(path, 'Downloads')
    return path

In [None]:
entrain_path = os.path.join(os.path.dirname(utilsFiles.get_onedrive_path(folder='onedrive')),
                            'Dokumente', 'data', 'entrain')

sub = 'test02'
rec_name = '578EA61_ELMedOn2_Task_StimOffDopa05_1 - 20231205T102635'  # REST OFF
rec_name = '578EA61_ELMedOn2_StimOn130_Dopa52_1 - 20231205T111224'  # STIM 130 ON

for f in os.listdir(os.path.join(entrain_path,
                                 'source_data', sub, rec_name)):
    if f.endswith('Poly5'):
        
        raw = tmsiReader.Poly5Reader(
            os.path.join(entrain_path, 'source_data', sub, rec_name, f)
        )
        break

In [None]:
# filename = '514IL50_M0S0_Rest_test.Poly5'

# path = os.path.join(get_download_folder(), filename)
# print(os.path.exists(path))
# raw = tmsiReader.Poly5Reader(path)



In [None]:
@dataclass(init=True,)
class ephys_data:
    sub: str
    data: any
    ch_names: list
    sfreq: int
    ch_coding: dict = field(
        default_factory= lambda: {
            'STN_L': ['STN', 'L'],
            'STN_R': ['STN', 'R'],
            'ECOG': ['ECX']
        }
    )
    remove_start_sec: int = 5
    remove_end_sec: int = 5
    
    def __post_init__(self,):
        # remove start and end seconds for artefacts
        self.data = self.data[:, self.remove_start_sec*self.sfreq:
                                 -self.remove_end_sec*self.sfreq]
        # correct unit to microVolt
        self.data /= 1e6
        # get timestamps
        self.run_duration = (1 / self.sfreq) * self.data.shape[1]
        self.times = np.arange(0, self.run_duration, 1 / self.sfreq)
        
        for electrode in self.ch_coding.keys():

            print(f'select {electrode}')
            setattr(self,
                    electrode,
                    electrode_data(electrode_type=electrode,
                                   main_class=self))



In [None]:
@dataclass(init=True,)
class electrode_data:
    electrode_type: str
    main_class: Any

    def __post_init__(self,):
        assert (
            self.electrode_type in self.main_class.ch_coding.keys()
        ), 'incorrect electrode type given'

        self.sfreq = self.main_class.sfreq
        self.times = self.main_class.times.copy()
        # select data
        sel_strs = self.main_class.ch_coding[self.electrode_type]
        elec_sel = [all([s in c for s in sel_strs])
                    for c in self.main_class.ch_names]
        self.data = self.main_class.data[elec_sel, :]
        self.ch_names = list(compress(self.main_class.ch_names, elec_sel))

        # BandPass-Filtering
        self.data = filters.bp_filter(data=self.data,
                                      n_timeRows=0,
                                      Fs=self.sfreq,
                                      l_freq=2, h_freq=500, method='iir',)

        # Notch-Filtering
        self.data = filters.notch_filter(data=self.data,
                                        n_timeRows=0,
                                        Fs=self.sfreq,
                                        transBW=10,
                                        notchW=4,
                                        method='fir',
                                        verbose=False,)
        
        # Resampling
        sfreq_old = self.sfreq
        sfreq_new = 1024
        self.data = resample.resample(
            data=self.data,
            n_timeRows=0,
            Fs_orig=sfreq_old,
            Fs_new=sfreq_new,
        )
        self.times = np.linspace(self.times[0], self.times[-1],
                                 self.data.shape[1])
        self.sfreq = sfreq_new    


In [None]:
importlib.reload(resample)

dat = ephys_data(sub=sub,
                 data=raw.samples,
                 ch_names=raw.ch_names,
                 sfreq=raw.sample_rate,)



In [None]:
%matplotlib qt

SOURCE = 'STN_L'

for i_ch, ch in enumerate(getattr(dat, SOURCE).ch_names):

    plt.plot(getattr(dat, SOURCE).times,
             getattr(dat, SOURCE).data[i_ch, :], label=ch)

plt.legend(loc='upper left', ncol=3)
plt.show()


In [None]:
%matplotlib inline

In [None]:
def get_psd_window_channels(tempdat, WIN_MS=250):

    out_list = []

    for i_ch in np.arange(len(tempdat.ch_names)):
        win_samples = int(tempdat.sfreq / int(tempdat.sfreq/WIN_MS))
        n_windows = int(tempdat.data.shape[-1]/win_samples)
        welch_dat = np.reshape(tempdat.data[i_ch, :win_samples*n_windows],
                            (n_windows, win_samples))
        f, pxx = signal.welch(welch_dat, tempdat.sfreq,
                            axis=1, noverlap=win_samples//2,
                            nperseg=win_samples,)

        # z score per freq bin
        for i_f in np.arange(len(f)):
            m = np.mean(pxx[:, i_f])
            sd = np.std(pxx[:, i_f])
            pxx[:, i_f] = (pxx[:, i_f] - sd) / m

        out_list.append((f, pxx))

    return out_list 

In [None]:
source_psx = get_psd_window_channels(tempdat=getattr(dat, SOURCE))

In [None]:
fig, axes = plt.subplots(len(source_psx), 1,
                         sharex='col',
                         figsize=(8, 12))
FS = 16
stim_ups = [54, 132, 240]

for i_ch, ch in enumerate(getattr(dat, SOURCE).ch_names):
    f, psdGrid = source_psx[i_ch]
    win_times = np.linspace(getattr(dat, SOURCE).times[0],
                            getattr(dat, SOURCE).times[-1],
                            psdGrid.shape[0])

    axes[i_ch].pcolormesh(win_times, f, psdGrid.T, cmap='viridis',
                vmin=-6, vmax=6)
    axes[i_ch].set_ylim(0, 100)

    axes[i_ch].set_ylabel(f'{ch}\n\nFreq. (Hz)', rotation=90,
                          size=FS,)
    
    for x in stim_ups: axes[i_ch].axvline(x, ymin=0, ymax=1,
                                          lw=3, alpha=.5,
                                          color='orange')
axes[-1].set_xlabel('Recording time (sec)', size=FS,)

for ax in axes: ax.tick_params(axis='both', size=FS, labelsize=FS,)
plt.tight_layout()

plt.savefig(os.path.join(entrain_path, 'figures',
                         f'test2_{SOURCE}_timefreq_130HzStimDopa50'),
            dpi=150, facecolor='w',)

plt.close()