In [46]:
from datetime import date
from glob import glob
import json
import math
import os
from pathlib import Path
import pickle
import sys
import time

import numpy as np
import pandas as pd

In [47]:
sys.path.append(r'C:\Users\lesliec\code')

In [48]:
from tbd_eeg.tbd_eeg.data_analysis.eegutils import EEGexp

### Load an experiment

In [49]:
recfolder = r"F:\psi_exp\mouse735049\aw_sal_2024-05-22_11-05-25\experiment1\recording1"
exp = EEGexp(recfolder, preprocess=False, make_stim_csv=False)

Experiment type: electrical and sensory stimulation


#### Load probe data

In [50]:
probe_list = [x.replace('_sorted', '') for x in exp.experiment_data if 'probe' in x]
print(probe_list)

['probeB', 'probeC', 'probeD', 'probeF']


In [51]:
spike_count_threshold = 20

all_units_info = []
for probe_name in probe_list:
    print('  {}'.format(probe_name))
    ## Load probe_info.json ##
    with open(exp.ephys_params[probe_name]['probe_info']) as data_file:
        data = json.load(data_file)
    npx_allch = np.array(data['channel'])
    surface_ch = int(data['surface_channel'])
    allch_z = np.array(data['vertical_pos'])
    # ref_mask = np.array(data['mask'])
    # npx_chs = np.array([x for x in npx_allch if ref_mask[x] and x <= surface_ch])
    probe_ch_depths = allch_z[surface_ch] - allch_z
    
    ## Load the unit info ##
    cluster_group = pd.read_csv(exp.ephys_params[probe_name]['cluster_group'], sep='\t')
    cluster_metrics = pd.read_csv(exp.ephys_params[probe_name]['cluster_metrics'])
    spike_clusters = np.load(exp.ephys_params[probe_name]['spike_clusters'])
    spike_times = np.load(exp.ephys_params[probe_name]['spike_times'])
    
    if not np.array_equal(cluster_group['cluster_id'].values.astype('int'), np.unique(spike_clusters)):
        print('   IDs from cluster_group.tsv DO NOT match spike_clusters.npy. This may mean there are unsorted units, check in phy.')
        continue
    if np.array_equal(cluster_group['cluster_id'].values.astype('int'), cluster_metrics['cluster_id'].values.astype('int')):
        unit_metrics = pd.merge(cluster_group.rename(columns={'group':'label'}), cluster_metrics, on='cluster_id')
    else:
        print('   IDs from cluster_group DO NOT match cluster_metrics.')
        continue
    
    ## Select only "good" units ##
    unit_metrics['spike_count'] = [np.sum(spike_clusters == x) for x in unit_metrics.cluster_id.values]
    good_units = unit_metrics[(unit_metrics['label'] == 'good') & (unit_metrics['spike_count'] > spike_count_threshold)]
    tempcoords = np.array([[int(y) for y in x.replace('[','').replace(']','').replace(' ','').split(',')] for x in good_units.ccf_coord.values])
    
    ## Now reorganize metrics to save ##
    probe_units = pd.DataFrame([probe_name[-1] + str(x) for x in good_units.cluster_id.values], columns=['unit_id'])
    probe_units['probe'] = [probe_name] * len(good_units)
    probe_units['peak_ch'] = good_units['peak_channel'].values
    probe_units['depth'] = [probe_ch_depths[pkch] for pkch in good_units.peak_channel.values]
    probe_units['spike_duration'] = good_units['duration'].values
    probe_units['region'] = good_units['area'].values
    probe_units['CCF_AP'], probe_units['CCF_DV'], probe_units['CCF_ML'] = tempcoords[:,0], tempcoords[:,1], tempcoords[:,2]
    probe_units['firing_rate'] = good_units['firing_rate'].values
    probe_units['presence_ratio'] = good_units['presence_ratio'].values
    probe_units['isi_viol'] = good_units['isi_viol'].values
    probe_units['amplitude_cutoff'] = good_units['amplitude_cutoff'].values
    probe_units['spike_count'] = good_units['spike_count'].values
    
    ## Add parent region column ##
    # probe_units = add_parent_region_to_df(probe_units, str_tree, annot)
    all_units_info.append(probe_units)

## Now combine all probe units dfs ##
all_units_info_df = pd.concat(all_units_info, ignore_index=True)

  probeB
  probeC
  probeD
  probeF


In [52]:
all_units_info_df 

Unnamed: 0,unit_id,probe,peak_ch,depth,spike_duration,region,CCF_AP,CCF_DV,CCF_ML,firing_rate,presence_ratio,isi_viol,amplitude_cutoff,spike_count
0,B0,probeB,0,3700,0.494472,VAL,262,174,179,5.416058,0.99,0.003752,0.000110,55327
1,B1,probeB,0,3700,0.549414,VAL,262,174,179,1.570380,0.96,0.223173,0.001627,16042
2,B5,probeB,0,3700,0.755444,VAL,262,174,179,0.190693,0.91,2.017995,0.350076,1948
3,B6,probeB,0,3700,0.508208,VAL,262,174,179,0.028095,0.73,0.000000,0.345428,287
4,B7,probeB,0,3700,0.563149,VAL,262,174,179,0.008027,0.53,0.000000,0.500000,82
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1295,F570,probeF,289,820,0.769179,MOs5,156,89,166,0.074789,0.94,0.000000,0.079541,764
1296,F571,probeF,306,640,0.590620,MOs5,154,82,163,0.030542,0.91,0.000000,0.000829,312
1297,F572,probeF,55,3160,0.549414,LSr,174,180,211,0.034262,0.56,0.000000,0.139314,350
1298,F573,probeF,49,3220,0.604355,LSr,174,182,212,0.397735,0.99,0.463878,0.018443,4063


In [53]:
np.unique(all_units_info_df['region'].values)

array(['CA1', 'CA2', 'CA3', 'CP', 'DG-mo', 'DG-po', 'DG-sg', 'HY', 'LD',
       'LGd-co', 'LGd-sh', 'LSr', 'MGd', 'MGm', 'MGv', 'MOs2/3', 'MOs5',
       'MOs6a', 'MOs6b', 'PO', 'SSp-tr5', 'SSp-tr6a', 'SSp-tr6b', 'SSs5',
       'SSs6a', 'SSs6b', 'TH', 'VAL', 'VISa2/3', 'VISa4', 'VISa5',
       'VISal2/3', 'VISal4', 'VISal5', 'VISp2/3', 'VISp4', 'VISp5',
       'VISp6a', 'ZI', 'alv', 'ccb', 'ccg', 'cing', 'fa', 'int', 'ml',
       'or', 'root'], dtype=object)

In [54]:
reg_units = all_units_info_df[all_units_info_df['region'] == 'RT']
print(len(reg_units))

0
