In [None]:
import numpy as np
import matplotlib.pyplot as plt
from brpylib import NsxFile
import os

path_to_data = r'Z:\neural\archive\keck\P066\20240731-PH2\20240805'
path_to_matter = r'Z:\neural\DICOM\P066\processed\P066_matter.csv'

assert(os.path.exists(path_to_data))
assert(os.path.exists(path_to_matter))

# The data is within the "NSP1" folder, and the associated task data is in the "Task" folder
nsp1_folder = 'NSP1'
task_folder = 'Task'
nsp1_folder = os.path.join(path_to_data, nsp1_folder)
task_folder = os.path.join(path_to_data, task_folder)

nsp1_files = os.listdir(nsp1_folder)
task_files = os.listdir(task_folder)
task_files = [f for f in task_files if f.endswith('.txt')] # The relevant data is in the txt files

In [None]:
# Read in the matter file
import pandas as pd
matter = pd.read_csv(path_to_matter)

In [None]:
# The timestamp is in the format "YYYYMMDD-HHMMSS" - we want to pick out the txt files with cerestim in the name.
variable_isi_files = [f for f in task_files if 'cerestim' in f] # This folder is all variable_isi
# extract the timestamp from the filename
experiment_ids = [f.split('_')[0] for f in variable_isi_files] # Convienient way to do it since we put everything in snake case and the timestamp ain't.
# Now refine the task_files, only keep the ones that are in the experiment_ids
task_files = [f for f in task_files if f.split('_')[0] in experiment_ids]

In [None]:
# Now look for the corresponding nsp1 files
experiments = {}
for idx, t in enumerate(experiment_ids):
    this_experiment = [f for f in nsp1_files if t in f]
    experiments[t] = this_experiment

In [None]:
task_files

In [None]:
# Print the last line of the files with "diary" in the name
diary_files = [f for f in task_files if 'diary' in f]
for d in diary_files:
    with open(os.path.join(task_folder, d), 'r') as f:
        lines = f.readlines()
        # Print the last non-empty line
        for l in reversed(lines):
            if l.strip():
                print(l)
                break


In [None]:
experiments[experiment_ids[0]]
# NS3 files are 2khz, NS4 files are 10khz, NS5 files are 30khz filtered, and NS6 files are 30khz unfiltered
# We want the NS3 files for now.
# We can use the Blackrock library to read the data
# First, we need to initialize the library

In [None]:
# load the Ns3 file
nsx = NsxFile(os.path.join(nsp1_folder, experiments[experiment_ids[0]][2]))

In [None]:
nsx_data = nsx.getdata(elec_ids=stim_electrodes)
stimSite = matter[matter['ChannelNumber'] == 33]['ElectrodeName'].values[0]

In [None]:
# Construct the time vector
%matplotlib qt
points = nsx_data['data_headers'][0]['NumDataPoints']
fs = nsx_data['samp_per_s']
totalTimeSec = nsx_data['data_headers'][0]['data_time_s']
#print(points, fs, totalTimeSec)


time = np.linspace(0, totalTimeSec, points)
# Create a mapping from the int data to actual uV. The digital range is -32768 to 32767, and the analog range is -8191 to 8191 uV
# The factor is 8191/32768
conversionFactor = 8191/32768
elec_33 = (nsx_data['data'][0][nsx_data['elec_ids'][1]] * (conversionFactor))
elec_33 = (elec_33)
#elec_33 = np.abs(elec_33)
plt.close('all')
# Find the indices of the stimulation pulses
# The peaks are going to be above 7000 uV.
# We can use the find_peaks function from scipy to find the peaks
from scipy.signal import find_peaks
plt.plot(time, elec_33)
peakLocations = find_peaks(elec_33, height=4000, distance=100)
indices, values = peakLocations
print(len(indices))

# Check to see that each peak is at least 500 ms apart.
for i in range(len(indices)-1):
    if time[indices[i+1]] - time[indices[i]] < 0.5:
        print('Warning: Two stimulations are less than 500 ms apart')
        print('Time between stimulations: ', time[indices[i+1]] - time[indices[i]])
        print('Stimulation indices: ', indices[i], indices[i+1])

blockStarts = []
blockStarts.append(indices[0])
blockEnds = []
blockEnds.append(indices[-1])
for i in range(len(indices)-1):
    currentTime = time[indices[i]]
    nextTime = time[indices[i+1]]
    # 20 groups of 60 stimulations, we want to find the splits.
    if nextTime - currentTime > 5:
        blockStarts.append(indices[i+1])
        blockEnds.append(indices[i])

blockStarts.sort()
blockEnds.sort()

In [None]:
timeRanges = [(x, y) for x, y in zip(blockStarts, blockEnds)]
# Check that each time range is less than 2 minutes.
for i, t in enumerate(timeRanges):
    if time[t[1]] - time[t[0]] > 80:
        print('Warning: Time range is greater than 2 minutes')
        print('Time range: ', time[t[1]] - time[t[0]])
        print('Time range indices: ', t)

# Now what we'd like to do is find the stimulations that fit within each timerange, and then create a dictionary with the stim indices. Each key will be the time range, and the value will be the indices of the stimulations that fit within that time range.
stimIndices = {}
for i, t in enumerate(timeRanges):
    stimIndices[i] = []
    stimIndices[i].append(t[0])
    for j in indices:
        if j > t[0] and j < t[1]:
            stimIndices[i].append(j)
    stimIndices[i].append(t[1])

In [None]:
stimIndices[0][-1]

In [None]:
# Now check that each stimIndices list has 60 stimulations
for k, v in stimIndices.items():
    if len(v) != 60:
        print('Warning: StimIndices list does not have 62 stimulations')
        print('Length of stimIndices list: ', len(v))
        print('Time range: ', timeRanges[k])
        print('Indices: ', v)

# This is the time averaged version.
plt.title('Time averaged version')
all_data = nsx_data['data'][0]
# Now that we have all 1200 stimulations grouped by block, we can calculate the block averaged (spatial) evoked potentials.
timeWindow = 0.200 # 200 ms
rawBlocks = {}
os.makedirs('./figs_time', exist_ok=True)
for channelIdx, electrode in enumerate(all_data):
    plt.figure()
    for k, v in stimIndices.items():
        blockAverage = np.zeros(210)
        rawBlocks[k] = []
        for i in range(60):
            rawBlocks[k].append(electrode[v[i]-10:v[i]+200])
            blockAverage += electrode[v[i]-10:v[i]+200]
        blockAverage /= 60
        plt.plot(blockAverage)
        plt.xlabel('Time (ms)')
        plt.ylabel('uV')
        # Plot the stimulation time point (0)
        #plt.axvline(x=10, color='r')
    try:
        channelNumber = matter.iloc(1)[0][channelIdx]
        channelName = matter[matter['ChannelNumber'] == channelNumber]['ElectrodeName'].values[0]
        matter_type = matter[matter['ChannelNumber'] == channelNumber]['MatterType'].values[0]
    except:
        channelName = 'Unknown'
        matter_type = 'Unknown'
    plt.title('Channel ' + str(channelNumber) + ' ' + channelName + '_stimSite = ' + stimSite + '_' + matter_type + ' time averaged')
    plt.savefig('./figs_time/Channel_' + str(channelNumber) + '_' + channelName + '_timeAvgOverlaid_' + 'stim_site=' + stimSite + '_' + matter_type + '.png')
    plt.close('all')

In [None]:
# Now check that each stimIndices list has 60 stimulations
import time
import os
for k, v in stimIndices.items():
    if len(v) != 60:
        print('Warning: StimIndices list does not have 62 stimulations')
        print('Length of stimIndices list: ', len(v))
        print('Time range: ', timeRanges[k])
        print('Indices: ', v)

# This will be the block averaged version.
# Now that we have all 1200 stimulations grouped by block, we can calculate the block averaged (spatial) evoked potentials.
for channelIdx, electrode in enumerate(all_data):
    plt.figure()
    for block in range(60):
        blockAverage = np.zeros(210)
        for k, v in stimIndices.items():
            blockAverage += electrode[v[block]-10:v[block]+200]
        blockAverage /= len(stimIndices)
        plt.plot(blockAverage)
        plt.xlabel('Time (ms)')
        plt.ylabel('uV')
    try:
        channelNumber = matter.iloc(1)[0][channelIdx]
        channelName = matter[matter['ChannelNumber'] == channelNumber]['ElectrodeName'].values[0]
        matter_type = matter[matter['ChannelNumber'] == channelNumber]['MatterType'].values[0]
    except:
        channelName = 'Unknown'
        matter_type = 'Unknown'
    plt.title('Channel ' + str(channelNumber) + ' ' + channelName + ' ' + 'stim site ' + stimSite + ' ' + matter_type + ' all blocks overlayed')
    # Convert channelIdx to a string using the matter file.
    os.makedirs('./figs_block', exist_ok=True)
    plt.savefig('./figs_block/Channel_' + str(channelNumber) +  '_' + channelName + 'stim_site_' + stimSite + '_all_blocks_overlayed_' + matter_type + '.png')
    plt.close('all')

In [None]:
# We now want to analyze the data in the 10 seconds after the last stimulation for each block
lastStims = []
for k, v in stimIndices.items():
    lastStims.append(v[-1] + int(0.050 * fs)) # 50 ms after the last stimulation

# We have the indices of the last stimulations, now subdivide that into 10 bins of equal length
timeAfterStim = 10
timeAfterStimIndices = [x + int(timeAfterStim * fs) for x in lastStims]
# Subdivide the range into 10 intervals
analysis_intervals = [(start, end) for start, end in zip(lastStims, timeAfterStimIndices)]
step_length = [end - start for start, end in analysis_intervals]

In [None]:
x_time = np.linspace(0, 10, step_length[0])
for channelIdx, electrode in enumerate(all_data):
    plt.figure()
    average = np.zeros(step_length[0])
    for i, interval in enumerate(analysis_intervals):
        average += electrode[interval[0]:interval[1]]
    try:
        channelNumber = matter.iloc(1)[0][channelIdx]
        channelName = matter[matter['ChannelNumber'] == channelNumber]['ElectrodeName'].values[0]
        matter_type = matter[matter['ChannelNumber'] == channelNumber]['MatterType'].values[0]
    except:
        channelName = 'Unknown'
        matter_type = 'Unknown'
    average /= len(analysis_intervals)
    plt.plot(x_time, average)
    plt.ylabel('uV')
    plt.xlabel('Time (s)')
    plt.title('Channel ' + str(channelNumber) + ' ' + channelName + ' ' + 'stim site ' + stimSite + ' ' + matter_type + ' all blocks overlayed')
    # Convert channelIdx to a string using the matter file.
    os.makedirs('./figs_last_stims', exist_ok=True)
    plt.savefig('./figs_last_stims/Channel_' + str(channelNumber) +  '_' + channelName + 'stim_site_' + stimSite + '_last_stims_' + matter_type + '.png')
    plt.close('all')


In [None]:
# With the analysis intervals, let us subdivide this further into a list of 10 tuples, containing the start and end of each 1 second period.
# There should be 200 1 second periods.
one_second_intervals = {}
for idx, interval in enumerate(analysis_intervals):
    start = interval[0]
    end = start + int(fs)
    one_second_intervals[idx] = []
    while end <= interval[1]:
        one_second_intervals[idx].append((start, end))
        start = end
        end = start + int(fs)

In [343]:
one_second_intervals[0]

[(np.int64(143986), np.int64(145986)),
 (np.int64(145986), np.int64(147986)),
 (np.int64(147986), np.int64(149986)),
 (np.int64(149986), np.int64(151986)),
 (np.int64(151986), np.int64(153986)),
 (np.int64(153986), np.int64(155986)),
 (np.int64(155986), np.int64(157986)),
 (np.int64(157986), np.int64(159986)),
 (np.int64(159986), np.int64(161986)),
 (np.int64(161986), np.int64(163986))]

In [353]:
# Now we can calculate the average of each 1 second period.
fs = int(fs)
# number of 1 second intervals
n_intervals = len(one_second_intervals[0])
x_time = np.linspace(0, 1, fs)
plt.close('all')
for channelIdx, electrode in enumerate(all_data):
    for idx in range(n_intervals):
        # Here in the outer loop we are iterating over blocks. We really need some consistent terminology cause this is confusing.
        plt.figure()
        average = np.zeros(fs)
        avg_count = 0
        for interval in one_second_intervals.keys():
            start, end = one_second_intervals[interval][idx]
            plt.plot(x_time, electrode[start:end], alpha=0.1)
            average += electrode[start:end]
            avg_count += 1
        average /= avg_count
        try:
            channelNumber = matter.iloc(1)[0][channelIdx]
            channelName = matter[matter['ChannelNumber'] == channelNumber]['ElectrodeName'].values[0]
            matter_type = matter[matter['ChannelNumber'] == channelNumber]['MatterType'].values[0]
        except:
            channelName = 'Unknown'
            matter_type = 'Unknown'
        plt.plot(x_time, average)
        plt.ylabel('uV')
        plt.xlabel('Time (s)')
        plt.title('Ghost Stim Channel ' + str(channelNumber) + ' ' + channelName + ' ' + 'stim site ' + stimSite +  '_ghost_EP_' + str(idx) + ' ' + matter_type + ' index averaged')
        os.makedirs('./figs_ghost_stims', exist_ok=True)
        plt.savefig('./figs_ghost_stims/Channel_' + str(channelNumber) +  '_' + channelName + 'stim_site_' + stimSite + 'ghost_EP_' + str(idx) + '_' + matter_type + '.png')
    plt.close('all')

In [None]:
# Plot the raw 10 seconds after the last stimulation
x_Time = np.linspace(0, timeAfterStim, timeAfterStimIndex - lastStimIndex)
for channelIdx, electrode in enumerate(all_data):
    plt.figure()
    plt.plot(x_Time, electrode[rangeAfterStim[0]:rangeAfterStim[1]])
    plt.xlabel('Time (s)')
    plt.ylabel('uV')
    try:
        channelNumber = matter.iloc(1)[0][channelIdx]
        channelName = matter[matter['ChannelNumber'] == channelNumber]['ElectrodeName'].values[0]
        matter_type = matter[matter['ChannelNumber'] == channelNumber]['MatterType'].values[0]
    except:
        channelName = 'Unknown'
        matter_type = 'Unknown'
    plt.title('Channel ' + str(channelNumber) + ' ' + channelName + ' ' + 'stim site ' + stimSite + ' ' + matter_type + ' raw 10 seconds after last stim')
    os.makedirs('./figs_last_stim_raw', exist_ok=True)
    plt.savefig('./figs_last_stim_raw/Channel_' + str(channelNumber) +  '_' + channelName + 'stim_site_' + stimSite + '_10_seconds_after_last_stim_raw_' + matter_type + '.png')
    plt.close('all')

In [None]:
from dtw import *

# Now we want to compare the raw blocks array within itself
# We can use the dtw package to do this
for block in rawBlocks:
    for idx in range(block):
        for idx2 in range(block):
            if idx != idx2:
                try:
                    channelNumber = matter.iloc(1)[0][block]
                    channelName = matter[matter['ChannelNumber'] == channelNumber]['ElectrodeName'].values[0]
                    matter_type = matter[matter['ChannelNumber'] == channelNumber]['MatterType'].values[0]
                except:
                    channelName = 'Unknown'
                    matter_type = 'Unknown'
                alignment = dtw(rawBlocks[block][idx], rawBlocks[block][idx2], keep_internals=True)
                alignment.plot(type="twoway")
                plt.title('Block ' + str(block) + ' ' + 'Channel ' + str(channelNumber) + ' ' + channelName + ' ' + 'stim site ' + stimSite + ' ' + matter_type + ' ' + 'DTW between index ' + str(idx) + ' and index ' + str(idx2))
                os.makedirs('./figs_dtw', exist_ok=True)
                plt.savefig('./figs_dtw/Channel_' + str(channelNumber) +  '_' + channelName + 'index0_' + str(idx) + '_index1_' + str(idx2) + '_DTW_' + matter_type + '.png')
                plt.close('all')

In [None]:
alignment.distance