In [1]:
# Plot LFP effects for stimulation on the different channels
# Select 2-3 stimulation and recording channels
# Plot LFPs, PSDs, STFT, and cross-channel coherence
# Pre-stim only

#Import Libraries
# | echo: false
# | warning: false
%run C:/Users/27707/Documents/jhu_master/lab/importrhsutilities.py
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import spikeinterface as si  # import core only
import spikeinterface.extractors as se
import spikeinterface.extractors as sse
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.qualitymetrics as sqm
import spikeinterface.comparison as sc
import spikeinterface.exporters as sexp
import spikeinterface.curation as scur
import spikeinterface.widgets as sw

from probeinterface import Probe, ProbeGroup
from probeinterface import generate_linear_probe, generate_multi_shank
from probeinterface import combine_probes
from probeinterface.plotting import plot_probe

import datetime
from uuid import uuid4

import numpy as np
from dateutil.tz import tzlocal

from pynwb import NWBHDF5IO, NWBFile
from pynwb.ecephys import LFP, ElectricalSeries
from pprint import pprint
from scipy import signal

In [2]:
recording_names = ['BO3_O5W_Sti_Day1_5uA_m.nwb',
                   'BO3_O5W_Sti_Day1_10uA_m.nwb',
                   'BO3_O5W_Sti_Day1_20uA_m.nwb',
                   'BO3_O5W_Sti_Day1_30uA_m.nwb',
                   'BO3_O5W_Sti_Day1_40uA_m.nwb',
                   'BO3_O5W_Sti_Day1_50uA_m.nwb',
                   'BO3_O5W_Sti_Day1_60uA_m.nwb'
                   ]

amps = ['5ua','10ua','20ua',
        '30ua','40ua','50ua','60ua']
intan_path = 'C:/Users/27707/Documents/jhu_master/LFP_analysis/2024-07-26_DG_3elec/Experimental Set_Final Round_Batch 27/5 Week Old Organoid/Stimulation Day1/'
intan_folders = ['5uA_233.3uS_Biphasic Pulse/BO3_O5W_Stimulation_Day1_5uA_233uS_240625_161619/merged/',
                 '10uA_233.3uS_Biphasic Pulse/BO3_O5W_Stimulation_Day1_10uA_233uS_240625_162505/merged/',
                 '20uA_233.3uS_Biphasic Pulse/BO3_O5W_Stimulation_Day1_20uA_233uS_240625_163352/merged/',
                 '30uA_233.3uS_Biphasic Pulse/BO3_O5W_Stimulation_Day1_30uA_233uS_240625_164235/merged/',
                 '40uA_233.3uS_Biphasic Pulse/BO3_O5W_Stimulation_Day1_40uA_233uS_240625_165135/merged/',
                 '50uA_233.3uS_Biphasic Pulse/BO3_O5W_Stimulation_Day1_50uA_233uS_240625_170022/merged/',
                 '60uA_233.3uS_Biphasic Pulse/BO3_O5W_Stimulation_Day1_60uA_233uS_240625_170906/merged/'
                 ]



In [3]:
def all_intan_data(channel_names, signal_data_name,results):
    dfs = []
    i = 0
    for res in results:
        #save times for later processing
        dict_data = {}
        dict_data['time'] = res['t']

        #add data for each channel to master df
        j = 0
        for chan in channel_names:
            channel_found, signal_type, signal_index = find_channel_in_header(chan, res)
            dict_data[chan] = res[signal_data_name][signal_index,:]/1000 # data is in mV but NWB electrical series type expects V
            j += 1
        df = pd.DataFrame.from_dict(dict_data)
        dfs.append(df)
        i+=1
    rec_data = pd.concat([dfs[i] for i in range(len(dfs))], axis=0)
    rec_data = rec_data.set_index('time')
    return rec_data

def pull_files(recording_name,fpath='/media/t7/surpass/electrophysiology/Gracias Data/ExperimentalSet_Round2/'):
    '''
    input:
    day - int - day number
    recording_name - string - string name of recording folder
    '''
    filepath = fpath + recording_name + '/*.rhs'
    print(filepath)
    fnames = glob.glob(filepath)
    print(fnames)
    results = 0
    for i in np.arange(len(fnames)):
        res, data_present = load_file(fnames[i])
        if i == 0:
            results = [res]
        else:
            results.append(res)
    return results    

In [None]:
stim_times_array = []
for i in range(len(recording_names)):
    BO7_res = pull_files(intan_folders[i], intan_path)
    stim_names_BO7 = ['STIM_A-000','STIM_A-002','STIM_A-004']
    stim_BO7 = all_intan_data(stim_names_BO7,'stim_data',BO7_res)
    stim_times = stim_BO7.loc[(stim_BO7['STIM_A-000'] != 0) | (stim_BO7['STIM_A-002'] != 0) | (stim_BO7['STIM_A-004'] != 0)].index
    stim_times_array.append(stim_times.to_numpy())

In [None]:
stim_start_array = []
stim_end_array = []
for stim in stim_times_array:
    stim = np.insert(stim,0,0)
    stim_times_diff = np.diff(stim)
    stim_start = stim[np.where(stim_times_diff > 1)[0]+1]
    stim_end = stim[np.where(stim_times_diff > 1)[0][1:]]
    stim_end = np.concatenate((stim_end,[stim[-1]]))
    print(stim_start)
    print(stim_end)
    stim_start_array.append(stim_start)
    stim_end_array.append(stim_end)

In [6]:
def gen_probe():
    north = generate_linear_probe(num_elec=1)
    north.rotate(180)
    north.set_contacts(positions=[[0,75]])


    east = generate_linear_probe(num_elec=1)
    east.rotate(90)
    east.set_contacts(positions=[[75,0]])

    west = generate_linear_probe(num_elec=1)
    west.rotate(-90)
    west.set_contacts(positions=[[-75,0]])

    multi_shank = combine_probes([west, east, north])
    plot_probe(multi_shank)
    plt.show()
    multi_shank.set_device_channel_indices([0,1,2])
    return multi_shank

def filterbank(recording):
    recording_delta = spre.bandpass_filter(recording, freq_min=1, freq_max=4) #delta- moutri paper
    recording_100_200 = spre.bandpass_filter(recording, freq_min=100, freq_max=200) #'gamma'- moutri paper
    recording_200_400 = spre.bandpass_filter(recording, freq_min=200, freq_max=400) #'gamma'- moutri paper
    recording_gamma = spre.bandpass_filter(recording, freq_min=70, freq_max=110) #'ecog high gamma
    return recording_delta, recording_100_200, recording_200_400, recording_gamma

def downsample_lfp_mua(recording):
    recording_lfp = spre.bandpass_filter(recording, freq_min=1, freq_max=400)
    recording_lfp = spre.resample(recording_lfp, 1000)
    recording_mua = spre.resample(spre.rectify(recording), 1000)
    return recording_lfp, recording_mua

def load_data(rec_name):
    recording = se.read_nwb_recording(rec_name)
    multi_shank = gen_probe()
    recording = recording.set_probe(multi_shank)
    return recording 

In [None]:
num_samples = []
num_channels = []
multirecording = []
multirecording_f = []
for i in range(len(recording_names)):
    recording_names = [folder + 'BO3_O5W_Sti_Day1_' + amp.split('u')[0] + 'uA_m.nwb' 
                       for folder, amp in zip(intan_folders, amps)]
    rec_names = intan_path + recording_names[i]
    day_recording = load_data(rec_names)
    [num_pre_samples, num_pre_channels] = day_recording.get_traces().shape # Num samples, num channels
    num_samples.append(num_pre_samples)
    num_channels.append(num_pre_channels)
    print('Presample:' + str(num_pre_samples))

    # Case 2: the sorter DOES NOT handle multi-segment objects
    multirecording_1 = day_recording # The `concatenate_recordings()` mimics a mono-segment object that concatenates all segments
    multirecording.append(day_recording) # load your recording using SpikeInterface
    multi_shank = gen_probe()
    multirecording_filter = spre.bandpass_filter(multirecording[i], freq_min=300, freq_max=6000)
    multirecording_f.append(multirecording_filter)

In [None]:

# Step 1: Calculate RMS-based thresholds for each recording and channel
fs = 30000  # Replace with your actual sampling rate
start_time = 5  # Start at 5 seconds
end_time = 255  # End at 255 seconds
num_samples = fs * (end_time - start_time)  # Calculate the number of samples for 200 seconds

thresholds = []  # Initialize an empty list to store thresholds for each recording

for i in range(len(recording_names)):
    thresholds_per_recording = []  # Store thresholds for each channel in the current recording
    
    for ch in range(3):  # Assuming there are 3 channels
        # Extract the channel data from the 10th second to the 210th second of signal
        channel = multirecording[i].get_traces()[start_time * fs:end_time * fs, ch].astype(np.float32)

        # Calculate the RMS value and determine the threshold
        rms_value = np.sqrt(np.mean(channel**2))
        threshold = rms_value * 5#4.5
        thresholds_per_recording.append(threshold)
    
    thresholds.append(thresholds_per_recording)

# Now, use the thresholds in the original artifact detection code
ms_before = [100]
ms_after = [100]
artifact_frames = []

# Original loop for artifact detection
for i in range(len(recording_names)):
    artifact_frames_per_recording = []  # Store artifact frames for each channel in a recording
    
    for ch in range(3):  # Assuming there are 3 channels
        channel = multirecording[i].get_traces()[:, ch].astype(np.float32)
        
        # Use predefined threshold for the current recording and channel
        threshold = thresholds[i][ch]
        print(f"Threshold for recording {i+1}, channel {ch+1}: {threshold}")
        
        # Apply the custom threshold
        stimulation_trigger_frames = np.where(np.abs(channel) > threshold)

        # Remove artifacts using the detected triggers
        if stimulation_trigger_frames[0].size > 0:  # Check if there are any detected triggers
            multirecording_art = spre.remove_artifacts(
                multirecording[i],
                list_triggers=stimulation_trigger_frames[0].tolist(),
                ms_before=ms_before[0],  
                ms_after=ms_after[0]
            )
        else:
            print(f"No artifacts detected for recording {i+1}, channel {ch+1}")
        
        artifact_frames_per_recording.append(stimulation_trigger_frames[0].tolist())

    artifact_frames.append(artifact_frames_per_recording)


In [None]:
thresholds

In [11]:
# from scipy import signal
# from spikeinterface.extractors import NumpyRecording
# recording_art = []
# recording_lfp = []
# recording_mua = []
# recording_delta = []
# recording_100_200 = []
# recording_200_400 = []
# recording_gamma = []
# recording_car = []
# artifact_frames_resample = []
# # Design a bandstop filter to remove frequencies between 57 Hz and 63 Hz
# lowcut = 55.0  # Lower bound of the bandstop filter
# highcut = 65.0  # Upper bound of the bandstop filter
# lowcut2 = 170.0  # Lower bound of the bandstop filter
# highcut2 = 190.0 
# lowcut3 = 295.0  # Lower bound of the bandstop filter
# highcut3 = 305.0 
# lowcut4 = 60.0  # Lower bound of the bandstop filter
# highcut4 = 70.0
# lowcut5 = 50.0  # Lower bound of the bandstop filter
# highcut5 = 60.0
# sampling_rate = 30000  # Sampling rate (30 kHz)

# # Design the Butterworth bandstop filter
# b_bandstop, a_bandstop = signal.butter(3, [lowcut, highcut], btype='bandstop',fs=sampling_rate)
# b_bandstop2, a_bandstop2 = signal.butter(3, [lowcut2, highcut2], btype='bandstop',fs=sampling_rate)
# b_bandstop3, a_bandstop3 = signal.butter(3, [lowcut3, highcut3], btype='bandstop',fs=sampling_rate)
# b_bandstop4, a_bandstop4 = signal.butter(3, [lowcut4, highcut4], btype='bandstop',fs=sampling_rate)
# b_bandstop5, a_bandstop5 = signal.butter(3, [lowcut5, highcut5], btype='bandstop',fs=sampling_rate)

# f0_120 = 120
# f0_180 = 180.0  # Power line noise at 60 Hz
# Q = 10  # Quality factor for the notch filter
# Q1=80
# # Q2 = 1000
# f_240 = 240
# f_300 =300
# # #Create a notch filter for 60 Hz

# b_60,a_60 = signal.iirnotch(60, 30, 30000)
# b_120,a_120 = signal.iirnotch(f0_120, 200, 30000)
# b_200, a_200 = signal.iirnotch(f0_180, 30, 30000)
# b_240,a_240 = signal.iirnotch(f_240, 100, 30000)
# b_300, a_300 = signal.iirnotch(f_300, Q1, 30000)
# for i in range(len(recording_names)):
#     rc = []
#     for ch in range(3):
#         # Remove artifacts for each channel
#         recording_artifact_removed = spre.remove_artifacts(
#             multirecording[i],
#             list_triggers=artifact_frames[i][ch],  # Combine triggers from all channels
#             ms_before=ms_before[0], ms_after=ms_after[0]  # Adjust based on your needs
#         )
#         channel_data_artifact_removed = recording_artifact_removed.get_traces()[:, ch]

#         # Append the artifact-removed channel data to rc list
#         rc.append(channel_data_artifact_removed)

#     # Combine the channels into one array
#     combined_data = np.stack(rc, axis=1)  # Stack along the channel axis

#     # Create a new RecordingExtractor object with combined data
#     sampling_frequency = multirecording[i].get_sampling_frequency()
#     combined_recording = NumpyRecording(traces_list=[combined_data], sampling_frequency=sampling_frequency)

#     # Append the filtered recording to the list
#     recording_art.append(combined_recording)
#     #  # Combine the channels into one array
#     # combined_data = np.stack(rc, axis=1)  # Stack along the channel axis

#     # # Create a new RecordingExtractor object with combined data
#     # sampling_frequency = multirecording[i].get_sampling_frequency()
#     # combined_recording = NumpyRecording(traces_list=[combined_data], sampling_frequency=sampling_frequency)

#     # # Append the filtered recording to the list
#     # recording_art.append(combined_recording)

#     # # Step 2: Apply CAR filter after artifact removal using numpy
#     # car_filtered_data = combined_data - np.mean(combined_data, axis=1, keepdims=True)
#     # recording_car_filtered = NumpyRecording(traces_list=[car_filtered_data], sampling_frequency=sampling_frequency)
#     # recording_car.append(recording_car_filtered)
#     # Apply the 57-63 Hz bandstop filter
#     traces_filtered = signal.filtfilt(b_60, a_60, recording_art[i].get_traces(), axis=0)
#     # traces_filtered = signal.filtfilt(b_bandstop5, a_bandstop5, traces_filtered, axis=0)
#     # traces_filtered = signal.filtfilt(b_bandstop4, a_bandstop4, traces_filtered, axis=0)
#     # Continue applying other notch filters (120 Hz, 180 Hz, etc.)
#     traces_filtered = signal.filtfilt(b_120, a_120, traces_filtered, axis=0)
#     traces_filtered = signal.filtfilt(b_200, a_200, traces_filtered, axis=0)
#     # traces_filtered = signal.filtfilt(b_bandstop2, a_bandstop2, traces_filtered, axis=0)
#     traces_filtered = signal.filtfilt(b_240, a_240, traces_filtered, axis=0)
#     traces_filtered = signal.filtfilt(b_300, a_300, traces_filtered, axis=0)
#     # traces_filtered = signal.filtfilt(b_bandstop3, a_bandstop3, traces_filtered, axis=0)

#     # Create a new RecordingExtractor with the filtered traces
#     recording_filtered = se.NumpyRecording([traces_filtered], sampling_frequency)

#     # Process LFP and MUA after filtering
#     lfp, mua = downsample_lfp_mua(recording_filtered)

#     # Apply filter bank to LFP
#     delta, s100_200, s200_400, gamma = filterbank(lfp)
#     recording_lfp.append(lfp)
#     recording_mua.append(mua)
#     recording_delta.append(delta)
#     recording_100_200.append(s100_200)
#     recording_200_400.append(s200_400)
#     recording_gamma.append(gamma)

#     # Resample artifact frames for further analysis
#     stim_frames = list(map(lambda x: int(x * (1000 / 30000)), np.concatenate(artifact_frames[i])))
#     artifact_frames_resample.append(stim_frames)

In [12]:
# from spikeinterface.preprocessing import common_reference

# # Step 1: Apply CAR filter after artifact removal
# recording_art = []
# recording_car = []  # Store recordings after applying CAR filter
# recording_lfp = []
# recording_mua = []
# recording_delta = []
# recording_100_200 = []
# recording_200_400 = []
# recording_gamma = []
# artifact_frames_resample = []

# # Design a bandstop filter to remove frequencies between 57 Hz and 63 Hz
# lowcut = 55.0  # Lower bound of the bandstop filter
# highcut = 65.0  # Upper bound of the bandstop filter
# lowcut2 = 170.0  # Lower bound of the bandstop filter
# highcut2 = 190.0 
# lowcut3 = 295.0  # Lower bound of the bandstop filter
# highcut3 = 305.0 
# lowcut4 = 60.0  # Lower bound of the bandstop filter
# highcut4 = 70.0
# lowcut5 = 50.0  # Lower bound of the bandstop filter
# highcut5 = 60.0
# sampling_rate = 30000  # Sampling rate (30 kHz)

# # Design the Butterworth bandstop filter
# b_bandstop, a_bandstop = signal.butter(3, [lowcut, highcut], btype='bandstop', fs=sampling_rate)
# b_bandstop2, a_bandstop2 = signal.butter(3, [lowcut2, highcut2], btype='bandstop', fs=sampling_rate)
# b_bandstop3, a_bandstop3 = signal.butter(3, [lowcut3, highcut3], btype='bandstop', fs=sampling_rate)

# f0_120 = 120
# f0_180 = 180.0  # Power line noise at 60 Hz
# Q = 10  # Quality factor for the notch filter
# Q1 = 50
# f_240 = 240
# f_300 = 300

# # Create a notch filter for 60 Hz
# b_120, a_120 = signal.iirnotch(f0_120, 200, 30000)
# b_240, a_240 = signal.iirnotch(f_240, 100, 30000)

# for i in range(len(recording_names)):
#     rc = []
#     for ch in range(3):
#         # Remove artifacts for each channel
#         recording_artifact_removed = spre.remove_artifacts(
#             multirecording[i],
#             list_triggers=artifact_frames[i][ch],  # Combine triggers from all channels
#             ms_before=ms_before[0], ms_after=ms_after[0]  # Adjust based on your needs
#         )
#         channel_data_artifact_removed = recording_artifact_removed.get_traces()[:, ch]

#         # Append the artifact-removed channel data to rc list
#         rc.append(channel_data_artifact_removed)

#     # Combine the channels into one array
#     combined_data = np.stack(rc, axis=1)  # Stack along the channel axis

#     # Create a new RecordingExtractor object with combined data
#     sampling_frequency = multirecording[i].get_sampling_frequency()
#     combined_recording = NumpyRecording(traces_list=[combined_data], sampling_frequency=sampling_frequency)

#     # Append the filtered recording to the list
#     recording_art.append(combined_recording)

#     # Step 2: Apply CAR filter after artifact removal using numpy
#     car_filtered_data = combined_data - np.mean(combined_data, axis=1, keepdims=True)
#     recording_car_filtered = NumpyRecording(traces_list=[car_filtered_data], sampling_frequency=sampling_frequency)
#     recording_car.append(recording_car_filtered)

#     # Step 3: Apply the 57-63 Hz bandstop filter
#     traces_filtered = signal.filtfilt(b_bandstop, a_bandstop, recording_car[i].get_traces(), axis=0)
#     traces_filtered = signal.filtfilt(b_120, a_120, traces_filtered, axis=0)
#     traces_filtered = signal.filtfilt(b_bandstop2, a_bandstop2, traces_filtered, axis=0)
#     traces_filtered = signal.filtfilt(b_240, a_240, traces_filtered, axis=0)
#     traces_filtered = signal.filtfilt(b_bandstop3, a_bandstop3, traces_filtered, axis=0)

#     # Create a new RecordingExtractor with the filtered traces
#     recording_filtered = se.NumpyRecording([traces_filtered], sampling_frequency)

#     # Process LFP and MUA after filtering
#     lfp, mua = downsample_lfp_mua(recording_filtered)

#     # Apply filter bank to LFP
#     delta, s100_200, s200_400, gamma = filterbank(lfp)
#     recording_lfp.append(lfp)
#     recording_mua.append(mua)
#     recording_delta.append(delta)
#     recording_100_200.append(s100_200)
#     recording_200_400.append(s200_400)
#     recording_gamma.append(gamma)

#     # Resample artifact frames for further analysis
#     stim_frames = list(map(lambda x: int(x * (1000 / 30000)), np.concatenate(artifact_frames[i])))
#     artifact_frames_resample.append(stim_frames)



In [13]:
from scipy import signal
from spikeinterface.extractors import NumpyRecording
import numpy as np

recording_art = []
recording_lfp = []
recording_mua = []
recording_delta = []
recording_100_200 = []
recording_200_400 = []
recording_gamma = []
artifact_frames_resample = []

# Create notch filters for specific frequencies
# frequencies = [9.0, 18.2, 25.0, 27.2, 29.2, 36.4, 37.6, 41.4, 43.8, 45.4, 46.4, 60.0, 180.0, 300.0]
frequencies = [60.0, 180.0, 300.0]
qs = [60, 180, 300]  
notch_filters = [signal.iirnotch(f, q, 30000) for f, q in zip(frequencies, qs)]


for i in range(len(recording_names)):
    rc = []
    for ch in range(3):
        # Remove artifacts for each channel
        recording_artifact_removed = spre.remove_artifacts(
            multirecording[i],
            list_triggers=artifact_frames[i][ch],  # Combine triggers from all channels
            ms_before=ms_before[0], ms_after=ms_after[0]  # Adjust based on your needs
        )
        channel_data_artifact_removed = recording_artifact_removed.get_traces()[:, ch]

        # Append the artifact-removed channel data to rc list
        rc.append(channel_data_artifact_removed)

    # Combine the channels into one array
    combined_data = np.stack(rc, axis=1)  # Stack along the channel axis

    # Create a new RecordingExtractor object with combined data
    sampling_frequency = multirecording[i].get_sampling_frequency()
    combined_recording = NumpyRecording(traces_list=[combined_data], sampling_frequency=sampling_frequency)

    # Append the filtered recording to the list
    recording_art.append(combined_recording)

    # Apply notch filters sequentially to remove noise at specified frequencies
    traces_filtered = recording_art[i].get_traces()
    for b, a in notch_filters:
        traces_filtered = signal.filtfilt(b, a, traces_filtered, axis=0)

    # Create a new RecordingExtractor with the filtered traces
    recording_filtered = NumpyRecording([traces_filtered], sampling_frequency)

    # Process LFP and MUA after filtering
    lfp, mua = downsample_lfp_mua(recording_filtered)

    # Apply filter bank to LFP
    delta, s100_200, s200_400, gamma = filterbank(lfp)
    recording_lfp.append(lfp)
    recording_mua.append(mua)
    recording_delta.append(delta)
    recording_100_200.append(s100_200)
    recording_200_400.append(s200_400)
    recording_gamma.append(gamma)

    # Resample artifact frames for further analysis
    stim_frames = list(map(lambda x: int(x * (1000 / 30000)), np.concatenate(artifact_frames[i])))
    artifact_frames_resample.append(stim_frames)

In [14]:
# import numpy as np
# import matplotlib.pyplot as plt
# from scipy import signal

# # Function to plot PSD for all channels of a given recording
# def plot_psd(recording, title, fs=1000):
#     num_channels = recording.get_traces().shape[1]
#     fig, axes = plt.subplots(1, num_channels, figsize=(15, 5))
#     for ch in range(num_channels):
#         # Extract channel data
#         channel_data = recording.get_traces()[:, ch]
#         # Compute PSD
#         f, Pxx = signal.welch(channel_data, fs=fs, nperseg=fs*2)
#         # Plot PSD
#         axes[ch].semilogy(f, Pxx)
#         axes[ch].set_xlim([1, 30])  # Limit frequency range
        
#         axes[ch].set_title(f'Channel {ch+1} PSD')
#         axes[ch].set_xlabel('Frequency [Hz]')
#         axes[ch].set_ylabel('PSD [V^2/Hz]')
#         axes[ch].grid(True)
#     plt.suptitle(title, fontsize=16)
#     plt.tight_layout()
#     plt.show()

# # Generate 21 PSD plots for recording 0-6 (each with 3 channels)
# for i in range(7):
#     plot_psd(recording_lfp[i], title=f'Recording {i}: PSD for 3 Channels')


In [None]:
# # import numpy as np
# # import scipy.signal as sp_signal  # Rename scipy.signal to avoid conflicts
# # from spikeinterface.extractors import NumpyRecording
# # import matplotlib.pyplot as plt
# # import seaborn as sns
# # from scipy.signal import hilbert 

# # Auxiliary function
# def extract_phase(signal):
#     """Extract the instantaneous phase of a signal using the Hilbert transform"""
#     analytic_signal = hilbert(signal)
#     phase = np.angle(analytic_signal)
#     return phase

# def calculate_plv(phases):
#     """Compute the phase-locking value (PLV) between multiple signals"""
#     n_signals = phases.shape[0]
#     plv_matrix = np.zeros((n_signals, n_signals))
#     for i in range(n_signals):
#         for j in range(i + 1, n_signals):
#             phase_diff = phases[i] - phases[j]
#             plv_matrix[i, j] = np.abs(np.mean(np.exp(1j * phase_diff)))
#             plv_matrix[j, i] = plv_matrix[i, j]
#     return plv_matrix

# def bandpass_filter(signal, lowcut, highcut, fs, order=4):
#     """Band-pass filter"""
#     nyquist = 0.5 * fs
#     low = lowcut / nyquist
#     high = highcut / nyquist
#     b, a = sp_signal.butter(order, [low, high], btype="band")
#     filtered_signal = sp_signal.filtfilt(b, a, signal, axis=0)
#     return filtered_signal

# def calculate_band_plv(traces_filtered, fs, lowcut, highcut):
#     """Compute PLV for a specific frequency band"""
#     band_filtered = np.array([bandpass_filter(traces_filtered[:, ch], lowcut, highcut, fs) for ch in range(traces_filtered.shape[1])])
#     phases = np.array([extract_phase(band_filtered[ch]) for ch in range(band_filtered.shape[0])])
#     plv_matrix = calculate_plv(phases)
#     return plv_matrix

# # Create Notch filters
# frequencies_to_remove = [60]  # Frequencies to be removed
# qs = [60]  # High-quality factor for narrow-band filtering
# notch_filters = [sp_signal.iirnotch(f, q, 1000) for f, q in zip(frequencies_to_remove, qs)]

# # Iterate over recordings
# # for i in range(len(recording_art)):
# for i in range(len([1, 2])):
#     # Extract LFP signals
#     lfp = recording_lfp[i].get_traces()
#     start_time = 30  # Start time (seconds)
#     end_time = start_time + 2  # End time (seconds)

#     start_sample = int(start_time * fs)
#     end_sample = int(end_time * fs)
#     lfp = lfp[start_sample:end_sample, :]  # Extract the first 10 seconds of data
#     fs = recording_lfp[i].get_sampling_frequency()

#     # Step 1: Compute PLV for the original 5-20 Hz signal
#     plv_5_20hz_original = calculate_band_plv(lfp, fs, 55, 65.0)
#     print(f"Recording {i}: Original PLV for 5-20 Hz:\n{plv_5_20hz_original}")

#     # Step 2: Apply Notch filter to remove 9 Hz and 18 Hz
#     lfp_notch_filtered = lfp.copy()
#     for b, a in notch_filters:
#         lfp_notch_filtered = sp_signal.filtfilt(b, a, lfp_notch_filtered, axis=0)

#     # Step 3: Compute PLV for 5-20 Hz after Notch filtering
#     plv_5_20hz_filtered = calculate_band_plv(lfp_notch_filtered, fs, 55.0, 65.0)
#     print(f"Recording {i}: PLV for 5-20 Hz after Notch Filtering:\n{plv_5_20hz_filtered}")

#     # Visualize PLV matrix
#     def visualize_plv(plv_matrix, title):
#         plt.figure(figsize=(6, 4))
#         sns.heatmap(plv_matrix, annot=True, cmap="viridis", xticklabels=["Ch1", "Ch2", "Ch3"], yticklabels=["Ch1", "Ch2", "Ch3"])
#         plt.title(title)
#         plt.xlabel("Channels")
#         plt.ylabel("Channels")
#         plt.show()

#     # Visualize the PLV matrices for the original and filtered signals
#     visualize_plv(plv_5_20hz_original, f"Original PLV for 5-20 Hz (Recording {i})")
#     visualize_plv(plv_5_20hz_filtered, f"PLV for 5-20 Hz after Notch Filtering (Recording {i})")


In [None]:
# # PLV
# plv_5_10hz = calculate_band_plv(lfp, fs, 55, 58.0)
# plv_10_20hz = calculate_band_plv(lfp, fs, 62, 65.0)
# plv_5_10hz_filtered = calculate_band_plv(lfp_notch_filtered, fs, 55.0, 58.0)
# plv_10_20hz_filtered = calculate_band_plv(lfp_notch_filtered, fs, 62.0, 65.0)

# print(f"Original PLV for 5-10 Hz:\n{plv_5_10hz}")
# print(f"Filtered PLV for 5-10 Hz:\n{plv_5_10hz_filtered}")
# print(f"Original PLV for 10-20 Hz:\n{plv_10_20hz}")
# print(f"Filtered PLV for 10-20 Hz:\n{plv_10_20hz_filtered}")


In [None]:
def plot_time_frequency(recording, title, fs=1000):
    """
    Function to plot time-frequency spectrograms for all channels of a given recording.
    Args:
    - recording: Recording object with `get_traces()` method.
    - title: Title for the plots.
    - fs: Sampling frequency of the recording (default is 30 kHz).
    """
    num_channels = recording.get_traces().shape[1]
    fig, axes = plt.subplots(1, num_channels, figsize=(15, 5))
    for ch in range(num_channels):
        # Extract channel data
        channel_data = recording.get_traces()[:, ch]
        # Compute and plot spectrogram
        f, t, Sxx = signal.spectrogram(channel_data, fs=fs, nperseg=256, noverlap=128)
        axes[ch].pcolormesh(t, f, 10 * np.log10(Sxx), shading='gouraud')
        axes[ch].set_ylim([1, 20])  # Limit frequency range for better visualization
        axes[ch].set_title(f'Channel {ch+1}')
        axes[ch].set_xlabel('Time [s]')
        axes[ch].set_ylabel('Frequency [Hz]')
        
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()

# Generate 21 time-frequency spectrograms for recording 0-6 (each with 3 channels)
for i in range(7):
    plot_time_frequency(recording_lfp[i], title=f'Recording {i}: Time-Frequency for 3 Channels')


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import stft

def plot_stft(recording, title, fs=1000):
    """
    Function to plot STFT-based time-frequency spectrograms for all channels of a given recording.
    Args:
    - recording: Recording object with `get_traces()` method.
    - title: Title for the plots.
    - fs: Sampling frequency of the recording (default is 30 kHz).
    """
    num_channels = recording.get_traces().shape[1]
    fig, axes = plt.subplots(1, num_channels, figsize=(15, 5))
    for ch in range(num_channels):
        # Extract channel data
        channel_data = recording.get_traces()[:, ch]
        # Compute STFT
        f, t, Zxx = stft(channel_data, fs=fs, nperseg=256, noverlap=128)
        # Plot the spectrogram (magnitude of STFT)
        axes[ch].pcolormesh(t, f, np.abs(Zxx), shading='gouraud')
        axes[ch].set_ylim([1, 20])  # Limit frequency range for better visualization
        axes[ch].set_xlim([5, 400])
        axes[ch].set_title(f'Channel {ch+1}')
        axes[ch].set_xlabel('Time [s]')
        axes[ch].set_ylabel('Frequency [Hz]')
        
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()

# Generate 21 STFT spectrograms for recording 0-6 (each with 3 channels)
for i in range(7):
    plot_stft(recording_lfp[i], title=f'Recording {i}: STFT Time-Frequency for 3 Channels')


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Set the parameters
fs = 30000  # Sampling frequency (30kHz)
time_window = 0  # 5 minutes in seconds
start_sample = int(time_window * fs)  # Starting sample at the 5-minute mark

# Extract the data from 5 minutes to the end
raw_data_post = recording_art[0].get_traces()[start_sample:, :]

# Calculate the time axis for the plot
time_axis = np.arange(start_sample, start_sample + raw_data_post.shape[0]) / fs

# Plot the raw signal for each channel
num_channels = raw_data_post.shape[1]
fig, axs = plt.subplots(num_channels, 1, figsize=(15, 10), sharex=True)

for i in range(num_channels):
    axs[i].plot(time_axis, raw_data_post[:, i])
    axs[i].set_title(f'Channel {i+1}')
    axs[i].set_ylabel('Amplitude (V)')
    axs[i].set_xlim([time_axis[0], time_axis[9000000]])
    # axs[i].set_xlim([time_axis[6900000], time_axis[8000000]])
    #axs[i].set_ylim([-0.000025, 0.000025])
axs[-1].set_xlabel('Time (s)')
plt.tight_layout()
plt.show()

# Calculate and print the average amplitude for each channel
average_amplitudes = np.mean(np.abs(raw_data_post), axis=0)
for i, avg_amp in enumerate(average_amplitudes):
    print(f'Average amplitude for Channel {i+1}: {avg_amp:.6f} V')

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Set the parameters
fs = 1000  # Sampling frequency (30kHz)
time_window = 0  # 5 minutes in seconds
start_sample = int(time_window * fs)  # Starting sample at the 5-minute mark

# Extract the data from 5 minutes to the end
raw_data_post = recording_lfp[0].get_traces()[start_sample:, :]

# Calculate the time axis for the plot
time_axis = np.arange(start_sample, start_sample + raw_data_post.shape[0]) / fs

# Plot the raw signal for each channel
num_channels = raw_data_post.shape[1]
fig, axs = plt.subplots(num_channels, 1, figsize=(15, 10), sharex=True)

for i in range(num_channels):
    axs[i].plot(time_axis, raw_data_post[:, i])
    axs[i].set_title(f'Channel {i+1}')
    axs[i].set_ylabel('Amplitude (V)')
    axs[i].set_xlim([time_axis[0], time_axis[300000]])

axs[-1].set_xlabel('Time (s)')
plt.tight_layout()
plt.show()

# Calculate and print the average amplitude for each channel
average_amplitudes = np.mean(np.abs(raw_data_post), axis=0)
for i, avg_amp in enumerate(average_amplitudes):
    print(f'Average amplitude for Channel {i+1}: {avg_amp:.6f} V')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal

fs_stim = 1000  # Sampling frequency for stimulus data
nperseg = 5*fs_stim  # Window size for coherence calculation
overlap = 0.5*nperseg
fig, axs = plt.subplots(8, 3, sharex=True, sharey=True, dpi=400, figsize=(20, 30))  # Six rows, three columns
titles = ['North-East', 'North-West', 'East-West']
ind1 = [0, 0, 1]
ind2 = [1, 2, 2]
time_buffer = 0.7 #seconds before/after stimulus, to avoid artifacts
time_period = 2 #seconds
coherence_sum = np.zeros((int(nperseg/2+1), 3))  # Initialize an array to store sum of coherence values
# Step 1: Calculate Pre-Stimulus Coherence (First 5 Minutes)
stim_start_array_new = []

for stim in stim_start_array:
    stim_start1 = stim[0]  
    stim_start_array_new.append(stim_start1)


for i in range(len(recording_names)):
    segment = recording_lfp[i].get_traces()[5 * fs_stim : 255 * fs_stim, :]  
    #segment = recording_lfp[i].get_traces()[int(stim_start_array_new[i]* fs_stim-time_period* fs_stim-time_buffer* fs_stim):int(stim_start_array_new[i]* fs_stim-time_buffer* fs_stim),:]

  
# Step 2: Calculate Post-Stimulus Coherence (100 ms after each of the first five pulses)
    # segment = recording_lfp[i].get_traces()[:200*fs_stim, :] 
    for j in range(3):
        x_post = segment[:, ind1[j]]
        y_post = segment[:, ind2[j]]
        f, Cxy_post = signal.coherence(x_post, y_post, fs_stim, nperseg=nperseg, noverlap=overlap)
        
        coherence_sum[:, j] += Cxy_post  # Accumulate coherence values
        
        axs[i, j].semilogy(f, Cxy_post)
        axs[i, j].set_title(f'{titles[j]} (Pre-stimulus)')
        axs[i, j].set_xlabel('Frequency [Hz]')
        axs[i, j].set_ylabel('Coherence')
        axs[i, j].set_xlim([1, 50])
        axs[i, j].set_ylim([10**(-5), 10**(0)])

# Calculate the average coherence
coherence_avg1 = coherence_sum / 7

# Plot the average coherence in the last row
for j in range(3):
    axs[7, j].semilogy(f, coherence_avg1[:, j])
    axs[7, j].set_title(f'{titles[j]} (Average)')
    axs[7, j].set_xlabel('Frequency [Hz]')
    axs[7, j].set_ylabel('Coherence')
    axs[7, j].set_xlim([1, 20])
    axs[7, j].set_ylim([10**(-5), 10**(0)])

plt.tight_layout()
plt.show()      

In [18]:
# import numpy as np
# import matplotlib.pyplot as plt
# from scipy import signal

# fs_stim = 1000  # Sampling frequency for stimulus data
# nperseg = 0.8*fs_stim  # Window size for coherence calculation
# fig, axs = plt.subplots(8, 3, sharex=True, sharey=True, dpi=400, figsize=(20, 30))  # Six rows, three columns
# titles = ['North-East', 'North-West', 'East-West']
# ind1 = [0, 0, 1]
# ind2 = [1, 2, 2]

# coherence_sum = np.zeros((int(nperseg/2+1), 3))  # Initialize an array to store sum of coherence values
# # Step 1: Calculate Pre-Stimulus Coherence (First 5 Minutes)

# # Step 2: Calculate Post-Stimulus Coherence (10 seconds before first stimulus, 4 seconds duration)
# stim_start_array_old = []
# stim_end_array_old = []
# for stim in stim_times_array:
#     stim_start = stim[0] - 10 
#     stim_end = stim_start + 4 
#     stim_start_array_old.append(stim_start)
#     stim_end_array_old.append(stim_end)

# snr_values = []  # Store SNR values for visualization

# for i in range(len(recording_names)):
#     segment = recording_lfp[i].get_traces()[int(stim_start_array_old[0]):int(stim_end_array_old[0]), :]
    
#     # Calculate SNR using RMS method
#     signal_rms = np.sqrt(np.mean(segment**2, axis=0))
#     noise_segment = recording_lfp[i].get_traces()[int(stim_start_array[0][0]):int(stim_start_array[0][0])+2, :]  # Take the first 5 seconds as baseline noise
#     noise_rms = np.sqrt(np.mean(noise_segment**2, axis=0))
#     snr_rms = 20 * np.log10(signal_rms / noise_rms)
#     snr_values.append(snr_rms)
    
#     for j in range(3):
#         x_post = segment[:, ind1[j]]
#         y_post = segment[:, ind2[j]]
#         f, Cxy_post = signal.coherence(x_post, y_post, fs_stim, nperseg=nperseg)
        
#         coherence_sum[:, j] += Cxy_post  # Accumulate coherence values
        
#         axs[i, j].semilogy(f, Cxy_post)
#         axs[i, j].set_title(f'{titles[j]} (Pre-stimulus)')
#         axs[i, j].set_xlabel('Frequency [Hz]')
#         axs[i, j].set_ylabel('Coherence')
#         axs[i, j].set_xlim([1, 400])
#         axs[i, j].set_ylim([10**(-5), 10**(0)])

# # Calculate the average coherence
# coherence_avg1 = coherence_sum / 7

# # Plot the average coherence in the last row
# for j in range(3):
#     axs[7, j].semilogy(f, coherence_avg1[:, j])
#     axs[7, j].set_title(f'{titles[j]} (Average)')
#     axs[7, j].set_xlabel('Frequency [Hz]')
#     axs[7, j].set_ylabel('Coherence')
#     axs[7, j].set_xlim([1, 400])
#     axs[7, j].set_ylim([10**(-5), 10**(0)])

# plt.tight_layout()
# plt.show()




In [19]:
# stim_start_array_old = []
# stim_end_array_old = []
# for stim in stim_times_array:
#     stim_start = stim[0] - 10 
#     stim_end = stim_start + 4 
#     stim_start_array_old.append(stim_start)
#     stim_end_array_old.append(stim_end)
# stim_start_array_old

In [20]:
# # Plot SNR values per channel for each recording
# plt.figure(dpi=200, figsize=(10, 6))
# snr_values_array = np.array(snr_values)
# for ch in range(snr_values_array.shape[1]):
#     plt.plot(snr_values_array[:, ch], label=f'Channel {ch+1}', marker='o')
# plt.xlabel('Recording Index')
# plt.ylabel('SNR (dB)')
# plt.title('SNR of Different Channels Using RMS Method')
# plt.legend()
# plt.show()


In [21]:
recording_names1 = ['BO11_O7W_Sti_Day4_5uA_m.nwb',
                   'BO11_O7W_Sti_Day4_10uA_m.nwb',
                   'BO11_O7W_Sti_Day4_20uA_m.nwb',
                   'BO11_O7W_Sti_Day4_30uA_m.nwb',
                   'BO11_O7W_Sti_Day4_40uA_m.nwb',
                   'BO11_O7W_Sti_Day4_50uA_m.nwb',
                   'BO11_O7W_Sti_Day4_60uA_m.nwb'
                   ]

amps1 = ['5ua','10ua','20ua',
        '30ua','40ua','50ua','60ua']
intan_path1 = 'C:/Users/27707/Documents/jhu_master/LFP_analysis/2024-07-26_DG_3elec/Experimental Set_Final Round_Batch 27/7 Week Old Organoid/Stimulation Day1/'
intan_folders1 = ['5uA_233.3uS_Biphasic Pulse/BO11_O7W_Stimulation_Day4_5uA_240708_144920/merged/',
                 '10uA_233.3uS_Biphasic Pulse/BO11_O7W_Stimulation_Day4_10uA_240708_145755/merged/',
                 '20uA_233.3uS_Biphasic Pulse/BO11_O7W_Stimulation_Day4_20uA_240708_150743/merged/',
                 '30uA_233.3uS_Biphasic Pulse/BO11_O7W_Stimulation_Day4_30uA_240708_151657/merged/',
                 '40uA_233.3uS_Biphasic Pulse/BO11_O7W_Stimulation_Day4_40uA_240708_152646/merged/',
                 '50uA_233.3uS_Biphasic Pulse/BO11_O7W_Stimulation_Day4_50uA_240708_153556/merged/',
                 '60uA_233.3uS_Biphasic Pulse/BO11_O7W_Stimulation_Day4_60uA_240708_154509/merged/'
                 ]

In [None]:
stim_times_array1 = []
for i in range(len(recording_names1)):
    BO7_res1 = pull_files(intan_folders1[i], intan_path1)
    stim_names_BO71 = ['STIM_A-000','STIM_A-002','STIM_A-004']
    stim_BO71 = all_intan_data(stim_names_BO71,'stim_data',BO7_res1)
    stim_times1 = stim_BO71.loc[(stim_BO71['STIM_A-000'] != 0) | (stim_BO71['STIM_A-002'] != 0) | (stim_BO71['STIM_A-004'] != 0)].index
    stim_times_array1.append(stim_times1.to_numpy())

In [None]:
stim_start_array1 = []
stim_end_array1 = []
for stim in stim_times_array1:
    stim = np.insert(stim,0,0)
    stim_times_diff1 = np.diff(stim)
    stim_start1 = stim[np.where(stim_times_diff1 > 1)[0]+1]
    stim_end1 = stim[np.where(stim_times_diff1 > 1)[0][1:]]
    stim_end1 = np.concatenate((stim_end1,[stim[-1]]))
    print(stim_start1)
    print(stim_end1)
    stim_start_array1.append(stim_start1)
    stim_end_array1.append(stim_end1)

In [None]:
num_samples1 = []
num_channels1 = []
multirecording1 = []
multirecording_f1 = []
for i in range(len(recording_names1)):
    recording_names1 = [folder + 'BO11_O7W_Sti_Day4_' + amp.split('u')[0] + 'uA_m.nwb' 
                       for folder, amp in zip(intan_folders1, amps1)]
    rec_names1 = intan_path1 + recording_names1[i]
    day_recording1 = load_data(rec_names1)
    [num_pre_samples1, num_pre_channels1] = day_recording1.get_traces().shape # Num samples, num channels
    num_samples1.append(num_pre_samples1)
    num_channels1.append(num_pre_channels1)
    print('Presample:' + str(num_pre_samples1))

    # Case 2: the sorter DOES NOT handle multi-segment objects
    multirecording_1 = day_recording1 # The `concatenate_recordings()` mimics a mono-segment object that concatenates all segments
    multirecording1.append(day_recording1) # load your recording using SpikeInterface
    multi_shank1 = gen_probe()
    multirecording_filter1 = spre.bandpass_filter(multirecording1[i], freq_min=300, freq_max=6000)
    multirecording_f1.append(multirecording_filter1)

In [25]:
ms_before = [100]
ms_after = [100]


# Step 1: Calculate RMS-based thresholds for each recording and channel
fs = 30000  # Replace with your actual sampling rate
start_time = 5  # Start at 10 seconds
end_time = 255  # End at 210 seconds
num_samples = fs * (end_time - start_time)  # Calculate the number of samples for 200 seconds

thresholds1 = []  # Initialize an empty list to store thresholds for each recording

for i in range(len(recording_names1)):
    thresholds_per_recording1 = []  # Store thresholds for each channel in the current recording
    
    for ch in range(3):  # Assuming there are 3 channels
        # Extract the channel data from the 10th second to the 210th second of signal
        channel1 = multirecording1[i].get_traces()[start_time * fs:end_time * fs, ch].astype(np.float32)

        # Calculate the RMS value and determine the threshold
        rms_value = np.sqrt(np.mean(channel1**2))
        threshold = rms_value * 5.0
        thresholds_per_recording1.append(threshold)
    
    thresholds1.append(thresholds_per_recording1)



In [None]:
thresholds1

In [None]:
artifact_frames1 = []
for i in range(len(recording_names1)):
    artifact_frames_per_recording1 = []  # Store artifact frames for each channel in a recording
    
    for ch in range(3):  # Assuming there are 3 channels
        channel1 = multirecording1[i].get_traces()[:, ch].astype(np.float32)
        #plt.figure()
        #plt.plot(channel)
        
        # Use predefined threshold for the current recording and channel
        threshold1 = thresholds1[i][ch]
        print(f"Threshold for recording {i+1}, channel {ch+1}: {threshold1}")
        
        # Apply the custom threshold
        stimulation_trigger_frames1 = np.where(np.abs(channel1) > threshold1)
        
        # Remove artifacts using the detected triggers
        multirecording_art1 = spre.remove_artifacts(
            multirecording1[i],
            list_triggers=stimulation_trigger_frames1[0].tolist(),
            ms_before=ms_before[0],  
            ms_after=ms_after[0]
        )
        
        artifact_frames_per_recording1.append(stimulation_trigger_frames1[0].tolist())
        #plt.plot(multirecording_art.get_traces()[:, ch])  # Ensure the recording is preprocessed appropriately
        #plt.show()
    
    artifact_frames1.append(artifact_frames_per_recording1)

In [29]:
from scipy import signal
from spikeinterface.extractors import NumpyRecording
import numpy as np

recording_art1 = []
recording_lfp1 = []
recording_mua1= []
recording_delta1 = []
recording_100_2001 = []
recording_200_4001 = []
recording_gamma1 = []
artifact_frames_resample1 = []

# Create notch filters for specific frequencies
#frequencies = [ 9.0, 18.2, 25.0, 27.2, 29.2, 36.4, 37.6, 41.4, 43.8, 45.4, 46.4, 60.0, 180.0, 300.0]
frequencies = [60.0, 180.0, 300.0]
qs = [60, 180, 300]  
notch_filters = [signal.iirnotch(f, q, 30000) for f, q in zip(frequencies, qs)]

for i in range(len(recording_names1)):
    rc1 = []
    for ch in range(3):
        # Remove artifacts for each channel
        recording_artifact_removed1 = spre.remove_artifacts(
            multirecording1[i],
            list_triggers=artifact_frames1[i][ch],  # Combine triggers from all channels
            ms_before=ms_before[0], ms_after=ms_after[0]  # Adjust based on your needs
        )
        channel_data_artifact_removed1 = recording_artifact_removed1.get_traces()[:, ch]

        # Append the artifact-removed channel data to rc list
        rc1.append(channel_data_artifact_removed1)

    # Combine the channels into one array
    combined_data1 = np.stack(rc1, axis=1)  # Stack along the channel axis

    # Create a new RecordingExtractor object with combined data
    sampling_frequency = multirecording1[i].get_sampling_frequency()
    combined_recording1 = NumpyRecording(traces_list=[combined_data1], sampling_frequency=sampling_frequency)

    # Append the filtered recording to the list
    recording_art1.append(combined_recording1)

    # Apply notch filters sequentially to remove noise at specified frequencies
    traces_filtered = recording_art1[i].get_traces()
    for b, a in notch_filters:
        traces_filtered = signal.filtfilt(b, a, traces_filtered, axis=0)

    # Create a new RecordingExtractor with the filtered traces
    recording_filtered1 = NumpyRecording([traces_filtered], sampling_frequency)

    # Process LFP and MUA after filtering
    lfp1, mua1 = downsample_lfp_mua(recording_filtered1)

    # Apply filter bank to LFP
    delta1, s100_2001, s200_4001, gamma1 = filterbank(lfp)
    recording_lfp1.append(lfp1)
    recording_mua1.append(mua1)
    recording_delta1.append(delta1)
    recording_100_2001.append(s100_2001)
    recording_200_4001.append(s200_4001)
    recording_gamma1.append(gamma1)

    # Resample artifact frames for further analysis
    stim_frames1 = list(map(lambda x: int(x * (1000 / 30000)), np.concatenate(artifact_frames1[i])))
    artifact_frames_resample1.append(stim_frames1)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Set the parameters
fs1 = 30000  # Sampling frequency (30kHz)
time_window1 = 0  # 5 minutes in seconds
start_sample1 = int(time_window1 * fs1)  # Starting sample at the 5-minute mark

# Extract the data from 5 minutes to the end
raw_data_post1 = recording_art1[2].get_traces()[start_sample1:, :]

# Calculate the time axis for the plot
time_axis1 = np.arange(start_sample1, start_sample1 + raw_data_post1.shape[0]) / fs1

# Plot the raw signal for each channel
num_channels1 = raw_data_post1.shape[1]
fig1, axs1 = plt.subplots(num_channels1, 1, figsize=(15, 10), sharex=True)

for i in range(num_channels1):
    axs1[i].plot(time_axis1, raw_data_post1[:, i])
    axs1[i].set_title(f'Channel {i+1}')
    axs1[i].set_ylabel('Amplitude (V)')
    axs1[i].set_xlim([time_axis1[0], time_axis1[9000000]])
    # axs[i].set_xlim([time_axis[6900000], time_axis[8000000]])
    #axs1[i].set_ylim([-0.000025, 0.000025])
axs1[-1].set_xlabel('Time (s)')
plt.tight_layout()
plt.show()

# Calculate and print the average amplitude for each channel
average_amplitudes1 = np.mean(np.abs(raw_data_post1), axis=0)
for i, avg_amp in enumerate(average_amplitudes1):
    print(f'Average amplitude for Channel {i+1}: {avg_amp:.6f} V')

In [32]:
# import numpy as np
# import matplotlib.pyplot as plt

# # Set the parameters
# fs1 = 1000  # Sampling frequency (30kHz)
# time_window1 = 0  # 5 minutes in seconds
# start_sample1 = int(time_window1 * fs1)  # Starting sample at the 5-minute mark

# # Extract the data from 5 minutes to the end
# raw_data_post1 = recording_lfp1[2].get_traces()[start_sample1:, :]

# # Calculate the time axis for the plot
# time_axis1 = np.arange(start_sample1, start_sample1 + raw_data_post1.shape[0]) / fs1

# # Plot the raw signal for each channel
# num_channels1 = raw_data_post1.shape[1]
# fig1, axs1 = plt.subplots(num_channels1, 1, figsize=(15, 10), sharex=True)

# for i in range(num_channels1):
#     axs1[i].plot(time_axis1, raw_data_post1[:, i])
#     axs1[i].set_title(f'Channel {i+1}')
#     axs1[i].set_ylabel('Amplitude (V)')
#     #axs1[i].set_xlim([time_axis1[135000], time_axis1[145000]])
#     axs1[i].set_xlim([time_axis1[0], time_axis1[300000]])
# axs1[-1].set_xlabel('Time (s)')
# plt.tight_layout()
# plt.show()

# # Calculate and print the average amplitude for each channel
# average_amplitudes1 = np.mean(np.abs(raw_data_post1), axis=0)
# for i, avg_amp in enumerate(average_amplitudes1):
#     print(f'Average amplitude for Channel {i+1}: {avg_amp:.6f} V')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal

fs_stim1 = 1000  # Sampling frequency for stimulus data
nperseg1 = 5*fs_stim1  # Window size for coherence calculation
overlap = 0.2*nperseg1
fig1, axs1 = plt.subplots(8, 3, sharex=True, sharey=True, dpi=400, figsize=(20, 30))  # Six rows, three columns
titles = ['North-East', 'North-West', 'East-West']
ind1 = [0, 0, 1]
ind2 = [1, 2, 2]

coherence_sum1 = np.zeros((int(nperseg1/2+1), 3))  # Initialize an array to store sum of coherence values
# Step 1: Calculate Pre-Stimulus Coherence (First 5 Minutes)

stim_start_array_new1 = []

for stim in stim_start_array1:
    stim_start1 = stim[0]  
    stim_start_array_new1.append(stim_start1)


for i in range(len(recording_names1)):
    segment1 = recording_lfp1[i].get_traces()[5 * fs_stim1 : 255 * fs_stim1, :]  
    #segment1 = recording_lfp1[i].get_traces()[int(stim_start_array_new1[i]* fs_stim1-time_period* fs_stim1-time_buffer* fs_stim1):int(stim_start_array_new1[i]* fs_stim1-time_buffer* fs_stim1),:]
    
    for j in range(3):
        x_post1 = segment1[:, ind1[j]]
        y_post1 = segment1[:, ind2[j]]
        f_1, Cxy_post1 = signal.coherence(x_post1, y_post1, fs_stim1, nperseg=nperseg1,noverlap=overlap)
        
        coherence_sum1[:, j] += Cxy_post1  # Accumulate coherence values
        
        axs1[i, j].semilogy(f_1, Cxy_post1)
        axs1[i, j].set_title(f'{titles[j]} (Pre-stimulus)')
        axs1[i, j].set_xlabel('Frequency [Hz]')
        axs1[i, j].set_ylabel('Coherence')
        axs1[i, j].set_xlim([1, 50])
        axs1[i, j].set_ylim([10**(-5), 10**(0)])

# Calculate the average coherence
coherence_avg2 = coherence_sum1 / 7

# Plot the average coherence in the last row
for j in range(3):
    axs1[7, j].semilogy(f_1, coherence_avg2[:, j])
    axs1[7, j].set_title(f'{titles[j]} (Average)')
    axs1[7, j].set_xlabel('Frequency [Hz]')
    axs1[7, j].set_ylabel('Coherence')
    axs1[7, j].set_xlim([1, 50])
    axs1[7, j].set_ylim([10**(-5), 10**(0)])

plt.tight_layout()
plt.show()    

In [34]:
recording_names = ['BO5_O7W_Sti_Day4_5uA_m.nwb',
                   'BO5_O7W_Sti_Day4_10uA_m.nwb',
                   'BO5_O7W_Sti_Day4_20uA_m.nwb',
                   'BO5_O7W_Sti_Day4_30uA_m.nwb',
                   'BO5_O7W_Sti_Day4_40uA_m.nwb',
                   'BO5_O7W_Sti_Day4_50uA_m.nwb',
                   'BO5_O7W_Sti_Day4_60uA_m.nwb'
                   ]

amps = ['5ua','10ua','20ua',
        '30ua','40ua','50ua','60ua']
intan_path = 'C:/Users/27707/Documents/jhu_master/LFP_analysis/2024-07-26_DG_3elec/Experimental Set_Final Round_Batch 27/7 Week Old Organoid/Stimulation Day1/'
intan_folders = ['5uA_233.3uS_Biphasic Pulse/BO5_O7W_Stimulation_Day4_5uA_240708_122554/merged/',
                 '10uA_233.3uS_Biphasic Pulse/BO5_O7W_Stimulation_Day4_10uA_240708_123516/merged/',
                 '20uA_233.3uS_Biphasic Pulse/BO5_O7W_Stimulation_Day4_20uA_240708_124435/merged/',
                 '30uA_233.3uS_Biphasic Pulse/BO5_O7W_Stimulation_Day4_30uA_240708_125400/merged/',
                 '40uA_233.3uS_Biphasic Pulse/BO5_O7W_Stimulation_Day4_40uA_240708_130304/merged/',
                 '50uA_233.3uS_Biphasic Pulse/BO5_O7W_Stimulation_Day4_50uA_240708_131207/merged/',
                 '60uA_233.3uS_Biphasic Pulse/BO5_O7W_Stimulation_Day4_60uA_240708_132110/merged/'
                 ]

In [None]:
stim_times_array = []
for i in range(len(recording_names)):
    BO7_res = pull_files(intan_folders[i], intan_path)
    stim_names_BO7 = ['STIM_A-000','STIM_A-002','STIM_A-004']
    stim_BO7 = all_intan_data(stim_names_BO7,'stim_data',BO7_res)
    stim_times = stim_BO7.loc[(stim_BO7['STIM_A-000'] != 0) | (stim_BO7['STIM_A-002'] != 0) | (stim_BO7['STIM_A-004'] != 0)].index
    stim_times_array.append(stim_times.to_numpy())

In [None]:
stim_start_array = []
stim_end_array = []
for stim in stim_times_array:
    stim = np.insert(stim,0,0)
    stim_times_diff = np.diff(stim)
    stim_start = stim[np.where(stim_times_diff > 1)[0]+1]
    stim_end = stim[np.where(stim_times_diff > 1)[0][1:]]
    stim_end = np.concatenate((stim_end,[stim[-1]]))
    print(stim_start)
    print(stim_end)
    stim_start_array.append(stim_start)
    stim_end_array.append(stim_end)

In [None]:
num_samples = []
num_channels = []
multirecording = []
multirecording_f = []
for i in range(len(recording_names)):
    recording_names = [folder + 'BO5_O7W_Sti_Day4_' + amp.split('u')[0] + 'uA_240708_m.nwb' 
                       for folder, amp in zip(intan_folders, amps)]
    rec_names = intan_path + recording_names[i]
    day_recording = load_data(rec_names)
    [num_pre_samples, num_pre_channels] = day_recording.get_traces().shape # Num samples, num channels
    num_samples.append(num_pre_samples)
    num_channels.append(num_pre_channels)
    print('Presample:' + str(num_pre_samples))

    # Case 2: the sorter DOES NOT handle multi-segment objects
    multirecording_1 = day_recording # The `concatenate_recordings()` mimics a mono-segment object that concatenates all segments
    multirecording.append(day_recording) # load your recording using SpikeInterface
    multi_shank = gen_probe()
    multirecording_filter = spre.bandpass_filter(multirecording[i], freq_min=300, freq_max=6000)
    multirecording_f.append(multirecording_filter)

In [None]:

# Step 1: Calculate RMS-based thresholds for each recording and channel
fs = 30000  # Replace with your actual sampling rate
start_time = 5  # Start at 10 seconds
end_time = 255  # End at 210 seconds
num_samples = fs * (end_time - start_time)  # Calculate the number of samples for 200 seconds

thresholds = []  # Initialize an empty list to store thresholds for each recording

for i in range(len(recording_names)):
    thresholds_per_recording = []  # Store thresholds for each channel in the current recording
    
    for ch in range(3):  # Assuming there are 3 channels
        # Extract the channel data from the 10th second to the 210th second of signal
        channel = multirecording[i].get_traces()[start_time * fs:end_time * fs, ch].astype(np.float32)

        # Calculate the RMS value and determine the threshold
        rms_value = np.sqrt(np.mean(channel**2))
        threshold = rms_value * 5.0
        thresholds_per_recording.append(threshold)
    
    thresholds.append(thresholds_per_recording)

# Now, use the thresholds in the original artifact detection code
ms_before = [100]
ms_after = [100]
artifact_frames = []

# Original loop for artifact detection
for i in range(len(recording_names)):
    artifact_frames_per_recording = []  # Store artifact frames for each channel in a recording
    
    for ch in range(3):  # Assuming there are 3 channels
        channel = multirecording[i].get_traces()[:, ch].astype(np.float32)
        
        # Use predefined threshold for the current recording and channel
        threshold = thresholds[i][ch]
        print(f"Threshold for recording {i+1}, channel {ch+1}: {threshold}")
        
        # Apply the custom threshold
        stimulation_trigger_frames = np.where(np.abs(channel) > threshold)

        # Remove artifacts using the detected triggers
        if stimulation_trigger_frames[0].size > 0:  # Check if there are any detected triggers
            multirecording_art = spre.remove_artifacts(
                multirecording[i],
                list_triggers=stimulation_trigger_frames[0].tolist(),
                ms_before=ms_before[0],  
                ms_after=ms_after[0]
            )
        else:
            print(f"No artifacts detected for recording {i+1}, channel {ch+1}")
        
        artifact_frames_per_recording.append(stimulation_trigger_frames[0].tolist())

    artifact_frames.append(artifact_frames_per_recording)


In [None]:
thresholds

In [42]:
from scipy import signal
from spikeinterface.extractors import NumpyRecording
import numpy as np

recording_art = []
recording_lfp = []
recording_mua = []
recording_delta = []
recording_100_200 = []
recording_200_400 = []
recording_gamma = []
artifact_frames_resample = []

# Create notch filters for specific frequencies
#frequencies = [ 9.0, 18.2, 25.0, 27.2, 29.2, 36.4, 37.6, 41.4, 43.8, 45.4, 46.4, 60.0, 180.0, 300.0]
frequencies = [60.0, 180.0, 300.0]
qs = [60, 180, 300]  
notch_filters = [signal.iirnotch(f, q, 30000) for f, q in zip(frequencies, qs)]

for i in range(len(recording_names)):
    rc = []
    for ch in range(3):
        # Remove artifacts for each channel
        recording_artifact_removed = spre.remove_artifacts(
            multirecording[i],
            list_triggers=artifact_frames[i][ch],  # Combine triggers from all channels
            ms_before=ms_before[0], ms_after=ms_after[0]  # Adjust based on your needs
        )
        channel_data_artifact_removed = recording_artifact_removed.get_traces()[:, ch]

        # Append the artifact-removed channel data to rc list
        rc.append(channel_data_artifact_removed)

    # Combine the channels into one array
    combined_data = np.stack(rc, axis=1)  # Stack along the channel axis

    # Create a new RecordingExtractor object with combined data
    sampling_frequency = multirecording[i].get_sampling_frequency()
    combined_recording = NumpyRecording(traces_list=[combined_data], sampling_frequency=sampling_frequency)

    # Append the filtered recording to the list
    recording_art.append(combined_recording)

    # Apply notch filters sequentially to remove noise at specified frequencies
    traces_filtered = recording_art[i].get_traces()
    for b, a in notch_filters:
        traces_filtered = signal.filtfilt(b, a, traces_filtered, axis=0)

    # Create a new RecordingExtractor with the filtered traces
    recording_filtered = NumpyRecording([traces_filtered], sampling_frequency)

    # Process LFP and MUA after filtering
    lfp, mua = downsample_lfp_mua(recording_filtered)

    # Apply filter bank to LFP
    delta, s100_200, s200_400, gamma = filterbank(lfp)
    recording_lfp.append(lfp)
    recording_mua.append(mua)
    recording_delta.append(delta)
    recording_100_200.append(s100_200)
    recording_200_400.append(s200_400)
    recording_gamma.append(gamma)

    # Resample artifact frames for further analysis
    stim_frames = list(map(lambda x: int(x * (1000 / 30000)), np.concatenate(artifact_frames[i])))
    artifact_frames_resample.append(stim_frames)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Set the parameters
fs = 30000  # Sampling frequency (30kHz)
time_window = 0  # 5 minutes in seconds
start_sample = int(time_window * fs)  # Starting sample at the 5-minute mark

# Extract the data from 5 minutes to the end
raw_data_post = recording_art[0].get_traces()[start_sample:, :]

# Calculate the time axis for the plot
time_axis = np.arange(start_sample, start_sample + raw_data_post.shape[0]) / fs

# Plot the raw signal for each channel
num_channels = raw_data_post.shape[1]
fig, axs = plt.subplots(num_channels, 1, figsize=(15, 10), sharex=True)

for i in range(num_channels):
    axs[i].plot(time_axis, raw_data_post[:, i])
    axs[i].set_title(f'Channel {i+1}')
    axs[i].set_ylabel('Amplitude (V)')
    axs[i].set_xlim([time_axis[0], time_axis[9000000]])
    # axs[i].set_xlim([time_axis[6900000], time_axis[8000000]])
    #axs[i].set_ylim([-0.000025, 0.000025])
axs[-1].set_xlabel('Time (s)')
plt.tight_layout()
plt.show()




In [45]:
# import numpy as np
# import matplotlib.pyplot as plt

# # Set the parameters
# fs = 1000  # Sampling frequency (30kHz)
# time_window = 0  # 5 minutes in seconds
# start_sample = int(time_window * fs)  # Starting sample at the 5-minute mark

# # Extract the data from 5 minutes to the end
# raw_data_post = recording_lfp[0].get_traces()[start_sample:, :]

# # Calculate the time axis for the plot
# time_axis = np.arange(start_sample, start_sample + raw_data_post.shape[0]) / fs

# # Plot the raw signal for each channel
# num_channels = raw_data_post.shape[1]
# fig, axs = plt.subplots(num_channels, 1, figsize=(15, 10), sharex=True)

# for i in range(num_channels):
#     axs[i].plot(time_axis, raw_data_post[:, i])
#     axs[i].set_title(f'Channel {i+1}')
#     axs[i].set_ylabel('Amplitude (V)')
#     axs[i].set_xlim([time_axis[0], time_axis[300000]])

# axs[-1].set_xlabel('Time (s)')
# plt.tight_layout()
# plt.show()

# # Calculate and print the average amplitude for each channel
# average_amplitudes = np.mean(np.abs(raw_data_post), axis=0)
# for i, avg_amp in enumerate(average_amplitudes):
#     print(f'Average amplitude for Channel {i+1}: {avg_amp:.6f} V')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal

fs_stim = 1000  # Sampling frequency for stimulus data
nperseg = 5*fs_stim  # Window size for coherence calculation
overlap = 0.2*nperseg
fig, axs = plt.subplots(8, 3, sharex=True, sharey=True, dpi=400, figsize=(20, 30))  # Six rows, three columns
titles = ['North-East', 'North-West', 'East-West']
ind1 = [0, 0, 1]
ind2 = [1, 2, 2]

coherence_sum = np.zeros((int(nperseg/2+1), 3))  # Initialize an array to store sum of coherence values
# Step 1: Calculate Pre-Stimulus Coherence (First 5 Minutes)

stim_start_array_new = []

for stim in stim_start_array:
    stim_start1 = stim[0]  
    stim_start_array_new.append(stim_start1)


for i in range(len(recording_names)):
    segment = recording_lfp[i].get_traces()[5 * fs_stim : 255 * fs_stim, :]  
    #segment = recording_lfp[i].get_traces()[int(stim_start_array_new[i]* fs_stim-time_period* fs_stim-time_buffer* fs_stim):int(stim_start_array_new[i]* fs_stim-time_buffer* fs_stim),:]
    
    for j in range(3):
        x_post = segment[:, ind1[j]]
        y_post = segment[:, ind2[j]]
        f, Cxy_post = signal.coherence(x_post, y_post, fs_stim, nperseg=nperseg,noverlap=overlap)
        
        coherence_sum[:, j] += Cxy_post  # Accumulate coherence values
        
        axs[i, j].semilogy(f, Cxy_post)
        axs[i, j].set_title(f'{titles[j]} (Pre-stimulus)')
        axs[i, j].set_xlabel('Frequency [Hz]')
        axs[i, j].set_ylabel('Coherence')
        axs[i, j].set_xlim([1, 50])
        axs[i, j].set_ylim([10**(-5), 10**(0)])

# Calculate the average coherence
coherence_avg3 = coherence_sum / 7

# Plot the average coherence in the last row
for j in range(3):
    axs[7, j].semilogy(f, coherence_avg3[:, j])
    axs[7, j].set_title(f'{titles[j]} (Average)')
    axs[7, j].set_xlabel('Frequency [Hz]')
    axs[7, j].set_ylabel('Coherence')
    axs[7, j].set_xlim([1, 50])
    axs[7, j].set_ylim([10**(-5), 10**(0)])

plt.tight_layout()
plt.show() 

In [47]:
recording_names = ['BO11_O5W_Sti_Day1_5uA_m.nwb',
                   'BO11_O5W_Sti_Day1_10uA_m.nwb',
                   'BO11_O5W_Sti_Day1_20uA_m.nwb',
                   'BO11_O5W_Sti_Day1_30uA_m.nwb',
                   'BO11_O5W_Sti_Day1_40uA_m.nwb',
                   'BO11_O5W_Sti_Day1_50uA_m.nwb',
                   'BO11_O5W_Sti_Day1_60uA_m.nwb'
                   ]

amps = ['5ua','10ua','20ua',
        '30ua','40ua','50ua','60ua']
intan_path = 'C:/Users/27707/Documents/jhu_master/LFP_analysis/2024-07-26_DG_3elec/Experimental Set_Final Round_Batch 27/5 Week Old Organoid/Stimulation Day1/'
intan_folders = ['5uA_233.3uS_Biphasic Pulse/BO11_O5W_Stimulation_Day1_5uA_233uS_240625_123952/merged/',
                 '10uA_233.3uS_Biphasic Pulse/BO11_O5W_Stimulation_Day1_10uA_233uS_240625_125131/merged/',
                 '20uA_233.3uS_Biphasic Pulse/BO11_O5W_Stimulation_Day1_20uA_233uS_240625_130048/merged/',
                 '30uA_233.3uS_Biphasic Pulse/BO11_O5W_Stimulation_Day1_30uA_233uS_240625_130939/merged/',
                 '40uA_233.3uS_Biphasic Pulse/BO11_O5W_Stimulation_Day1_40uA_233uS_240625_131833/merged/',
                 '50uA_233.3uS_Biphasic Pulse/BO11_O5W_Stimulation_Day1_50uA_233uS_240625_132735/merged/',
                 '60uA_233.3uS_Biphasic Pulse/BO11_O5W_Stimulation_Day1_60uA_233uS_240625_133646/merged/'
                 ]

In [None]:
stim_times_array = []
for i in range(len(recording_names)):
    BO7_res = pull_files(intan_folders[i], intan_path)
    stim_names_BO7 = ['STIM_A-000','STIM_A-002','STIM_A-004']
    stim_BO7 = all_intan_data(stim_names_BO7,'stim_data',BO7_res)
    stim_times = stim_BO7.loc[(stim_BO7['STIM_A-000'] != 0) | (stim_BO7['STIM_A-002'] != 0) | (stim_BO7['STIM_A-004'] != 0)].index
    stim_times_array.append(stim_times.to_numpy())

In [None]:
stim_start_array = []
stim_end_array = []
for stim in stim_times_array:
    stim = np.insert(stim,0,0)
    stim_times_diff = np.diff(stim)
    stim_start = stim[np.where(stim_times_diff > 1)[0]+1]
    stim_end = stim[np.where(stim_times_diff > 1)[0][1:]]
    stim_end = np.concatenate((stim_end,[stim[-1]]))
    print(stim_start)
    print(stim_end)
    stim_start_array.append(stim_start)
    stim_end_array.append(stim_end)

In [None]:
num_samples = []
num_channels = []
multirecording = []
multirecording_f = []
for i in range(len(recording_names)):
    recording_names = [folder + 'BO11_O5W_Sti_Day1_' + amp.split('u')[0] + 'uA_m.nwb'
                       for folder, amp in zip(intan_folders, amps)]
    rec_names = intan_path + recording_names[i]
    day_recording = load_data(rec_names)
    [num_pre_samples, num_pre_channels] = day_recording.get_traces().shape # Num samples, num channels
    num_samples.append(num_pre_samples)
    num_channels.append(num_pre_channels)
    print('Presample:' + str(num_pre_samples))

    # Case 2: the sorter DOES NOT handle multi-segment objects
    multirecording_1 = day_recording # The `concatenate_recordings()` mimics a mono-segment object that concatenates all segments
    multirecording.append(day_recording) # load your recording using SpikeInterface
    multi_shank = gen_probe()
    multirecording_filter = spre.bandpass_filter(multirecording[i], freq_min=300, freq_max=6000)
    multirecording_f.append(multirecording_filter)

In [None]:

# Step 1: Calculate RMS-based thresholds for each recording and channel
fs = 30000  # Replace with your actual sampling rate
start_time = 5  # Start at 10 seconds
end_time = 255  # End at 210 seconds
num_samples = fs * (end_time - start_time)  # Calculate the number of samples for 200 seconds

thresholds = []  # Initialize an empty list to store thresholds for each recording

for i in range(len(recording_names)):
    thresholds_per_recording = []  # Store thresholds for each channel in the current recording
    
    for ch in range(3):  # Assuming there are 3 channels
        # Extract the channel data from the 10th second to the 210th second of signal
        channel = multirecording[i].get_traces()[start_time * fs:end_time * fs, ch].astype(np.float32)

        # Calculate the RMS value and determine the threshold
        rms_value = np.sqrt(np.mean(channel**2))
        threshold = rms_value * 5.2
        thresholds_per_recording.append(threshold)
    
    thresholds.append(thresholds_per_recording)

# Now, use the thresholds in the original artifact detection code
ms_before = [100]
ms_after = [100]
artifact_frames = []

# Original loop for artifact detection
for i in range(len(recording_names)):
    artifact_frames_per_recording = []  # Store artifact frames for each channel in a recording
    
    for ch in range(3):  # Assuming there are 3 channels
        channel = multirecording[i].get_traces()[:, ch].astype(np.float32)
        
        # Use predefined threshold for the current recording and channel
        threshold = thresholds[i][ch]
        print(f"Threshold for recording {i+1}, channel {ch+1}: {threshold}")
        
        # Apply the custom threshold
        stimulation_trigger_frames = np.where(np.abs(channel) > threshold)

        # Remove artifacts using the detected triggers
        if stimulation_trigger_frames[0].size > 0:  # Check if there are any detected triggers
            multirecording_art = spre.remove_artifacts(
                multirecording[i],
                list_triggers=stimulation_trigger_frames[0].tolist(),
                ms_before=ms_before[0],  
                ms_after=ms_after[0]
            )
        else:
            print(f"No artifacts detected for recording {i+1}, channel {ch+1}")
        
        artifact_frames_per_recording.append(stimulation_trigger_frames[0].tolist())

    artifact_frames.append(artifact_frames_per_recording)

In [54]:
from scipy import signal
from spikeinterface.extractors import NumpyRecording
import numpy as np

recording_art = []
recording_lfp = []
recording_mua = []
recording_delta = []
recording_100_200 = []
recording_200_400 = []
recording_gamma = []
artifact_frames_resample = []

# Create notch filters for specific frequencies
#frequencies = [9.0, 18.2, 25.0, 27.2, 29.2, 36.4, 37.6, 41.4, 43.8, 45.4, 46.4, 60.0, 180.0, 300.0]
frequencies = [60.0, 180.0, 300.0]
qs = [60, 180, 300]  
notch_filters = [signal.iirnotch(f, q, 30000) for f, q in zip(frequencies, qs)]

for i in range(len(recording_names)):
    rc = []
    for ch in range(3):
        # Remove artifacts for each channel
        recording_artifact_removed = spre.remove_artifacts(
            multirecording[i],
            list_triggers=artifact_frames[i][ch],  # Combine triggers from all channels
            ms_before=ms_before[0], ms_after=ms_after[0]  # Adjust based on your needs
        )
        channel_data_artifact_removed = recording_artifact_removed.get_traces()[:, ch]

        # Append the artifact-removed channel data to rc list
        rc.append(channel_data_artifact_removed)

    # Combine the channels into one array
    combined_data = np.stack(rc, axis=1)  # Stack along the channel axis

    # Create a new RecordingExtractor object with combined data
    sampling_frequency = multirecording[i].get_sampling_frequency()
    combined_recording = NumpyRecording(traces_list=[combined_data], sampling_frequency=sampling_frequency)

    # Append the filtered recording to the list
    recording_art.append(combined_recording)

    # Apply notch filters sequentially to remove noise at specified frequencies
    traces_filtered = recording_art[i].get_traces()
    for b, a in notch_filters:
        traces_filtered = signal.filtfilt(b, a, traces_filtered, axis=0)

    # Create a new RecordingExtractor with the filtered traces
    recording_filtered = NumpyRecording([traces_filtered], sampling_frequency)

    # Process LFP and MUA after filtering
    lfp, mua = downsample_lfp_mua(recording_filtered)

    # Apply filter bank to LFP
    delta, s100_200, s200_400, gamma = filterbank(lfp)
    recording_lfp.append(lfp)
    recording_mua.append(mua)
    recording_delta.append(delta)
    recording_100_200.append(s100_200)
    recording_200_400.append(s200_400)
    recording_gamma.append(gamma)

    # Resample artifact frames for further analysis
    stim_frames = list(map(lambda x: int(x * (1000 / 30000)), np.concatenate(artifact_frames[i])))
    artifact_frames_resample.append(stim_frames)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Set the parameters
fs = 30000  # Sampling frequency (30kHz)
time_window = 0  # 5 minutes in seconds
start_sample = int(time_window * fs)  # Starting sample at the 5-minute mark

# Extract the data from 5 minutes to the end
raw_data_post = recording_art[6].get_traces()[start_sample:, :]

# Calculate the time axis for the plot
time_axis = np.arange(start_sample, start_sample + raw_data_post.shape[0]) / fs

# Plot the raw signal for each channel
num_channels = raw_data_post.shape[1]
fig, axs = plt.subplots(num_channels, 1, figsize=(15, 10), sharex=True)

for i in range(num_channels):
    axs[i].plot(time_axis, raw_data_post[:, i])
    axs[i].set_title(f'Channel {i+1}')
    axs[i].set_ylabel('Amplitude (V)')
    axs[i].set_xlim([time_axis[0], time_axis[9000000]])
    # axs[i].set_xlim([time_axis[6900000], time_axis[8000000]])
    #axs[i].set_ylim([-0.000025, 0.000025])
axs[-1].set_xlabel('Time (s)')
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Set the parameters
fs = 1000  # Sampling frequency (30kHz)
time_window = 0  # 5 minutes in seconds
start_sample = int(time_window * fs)  # Starting sample at the 5-minute mark

# Extract the data from 5 minutes to the end
raw_data_post = recording_lfp[6].get_traces()[start_sample:, :]

# Calculate the time axis for the plot
time_axis = np.arange(start_sample, start_sample + raw_data_post.shape[0]) / fs

# Plot the raw signal for each channel
num_channels = raw_data_post.shape[1]
fig, axs = plt.subplots(num_channels, 1, figsize=(15, 10), sharex=True)

for i in range(num_channels):
    axs[i].plot(time_axis, raw_data_post[:, i])
    axs[i].set_title(f'Channel {i+1}')
    axs[i].set_ylabel('Amplitude (V)')
    axs[i].set_xlim([time_axis[0], time_axis[300000]])

axs[-1].set_xlabel('Time (s)')
plt.tight_layout()
plt.show()

# Calculate and print the average amplitude for each channel
average_amplitudes = np.mean(np.abs(raw_data_post), axis=0)
for i, avg_amp in enumerate(average_amplitudes):
    print(f'Average amplitude for Channel {i+1}: {avg_amp:.6f} V')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal

fs_stim = 1000  # Sampling frequency for stimulus data
nperseg = 5*fs_stim  # Window size for coherence calculation
fig, axs = plt.subplots(8, 3, sharex=True, sharey=True, dpi=400, figsize=(20, 30))  # Six rows, three columns
titles = ['North-East', 'North-West', 'East-West']
ind1 = [0, 0, 1]
ind2 = [1, 2, 2]

coherence_sum = np.zeros((int(nperseg/2+1), 3))  # Initialize an array to store sum of coherence values
# Step 1: Calculate Pre-Stimulus Coherence (First 5 Minutes)

stim_start_array_new = []

for stim in stim_start_array:
    stim_start1 = stim[0]  
    stim_start_array_new.append(stim_start1)


for i in range(len(recording_names)):
    segment = recording_lfp[i].get_traces()[5 * fs_stim : 255 * fs_stim, :]  
    #segment = recording_lfp[i].get_traces()[int(stim_start_array_new[i]* fs_stim-time_period* fs_stim-time_buffer* fs_stim):int(stim_start_array_new[i]* fs_stim-time_buffer* fs_stim),:]
    
    for j in range(3):
        x_post = segment[:, ind1[j]]
        y_post = segment[:, ind2[j]]
        f, Cxy_post = signal.coherence(x_post, y_post, fs_stim, nperseg=nperseg,noverlap=overlap)
        
        coherence_sum[:, j] += Cxy_post  # Accumulate coherence values
        
        axs[i, j].semilogy(f, Cxy_post)
        axs[i, j].set_title(f'{titles[j]} (Pre-stimulus)')
        axs[i, j].set_xlabel('Frequency [Hz]')
        axs[i, j].set_ylabel('Coherence')
        axs[i, j].set_xlim([1, 50])
        axs[i, j].set_ylim([10**(-5), 10**(0)])

# Calculate the average coherence
coherence_avg4 = coherence_sum / 7

# Plot the average coherence in the last row
for j in range(3):
    axs[7, j].semilogy(f, coherence_avg4[:, j])
    axs[7, j].set_title(f'{titles[j]} (Average)')
    axs[7, j].set_xlabel('Frequency [Hz]')
    axs[7, j].set_ylabel('Coherence')
    axs[7, j].set_xlim([1, 50])
    axs[7, j].set_ylim([10**(-5), 10**(0)])

plt.tight_layout()
plt.show() 

In [None]:
import numpy as np
import matplotlib.pyplot as plt



# Calculate the average coherence across the four datasets
average_coherence = (coherence_avg1  + coherence_avg2 + coherence_avg3 + coherence_avg4) / 4

# Plot the average coherence result
fig, axs = plt.subplots(1, 3, figsize=(30, 5), dpi=400)
titles = ['North-East', 'North-West', 'East-West']

for j in range(3):
    axs[j].semilogy(f, average_coherence[:, j])
    axs[j].set_title(f'{titles[j]}', fontsize=38)
    axs[j].set_xlabel('Frequency [Hz]', fontsize=30)
    axs[j].set_ylabel('Coherence', fontsize=30)
    axs[j].tick_params(axis='both', which='major', labelsize=28)
    axs[j].set_xlim([1, 49])
    axs[j].set_ylim([10**(-3), 10**(0)])

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Calculate the average coherence
average_coherence = (coherence_avg1 + coherence_avg2 + coherence_avg3 + coherence_avg4) / 4

# Titles and font size settings
titles = ['North-East', 'North-West', 'East-West']
font_title = 27
font_label = 25
font_ticks = 23

# Set font to Arial
plt.rcParams['font.family'] = 'Arial'
xticks_values = [10, 20, 30, 40, 50] 
# Plot each graph individually
for j in range(3):
    fig, ax = plt.subplots(figsize=(5, 5), dpi=400)
    ax.semilogy(f, average_coherence[:, j], linewidth=1.5)
    ax.set_title(titles[j], fontsize=font_title, fontweight='bold')
    ax.set_xlabel('Frequency (Hz)', fontsize=font_label, fontweight='bold')
    ax.set_ylabel('Coherence', fontsize=font_label, fontweight='bold')
    ax.set_xlim([1, 50])
    ax.set_ylim([10**(-3), 10**(0)])
    ax.set_xticks(xticks_values)
    ax.tick_params(axis='both', which='major', labelsize=font_ticks, width=2, length=5)  
    #ax.tick_params(axis='both', which='minor', width=1, length=4)  
    
    
    

    # Remove the top and right spines if not needed
  

    # Display or save the individual graph
    plt.tight_layout()
    plt.savefig(f'coherence_{titles[j].replace(" ", "_")}.png', format='png', bbox_inches='tight')
    plt.show()
