In [41]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:75% !important; }</style>"))
display(HTML("<style>div.output_scroll { height: 44em; }</style>"))

In [43]:
#Import functions you will need for running this script
# %matplotlib widget
import os
import numpy as np
import numpy.matlib
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
import matplotlib.lines as mlines
from scipy import io
from scipy import signal
from scipy.io import loadmat
from scipy.stats import wilcoxon
from scipy import stats
from glob import glob
from datetime import datetime
import pandas as pd
import seaborn as sns
from IPython.display import clear_output
import warnings


mpl.rcParams['font.family'] = 'Arial'
plt.rcParams['font.size'] = 18
plt.rcParams['axes.linewidth'] = 2
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

In [45]:
## Function to load events and spikes:
def get_events(dataLoc):
    # Attempt to load required files and handle exceptions if files are missing or corrupted
    try:
        msgText = np.load(os.path.join(dataLoc, 'messages', 'text.npy'))
        msgSample = np.load(os.path.join(dataLoc, 'messages', 'timestamps.npy'))
        evState = np.load(os.path.join(dataLoc, 'events', 'channel_states.npy'))
        evSample = np.load(os.path.join(dataLoc, 'events', 'timestamps.npy'))
        recInfo = np.genfromtxt(os.path.join(dataLoc, 'sync_messages.txt'), dtype='str', skip_header=1, delimiter=' ')
    except Exception as e:
        print(f"Error loading files: {e}")
        return None

    recInfo = recInfo[-1]
    startTime, fs = int(recInfo.split('@')[0]), int(recInfo.split('@')[1][:-2])

    evTS = (evSample - startTime) / fs
    msgTS = (msgSample - startTime) / fs

    # Print each message text on a new line with its index, decoding from byte-string
    for index, text in enumerate(msgText):
        print(f"{index}: {text.decode('utf-8')}")  # Decoding from UTF-8
        
    print("Total number of blocks:", len(msgTS))

    blockStartID = int(input("Indicate Block Start here (remember python syntax were 0 = 1):"))
    blockStart = msgTS[blockStartID]
    blockEnd = msgTS[blockStartID + 1] if blockStartID + 1 < len(msgTS) else None
  
    if blockEnd is None:
        blockEvTS = evTS[evTS > blockStart]
        blockEvState = evState[evTS > blockStart]
    else:
        blockEvTS = evTS[(evTS > blockStart) & (evTS < blockEnd)]
        blockEvState = evState[(evTS > blockStart) & (evTS < blockEnd)]

    stimulusOn = blockEvTS[blockEvState == 1][1:]
    stimulusOff = blockEvTS[blockEvState == -1][1:]
    

    #spike extraction events are done, now load spike data
    spikes = np.load(os.path.join(dataLoc, 'spike_times.npy'))/ fs
    clust = np.load(os.path.join(dataLoc, 'spike_clusters.npy'))
       
    #   load mean waveforms and get cluster depth
    mean_waves = loadmat(os.path.join(dataLoc,'mean_waveforms.mat'), appendmat=True)
    clust_Depth = mean_waves['chanMap']['ycoords'][0][0][0][np.squeeze(mean_waves['mn'])-1]
    clust_Depth = np.int32(clust_Depth) - np.int32(np.asarray(np.max(mean_waves['chanMap']['ycoords'][0][0][0])))
  
    # load cluster information
    clustInfo = pd.read_csv(os.path.join(dataLoc, 'cluster_info.tsv'),delimiter='\t')
 

    # remove noise clusters
    clust_Depth = clust_Depth[clustInfo.group != 'noise']
    clustInfo = clustInfo[clustInfo.group != 'noise']
    
    
    # extract the cell index
    clustID = np.asarray(clustInfo.id)
    
    return evState, evTS, msgTS, blockEvTS, stimulusOn, spikes, clust, clust_Depth, clustInfo, stimulusOff
    # return evState, evTS, msgTS, blockEvTS, stimulusOn, spikes, clust,clustInfo, stimulusOff

In [47]:
# Function to load stimulus information. It requires 2 inputs, 1) your file location & 2) the name of your file
def stimuli(stim_Loc, name):
    os.chdir(stim_Loc)
    file = loadmat(name, appendmat=True) #loads the .m file with your stimuli
    stimInf = file['stimInfo'] #extracts stimulus information
    val = stimInf[0,0]
    stimInfo = dict()
    stimInfo['trialOrder'] = val['trialOrder'] #extracts the trial order
    stimInfo['tDur'] = val['tDur'] #extracts the tone duration
    stimInfo['ITI'] = val['ITI'] #extracts the inter trial interval
    stimInfo['laserDur'] = val['laserDur'] #extracts laser duration
    return stimInfo, name #returns a dictionary with all your necessary stimulus information

def laser_stimuli(stim_Loc, name):
    os.chdir(stim_Loc)
    file = loadmat(name, appendmat=True) #loads the .m file with your stimuli
    stimInf = file['stimInfo'] #extracts stimulus information
    val = stimInf[0,0]
    stimInfo = dict()
    stimInfo['ITI'] = val['ITI'] #extracts the inter trial interval
    stimInfo['laserDur'] = val['laserDur'] #extracts laser duration
    return stimInfo, name #returns a dictionary with all your necessary stimulus information

In [49]:
# This function applies a smoothing guassian filter to your data
# This function applies a smoothing guassian filter to your data
def SmoothGauss(X, M):
    if X.ndim != 1:
        raise ValueError("X must be a 1D array.")
    
    sigma = M ** 0.5
    G = np.arange(-M, M+1)
    F = np.exp(-G**2 / ((2 * sigma)**2))
    F /= np.sum(F)
    Y = np.convolve(X, F, mode='full')
    
    start_index = M
    end_index = start_index + len(X)
    Y = Y[start_index:end_index]
    
    correction_start = np.sum(F) / (np.sum(F[:M]) + np.cumsum(F[M:M*2]))
    correction_end = np.sum(F) / (np.sum(F[:M]) + np.cumsum(F[M:M*2])[::-1])
    
    Y[:M] *= correction_start
    Y[-M:] *= correction_end
    
    return Y

# This function extracts spikes and returns your raster, trials and psth variable:
def spike_data(spikes, stimulusOnset, edges, smVar):
    raster = []
    trials = []
    psth = np.empty(([len(stimulusOnset), len(edges)-1]))
    psth_S = np.empty(([len(stimulusOnset), len(edges)-1]))
    
    for s, stim in enumerate (stimulusOnset):
        
        spks = spikes - stim
        
        psth[s,:], _ = np.histogram(spks, bins=edges)
        psth[s,:] = psth[s,:]/np.diff(edges).mean()
        psth_S[s,:] = SmoothGauss(psth[s,:],smVar)
        
        spks = spks[(spks > edges[0]) & (spks < edges[-1])]
        raster.extend(spks)
        trials.extend(np.ones(len(spks))*(s+1))
        
    return raster, trials, psth, psth_S

In [51]:
def ismember(A, B):
    A = np.asarray(A).astype(int)
    B = np.asarray(B).astype(int)
    res = np.zeros(A.shape)
    for i in np.unique(A):
        res[A == i] = np.argwhere(B == i).squeeze()
    return res

In [53]:
data = '/Users/solymarrolon/Data/OpenEphys/MGB_Recordings/' #data folder
data_paper = 'Data_Paper'

# cellType = 'PV_Recordings' 
cellType = 'SOM_Recordings' 


# Virus = 'retro_stGtACR2'
# Virus = 'stGtACR1'
# Virus = 'ChR2'
Virus = 'Controls' 


# mouse = ['S095','S097','S098'] #PV Mouse List_Chr2

# mouse = ['S0157','S0158','S0163','S0164'] #PV Mouse List
# mouse = ['S0154','S0155','S0161','S0162']  #SOM Mouse List

# mouse = ['S0169','S0170'] #PV Mouse List Controls
mouse = ['S0171','S0172'] #SOM Mouse List Controls

In [55]:
# stim_loc = '/Users/solymarrolon/Data/Stimuli/07302020/' #Stimulus folder location chr2
stim_loc = '/Users/solymarrolon/Data/Stimuli/02232022/' #Stimulus folder location

# stim = "TuningCurve_50ms_Laser_50ms_Frequencies3_80Hz_073020_stimInfo" #Stimulus name chr2
stim = "TuningCurve_50ms_Laser_50ms_Frequencies3_80Hz_02232022_stimInfo" #Stimulus name 


StimInfo, StimName = stimuli(stim_loc, stim) # This function will extract the stimulus information needed


nreps = 2 #number of stimulus repetitions in the GUI
trialOrder = np.matlib.repmat(StimInfo['trialOrder'],nreps,1) # Order of trials in the stimulus presented
ITI = StimInfo['ITI'] # Interstimulus Time Interval 
laserDur = StimInfo['laserDur'] #Laser Duration 
laserStart = 0 #laser start time
tDur = StimInfo['tDur'] # tone duration 
tStart = 0 #tone start time


binSize = .002 #bin size for your spikes

edges = np.arange(-.050,.20, binSize) #this will establish what spikes to choose and how small your bin size
time = edges[:-1] # time vector


In [57]:
# Find unique frequencies and stimulus index 
uniq_Freq, stim_Ind= np.unique(trialOrder[:,0], return_inverse=True)
laserCond, laser_Ind= np.unique(trialOrder[:,1], return_inverse=True)

# extract laser on and laser off trials
laserOn = np.array(np.where(trialOrder[:,1] == 1)).squeeze()
laserOff = np.array(np.where(trialOrder[:,1] == 0)).squeeze()

sortI = np.lexsort((trialOrder[:,0],trialOrder[:,1]))+1

In [59]:
##### Setup directories and date format
date = datetime.now().strftime('%m%d%Y')
data_folder = 'PSTH/'
psth_folder_name = f'PSTH_{StimName}_{date}'
psth_dir = os.path.join(data, cellType, Virus, data_paper, data_folder, psth_folder_name)

# Ensure the output directory exists
if not os.path.exists(psth_dir):
    os.makedirs(psth_dir)


smoothingVariable = 2

# Processing each mouse and session
for mID in range(len(mouse)):
    mouseList = os.path.join(data, cellType, Virus, data_paper, mouse[mID])
    _, dirs, _ = next(os.walk(mouseList), ([], [], []))
    dirs = sorted(dirs)

    for sessionID in range(len(dirs)):
        clear_output(wait=True)
        dat = os.path.join(mouseList, dirs[sessionID], 'data')
        evState, evTS, msgTS, blockEvts, stimOn, spikes, clust, clust_Depth, clustInfo, stimOff = get_events(dat)
        
        # Check if there are enough stimulus events
        if len(stimOn) > 760:
            evOn = evTS[evState == 1]
            dOn = np.diff(evOn)
            blockI = np.argwhere(np.abs(dOn - 5) < .1)
            t = evTS[(evTS >= evOn[blockI[5]]) & (evTS < evOn[blockI[6]])]
            w = evState[(evTS >= evOn[blockI[5]]) & (evTS < evOn[blockI[6]])]
            newstimOn = t[w == 1][1:]  # Adjust index to skip the first element

            cellID = np.asarray(clustInfo.id)
            allPSTH = np.empty([len(newstimOn), len(edges)-1, len(cellID)])
            allPSTH_S = np.empty([len(newstimOn), len(edges)-1, len(cellID)])
            RastersAll, TrialsAll, spikeSortIAll = [], [], []

            for i, ID in enumerate(cellID):
                cSpks = spikes[clust == ID]
                Rasters, Trials, allPSTH[:, :, i], allPSTH_S[:, :, i] = spike_data(cSpks, newstimOn, edges, smoothingVariable)
                spikeSortI = ismember(Trials, sortI)
                RastersAll.append(Rasters)
                TrialsAll.append(Trials)
                spikeSortIAll.append(spikeSortI)

            psth_name = f'PSTH for {mouse[mID]}_{dirs[sessionID]} {StimName}'
            np.savez_compressed(os.path.join(psth_dir, psth_name),
                                allPSTH=allPSTH,
                                allPSTH_S=allPSTH_S,
                                clustDepth=np.array(clust_Depth, dtype=object),  # If clust_Depth is a list of arrays of varying sizes
                                rasters=np.array(RastersAll, dtype=object),  # If RastersAll contains arrays of different shapes
                                trials=np.array(TrialsAll, dtype=object),  # If TrialsAll contains arrays of different shapes
                                spikeSortI=np.array(spikeSortIAll, dtype=object))  # If spikeSortIAll contains arrays of different shapes
        else:
            # Handle cases with insufficient stimulus events similarly
            cellID = np.asarray(clustInfo.id)
            allPSTH = np.empty([len(stimOn), len(edges)-1, len(cellID)])
            allPSTH_S = np.empty([len(stimOn), len(edges)-1, len(cellID)])
            RastersAll, TrialsAll, spikeSortIAll = [], [], []

            for i, ID in enumerate(cellID):
                cSpks = spikes[clust == ID]
                Rasters, Trials, allPSTH[:, :, i], allPSTH_S[:, :, i] = spike_data(cSpks, stimOn, edges, smoothingVariable)
                spikeSortI = ismember(Trials, sortI)
                RastersAll.append(Rasters)
                TrialsAll.append(Trials)
                spikeSortIAll.append(spikeSortI)

            psth_name = f'PSTH for {mouse[mID]}_{dirs[sessionID]} {StimName}'
            np.savez_compressed(os.path.join(psth_dir, psth_name),
                                allPSTH=allPSTH,
                                allPSTH_S=allPSTH_S,
                                clustDepth=np.array(clust_Depth, dtype=object),  # If clust_Depth is a list of arrays of varying sizes
                                rasters=np.array(RastersAll, dtype=object),  # If RastersAll contains arrays of different shapes
                                trials=np.array(TrialsAll, dtype=object),  # If TrialsAll contains arrays of different shapes
                                spikeSortI=np.array(spikeSortIAll, dtype=object))  # If spikeSortIAll contains arrays of different shapes
# Note: Additional error handling and logging would enhance the robustness and traceability of this script.


0: laser only 10ms -10 reps
1: laser only 25ms -10 reps
2: laser only 50ms -10 reps
3: laser only 100ms -10 reps
4: tc 25ms tone 25ms laser - 2 reps
5: tc 50ms tone 50ms laser - 2 reps
6: tc 100ms tone 100ms laser - 2 reps
7: attenuated tc 25ms tone 25ms laser - 1 reps
8: attenuated tc 100ms tone 100ms laser - 1 reps
Total number of blocks: 9


Indicate Block Start here (remember python syntax were 0 = 1): 5
