In [None]:
import pyecap
import os
from pathlib import Path
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import scipy.io as sio
from scipy.signal import find_peaks
from fnirs_functions import *

def filter_hr(data, cutoff = 10, downsample = False, downsample_factor = 10, check_plot=False):

    raw_ecg = data.array.compute()[0] * 1e6

    data_f = data.filter_gaussian(Wn=cutoff, btype='lowpass')
    ecg_f = data_f.array.compute()[0] * 1e6

    time = data.time().compute()

    if downsample:
        ds_ECGr = signal.decimate(raw_ecg, downsample_factor, zero_phase = True)
        ds_time = signal.decimate(time, downsample_factor, zero_phase = True)
        ds_ECGf = signal.decimate(ecg_f, downsample_factor, zero_phase = True)

        """Check alignment of filtered and downsampled signal against original"""
        if check_plot:
            fig = make_subplots(specs = [[{'secondary_y' : True}]])
            fig.add_trace(go.Scatter(x = ds_time,y = ds_ECGr, name = 'raw'))
            fig.add_trace(go.Scatter(x = ds_time,y = ds_ECGf, name = 'Filtered'), secondary_y = True)
            fig.show()

        d = ds_ECGf
        t = ds_time

    else:
        d = ecg_f
        t = time

        """Check alignment of filtered and downsampled signal against original"""
        if check_plot:
            fig = make_subplots(specs = [[{'secondary_y' : True}]])
            fig.add_trace(go.Scatter(x = time,y = raw_ecg, name = 'raw'))
            fig.add_trace(go.Scatter(x = time,y = ecg_f, name = 'Filtered'), secondary_y = True)
            fig.show()

    data_array = np.vstack((d,t))
    return data_array

def calc_hr(data, fs, peak_height, check_plot = False):
    """Calculate ECG"""
    d = data[0]
    time = data[1]
    peaks = signal.find_peaks(d, height = peak_height, distance = fs * 0.25)[0]
    peaks = peaks[1:-1]

    """Get distance (time) between peaks and calculate heart rate"""
    peak_dt = np.diff(peaks) / fs #Time between peaks in seconds

    idx = [ int((peaks[i] + peaks[i + 1]) / 2) for i in np.arange(len(peaks) - 1)]
    peak_time = time[idx]
    bpm = 60 / peak_dt #Instantaneous Heart rate in BPM based on peak_dt

    #Construct ECG dataframe?
    ecg_dct = {'peak_idx' : idx,
               'peak_dt' : peak_dt,
               'Time (s)' : peak_time,
               'b2b bpm' : bpm}

    ecgDF = pd.DataFrame(ecg_dct)
    ecgDF['smooth bpm'] = ecgDF['b2b bpm'].rolling(5).mean()

    if check_plot:
        #fig = go.Figure()
        fig = make_subplots(specs = [[{"secondary_y" : True}]])
        fig.add_trace(go.Scatter(x = time[peaks], y = d[peaks], mode = 'markers'), secondary_y = True)
        fig.add_trace(go.Scatter(x = time,y = d, name = 'ECG Trace'), secondary_y = True)
        fig.add_trace(go.Scatter(x = ecgDF['Time (s)'], y = ecgDF['b2b bpm'], name = 'HR (bpm)'))#, secondary_y=True)
        fig.show()

    return ecgDF

In [None]:
meta_index = 1
metaDF = pd.read_excel(r'D:\Data\TDT_fNIRs\20260213_fNIRs_QC\metadata.xlsx')
tank = r'D:\Data\TDT_fNIRs\20260213_fNIRs_QC\CVP_SingleCh_Ramp-260213\\' + metaDF.at[meta_index, 'Tank']
flex_folder = metaDF.loc[meta_index, 'fNIRs Folder'] + '.mat'
fd = flexNIRs(flex_folder)

ecg_data = pyecap.Ephys(tank, stores = 'ECGG')
ecg_data = ecg_data.remove_ch(channels=['ECGG 2','ECGG 3', 'ECGG 4'])

stim = pyecap.Stim(tank)
stimDF = stim.parameters

In [None]:
"""Align fNIRs Artifact"""
alignDF = stimDF.loc[stimDF['pulse amplitude (μA)'] < 0]
fd.manual_alignment(stimDF = alignDF, stim_start_index = int(62*266))
fd.plot_artifact(channel = 'D3 Ambient', show_stim = True)

In [None]:
fd.calc_hr(channel = 'SS Red', filter_cutoffs=(0.5,15),peak_height = 2000, check_plot = True)
fd.plot_channel_interactive(data_type = 'Mua_filt', channel = 'D3 LS Red', show_stim =False, show_hr=True, hr_chan='smooth')

In [None]:
#Filter and down sample TDT ECG data
tdt_ecg = filter_hr(ecg_data, cutoff = 10, downsample=False, downsample_factor = 10, check_plot=False)
ecgDF = calc_hr(tdt_ecg, peak_height = 250, fs = ecg_data.sample_rate)

In [None]:
"""ECG"""
fig = go.Figure()
fig.add_trace(go.Scatter(x = ecgDF['Time (s)'], y = ecgDF['b2b bpm']))

# for param in stimDF.index:
#         fig.add_vrect(x0=fd.stimDF.loc[param]['onset time (s)'],
#                   x1=fd.stimDF.loc[param]['offset time (s)'])

fig.show()

In [None]:
fd.plot_channel_interactive(data_type = 'SSR_filt', channel = 'D1 LS Red SSR', show_stim =True, show_hr=False, hr_chan='smooth')

In [None]:
data_type = 'Mua_filt' # Data type of original data
fd.ssr_regression(data_type = data_type)
fd.ssr_plot(data_type, channel = 'D1 LS Red', show_stim=True, show_ss=True)

In [None]:
"""ECG"""
data = fd.raw_data.copy()

fig = go.Figure()
fig.add_trace(go.Scatter(x = fd.time, y = data['D3 LS Red']))
fig.show()

In [None]:
"""High pass filter for ECG"""
channel = 'SS Red'

fs = 800/3
filter_cutoffs = (1,40)
transition_width = 1
numtaps = 3001

filter_weights = signal.firwin(numtaps, filter_cutoffs, width=transition_width, window='Hamming', pass_zero = 'bandpass', fs=fs)

#Plot Frequency Response
w,h = signal.freqz(filter_weights, worN = fft.next_fast_len(40000, real=True))
# plt.plot( (w / np.pi) * (fs/2), 20 * np.log10( np.abs(h)))
# plt.xlim((0,200))

data = fd.raw_data[channel].values
padded_data = pad_noise(data, numtaps, 5000)
filtered_data = np.flip(signal.fftconvolve(np.flip(signal.fftconvolve(padded_data, filter_weights, mode='same')), filter_weights, mode='same'))
plot_data = filtered_data[numtaps:-numtaps]

fig = go.Figure()
fig.add_trace(go.Scatter(x = fd.time, y = plot_data))
fig.show()



In [None]:
fd.fs

In [None]:
peaks, _ = find_peaks(plot_data, height= 1000, distance = 67)

peaks = peaks[1:-1] #drops first/last values of peaks due to potential odd behaviors at recording initialization and termination

fig = go.Figure()
fig.add_trace(go.Scatter(x = fd.time, y = plot_data))
fig.add_trace(go.Scatter(x = fd.time[peaks], y = plot_data[peaks], mode = 'markers'))
fig.show()


In [None]:
"""Get distance (time) between peaks and calculate heart rate"""
peak_dt = np.diff(peaks) / fd.fs #Time between peaks in seconds

idx = [ int((peaks[i] + peaks[i + 1]) / 2) for i in np.arange(len(peaks) - 1)]
peak_time = fd.time[idx]
bpm = 60 / peak_dt #Instantaneous Heart rate in BPM based on peak_dt

#Construct ECG dataframe?
ecg_dct = {'peak_idx' : idx,
           'peak_dt' : peak_dt,
           'Time (s)' : peak_time,
           'b2b bpm' : bpm}

ecgDF = pd.DataFrame(ecg_dct)

#fig = go.Figure()
fig = make_subplots(specs = [[{"secondary_y" : True}]])
fig.add_trace(go.Scatter(x = fd.time, y = plot_data))
fig.add_trace(go.Scatter(x = fd.time[peaks], y = plot_data[peaks], mode = 'markers'))
fig.add_trace(go.Scatter(x = ecgDF['Time (s)'], y = ecgDF['b2b bpm']), secondary_y=True)

for param in fd.stimDF.index:
    fig.add_vrect(x0=fd.stimDF.loc[param]['fNIRs onset time (s)'],
                  x1=fd.stimDF.loc[param]['fNIRs offset time (s)'])

fig.show()

In [None]:
ecgDF['smooth bpm'] = ecgDF['b2b bpm'].rolling(5).mean()

#fig = go.Figure()
fig = make_subplots(specs = [[{"secondary_y" : True}]])
fig.add_trace(go.Scatter(x = fd.time, y = plot_data))
fig.add_trace(go.Scatter(x = fd.time[peaks], y = plot_data[peaks], mode = 'markers'))
fig.add_trace(go.Scatter(x = ecgDF['Time (s)'], y = ecgDF['smooth bpm']), secondary_y=True)

for param in fd.stimDF.index:
    fig.add_vrect(x0=fd.stimDF.loc[param]['fNIRs onset time (s)'],
                  x1=fd.stimDF.loc[param]['fNIRs offset time (s)'])

fig.show()

In [None]:
"""Generate plot showing individual data channels with and without short-channel regression"""

data_type = 'Mua_filt' # Data type of original data
ssr_data_type = 'SSR_filt' # Data type of SSR data

#Run regression on data type specified. In this case matches plot for short-channel signal
fd.ssr_regression(data_type = data_type)

og_plot_dct = {'SS Red' : (data_type, 'SS Red'),
            'D1 LS Red' : (data_type, 'D1 LS Red'),
            'D1 LL Red' : (data_type, 'D1 LL Red'),
            'D3 LS Red' : (data_type, 'D3 LS Red'),
            'D3 LL Red' : (data_type, 'D3 LL Red'),
            }

ssr_plot_dct = {'D1 LS Red' : (ssr_data_type, 'D1 LS Red'),
            'D1 LL Red' : (ssr_data_type, 'D1 LL Red'),
            'D3 LS Red' : (ssr_data_type, 'D3 LS Red'),
            'D3 LL Red' : (ssr_data_type, 'D3 LL Red'),
            }

red_channels = ['D3 LS Red', 'D3 LS Red SSR', 'SS Red']

fig, ax = plt.subplots(figsize = (15,5))

fd.plot_full_trace(data_type = 'SSR_filt', channel = 'SS Red SSR',legend = False, show = False, axis = ax)
plt.show()

In [None]:
fd.bandpass_filter()

In [None]:
fd.ssr_regression(data_type = 'Mua_filt')
#fd.plot_channel(data_type = 'SSR', channel = 'D3 LS',plot_style = 'Full', pre_time=0,post_time=0)
#fd.plot_channel(data_type = 'Mua_filt', channel = 'D3 LS',plot_style = 'Full', pre_time=0,post_time=0)

In [None]:
fig, ax = plt.subplots()
fd.plot_full_recording(data_type = 'SSR_filt', channel = 'D3 LS Red SSR',legend = False, fig_size=(15,5), show = False, axis = ax)

In [None]:
plotDF = fd.plot_channel(data_type = 'Mua_filt', channel = 'D1 LL',plot_style = 'Full', pre_time=1,post_time=50, legend = True, zero_shift = True)

In [None]:
plotDF.loc[plotDF['Stim #'] == '1']

In [None]:
fd.plot_channel_interactive(data_type='Mua_filt', channel = 'SS Red', show_stim = True)

In [None]:
fd.plot_channel_interactive(data_type='Mua_filt', channel = 'D3 LL IR', show_stim = True)

In [None]:
fd.plot_channel_interactive(data_type='SSR', channel = 'D3 LL Red SSR', show_stim = True)

In [None]:
fd.plot_channel_interactive(data_type='SSR', channel = 'D3 LL IR SSR', show_stim = True)

In [None]:
"""Calculated HB changes per channel"""
channel_pairs = {'SS': ['SS Red', 'SS IR'],
                 'D1 LS': ['D1 LS Red', 'D1 LS IR'],
                 'D1 LL': ['D1 LL Red', 'D1 LL IR'],
                 'D3 LS': ['D3 LS Red', 'D3 LS IR'],
                 'D3 LL': ['D3 LL Red', 'D3 LL IR'],}
#channel_pairs = [(0,1),(2,4),(3,5),(6,8),(7,9)]
exC = fd.extinctionCoeff[:,0:2]

hemo_data = []
hemo_chanLIST = []


for pair in channel_pairs:
    ch_red = channel_pairs[pair][0]
    ch_ir = channel_pairs[pair][1]

    data_red = fd.d_Mua[ch_red].values
    data_ir = fd.d_Mua[ch_red].values

    d = np.vstack((data_red, data_ir))
    Hb = np.matmul(exC**-1,d)

    hemo_data.append(Hb[0,:])
    hemo_data.append(Hb[1,:])

    hemo_chan_red = pair + ' HbO'
    hemo_chan_ir = pair + ' HbR'

    hemo_chanLIST.append(hemo_chan_red)
    hemo_chanLIST.append(hemo_chan_ir)
hemoDF = pd.DataFrame(np.column_stack(hemo_data), columns = hemo_chanLIST)

In [None]:
hemoDF = pd.DataFrame(np.column_stack(hemo_data), columns = hemo_chanLIST)

In [None]:
d = np.column_stack(hemo_data)

In [None]:
"""High pass filter for finding artifact"""
fs = 800/3
filter_cutoffs = 1
transition_width = 1
numtaps = 3001
#numtaps = int(( 3.3 * fs) / (2 * transition_width)) * 2 + 1

filter_weights = signal.firwin(numtaps, filter_cutoffs, width=transition_width, window='Hamming', pass_zero = 'highpass', fs=fs)

#Plot Frequency Response
w,h = signal.freqz(filter_weights, worN = fft.next_fast_len(40000, real=True))
# plt.plot( (w / np.pi) * (fs/2), 20 * np.log10( np.abs(h)))
# plt.xlim((0,200))

"""flexNIRs High-pass Artifact Filtering -- WIP"""
data_cols = ['D1 Ambient','D3 Ambient']
artDF = fd.raw_data[['D1 Ambient', 'D3 Ambient']].copy()
artDF['Time (s)'] = fd.time

for col in data_cols:
    data = artDF[col].to_numpy()
    padded_data = data #pad_noise(data, numtaps, 5000)
    filtered_data = np.flip(signal.fftconvolve(np.flip(signal.fftconvolve(padded_data, filter_weights, mode='same')), filter_weights, mode='same'))
    name = col + ' Filtered'
    artDF[name] = filtered_data#[numtaps:-numtaps]

"""Plotly plot for looking at stim artifact in flexNIRs data"""
channel = 'D3 Ambient Filtered'

fig = go.Figure()
fig.add_trace(go.Scatter(x = artDF['Time (s)'], y = artDF[channel], customdata = artDF.index, hovertemplate = '%{customdata:.1f}'))
fig.show()

# for param in stimDF.index:
#     fig.add_vrect(x0 = stimDF.loc[param]['fNIRs onset time (s)'], x1 = stimDF.loc[param]['fNIRs offset time (s)'])

In [None]:
"""Manual alignment function - requires manual ID and input of the first stimulation start index"""
DF2 = pd.DataFrame()
DF2 = stimDF.loc[ stimDF['pulse amplitude (μA)'] < 0]
alignDF = manual_alignment(artDF, DF2, stim_start_index=61*266)

"""Plotly plot for looking at stim artifact in flexNIRs data"""
channel = 'D3 Ambient Filtered'

fig = go.Figure()
fig.add_trace(go.Scatter(x = artDF['Time (s)'], y = artDF[channel], customdata = artDF.index, hovertemplate = '%{customdata:.1f}'))

for param in alignDF.index:
    fig.add_vrect(x0 = alignDF.loc[param]['fNIRs onset time (s)'], x1 = alignDF.loc[param]['fNIRs offset time (s)'])
fig.show()

In [None]:
fnirsDF = fnirs_filter(fnirsDF, device_type = 'flexNIRs', filter_cutoffs = (0.01,1))
#fnirs_plot(fnirsDF, alignDF, channel = 'SS', device_type = 'flexNIRs', pre_time = 0, post_time=0)

In [None]:
fnirs_plot(fnirsDF, alignDF, channel = 'D3 LL Filtered', device_type = 'flexNIRs', pre_time = 0, post_time=0, plot_type = 'Full', zero_shift = True)

In [None]:
"""Trying to recreate Harvard data pipeline in python
- Pass each signal through a median filter
- Pass each signal through a bandpass filter (.001 - 0.18) on top of a -log(signal / (mean(signal)) to get delta OD
- Calculates differential path length factor for each channel
- Calculates change in absorption coefficient based on deltaOD / dpf
- Calculates change in HB using MBLL matrix multiplication against extinction coefficients
"""
wavelengths = [760,850]
sds = [2.8,3.3] # Source-detector-separation distances
msp0 = 6.666
mspb = 0.99
wv0 = 750
waterP = 0.75
msp = [msp0*(wv/wv0)**(-mspb) for wv in wavelengths]
extinctionC = np.array([[1349.558, 3910.494,0.0252],[2436.574,1888.46,0.043]])

In [None]:
fnirs_data = sio.loadmat(r'M:\Projects\fNIRs_QC\20260206_0-1Hz_fNIRS\fNIRS\run10.mat')

In [None]:
fd.plot_artifact(channel = 'D3 Ambient', stim_data = alignDF, show_stim=True)

In [None]:
ss_red = fd.d_Mua['SS Red'].values
ss_ir = fd.d_Mua['SS IR'].values
d_red = fd.d_Mua['D3 LS Red'].values
d_ir = fd.d_Mua['D3 LS IR'].values

In [None]:
alpha = np.dot(ss_red, d_red) / np.dot(ss_red,ss_red)

In [None]:
d_red_cleaned = d_red - (alpha * ss_red)