In [1]:
from __future__ import division
import pandas as pd
import numpy as np
import time
from datetime import datetime
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.ticker import MaxNLocator

from matplotlib.ticker import FormatStrFormatter

import scipy.stats as sstat
import scipy.signal as ssig
import h5py
from mpl_toolkits.mplot3d import Axes3D
import os
from sklearn import preprocessing
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA as sklearnPCA
import re

# import ephys_unit_analysis as ena
import mz_ephys_unit_analysis as mz_ena

#import resampy
import fnmatch
import seaborn as sns
%matplotlib inline
%load_ext autoreload
%autoreload 2



---

# NP probe inserted through V1 and hippo 

In [2]:
probe = 'Neuropixels'
channel_groups = mz_ena.get_channel_depth(probe)

In [None]:
root_files = []
matches = [] # list of experiment folders
source_folder = r"G:\Neuropixels\Sam_multi_brain_region_paper\SORTED"

for root, dirnames, filenames in os.walk(source_folder):
    for filename in fnmatch.filter(filenames, '*rez.mat'):
        for filename in fnmatch.filter(filenames, '*cluster_group.tsv'):#For newer phy2 GUI, .tsv instead of .csv files
            if (str('novel') in root):
                if (str('bad') not in root):
                    matches.append(os.path.join(root, filename))
                    root_files.append(root)
                    print (root)

print('\nIMPORTANT: This has "cluster_group.tsv" already appended to the matches list')
print ('How many files?',len(matches))

In [41]:



all_units_or_good = 1   # if 0--manually sorted good units, if 1--all units from KS




In [None]:
data_df = []
df_rez = []

for f in root_files:
    path = f
    cluster_path = os.path.join(path, 'cluster_KSLabel.tsv')

    stim_type = f.split('\\')[-1].split('_')[0]
    et_num = f.split('\\')[-1].split('_')[2]
    cc_num = f.split('\\')[-1].split('_')[1]   # what cage number is this from?
    
    cluster_groups = pd.read_csv(cluster_path, sep = '\t')
    
    if all_units_or_good == 0:
        good = cluster_groups[cluster_groups['group'] == 'good'].cluster_id.values
    elif all_units_or_good == 1:
        good = cluster_groups[cluster_groups['KSLabel'] == 'good'].cluster_id.values
    
    spike_clusters = np.load(os.path.join(path, 'spike_clusters.npy'))
    spike_times = np.load(os.path.join(path, 'spike_times.npy'))
    templates = np.load(os.path.join(path, 'templates.npy'))
    spike_templates = np.load(os.path.join(path, 'spike_templates.npy'))

    foo = pd.DataFrame({'stim_type': stim_type,
                            'et': et_num,
                            'cc': cc_num,
                            'cluster_id': spike_clusters.flatten(),
                            'times': spike_times.flatten()/30000.0, 
                            'templates': spike_templates.flatten(),
                            'path': f})
    
    data_df.append(foo)
    
    foo_1 = foo[foo.cluster_id.isin(good)]
    df_rez.append(foo_1)

data_df = pd.concat(data_df, axis=0, ignore_index=True)
df_rez = pd.concat(df_rez, axis=0, ignore_index=True)

print('total units df shape:', data_df.shape)
print('"good" units df shape:', df_rez.shape)

print('Total paths:', df_rez.path.nunique())

In [None]:
data_df['cuid'] =  data_df.et.astype(str) + str('_') + data_df.cluster_id.astype(str)
df_rez['cuid'] =  df_rez.et.astype(str) + str('_') + df_rez.cluster_id.astype(str)

print("Total units:", data_df['cuid'].nunique())
print("Good units:", df_rez['cuid'].nunique())

df_rez.head()

---

# Keep going from here

In [44]:

trials_number = 50 # ~~~~~~~~~~~~~~~~~~~~~~IMPORTANT TO CHANGE THIS~~~~~~~~~~~~~~~~~~~~~~

trial_length = 3.0 #this is the length of the recording in OpenEphys (the yellow highlight)
th_bin = 0.01

ls_rawcount = []
ls_lowspikecount = []
ls_refract_violators = []
ls_lowamp_waveforms = []

In [None]:
data_df['trial']=(data_df.times//trial_length).astype(int)
df_rez['trial']=(df_rez.times//trial_length).astype(int)

data_df.head()

---

# Creating the dataFrames

In [None]:
ls_spikes = []
ls_tmt = []
ls_psth = []

i=0
num_units = df_rez['cuid'].nunique()

for iii, unit in enumerate(df_rez['cuid'].unique()): ##### I changed this from df_rez to data_df to check the units
    cuid = str(unit)
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    tmp2 = df_rez[(df_rez.cuid == unit)]             ##### I changed this from df_rez to data_df to check the units
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    stim_id = tmp2.stim_type.values[0]
    cluster_id = tmp2.cluster_id.values[0]
    et = tmp2.et.values[0]
    cc = tmp2.cc.values[0]
    path = tmp2.path.values[0]
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    try:
        tmt, depth, ch_idx = mz_ena.ksort_get_tmt(tmp2, cluster_id, templates, channel_groups)
    except:
        i = i+1        
        continue    
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    df = mz_ena.getRaster_kilosort(tmp2, unit, trial_length) 
    trials_number_not_empty = len(df.trial.unique())    

    h, ttr = mz_ena.PSTH(df.times, th_bin, trial_length, trials_number_not_empty)

    zscore = sstat.mstats.zscore(h)
    mean = np.mean(h[0:50])#The Baseline period. Be sure it matches time course of experiments##
    if mean<=0:
        std=1
    else:
        std = np.std(h[0:50])#The Baseline period. Be sure it matches time course of experiments##
    ztc = (h - mean)/std
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    if iii%200 == 0:
        print('done with {0} out of {1}'.format(iii, num_units))
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    df_psth_tmp = pd.DataFrame({'times':ttr,
                                'stim': stim_id,
                                'Hz':h,
                                'cluster_id': cluster_id,
                                'depth': depth,
                                'zscore':zscore, 
                                'ztc':ztc,
                                'et':et,
                                'cc': cc,
                                'cuid':cuid,
                                'path':path})
    ls_psth.append(df_psth_tmp)
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    df_spikes_tmp = pd.DataFrame({'cluster_id': cluster_id, 
                                  'spikes': tmp2.times.values,
                                  'trial':df.trial,
                                  'trial_spikes':df.times,
                                  'stim': stim_id,
                                  'depth':depth,
                                  'et':et,
                                  'cc': cc,
                                  'cuid': cuid,
                                  'path':path})
    ls_spikes.append(df_spikes_tmp)
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    df_tmt_tmp = pd.DataFrame({'tmt': tmt,
                               'stim': stim_id,
                               'depth':depth,
                               'et':et,
                               'cc': cc,
                               'cuid': cuid,
                               'path':path})
    ls_tmt.append(df_tmt_tmp)
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
print(i, "errors")
df_psth = pd.concat(ls_psth)
df_spikes = pd.concat(ls_spikes)
df_tmt = pd.concat(ls_tmt)

In [None]:
print('Min unit depth on probe:', df_psth['depth'].min())
print('Max unit depth on probe:', df_psth['depth'].max())

# print('\n',np.sort(df_psth['depth'].unique()))

In [None]:
null_df = df_psth[df_psth['zscore'] < 20]

print('# good units: %d' % df_psth['cuid'].nunique())
print('# good units w/ z-score < 20: %d' % null_df['cuid'].nunique())

In [None]:
null_df.head()

---

# Updating the dfs with group and region labels

In [50]:
# def label_et_group(row):
#     if (row['et']=="et1")|(row['et']=="et2")|(row['et']=="et10")|(row['et']=="et20")|(row['et']=="et200")|(row['et']=="et30")|(row['et']=="et3"):
#         return "wt"
#     else:
#         return "no_con"

# def label_group(row):
#     if (row['cc'] == "CC082263") | (row['cc'] == "CC067489") | (row['cc'] == "CC082260") | (row['cc'] == "CC084621"):
#         return "wt"
#     elif (row['cc'] == "CC082257") | (row['cc'] == "CC067431") | (row['cc'] == "CC067432") | (row['cc'] == "CC082255"):
#         return "fx"
    
def label_region(row):
    if (row['depth'] <= 3100) & (row['depth'] >= 2000):
        return 'v1'
    elif (row['depth'] < 1800) & (row['depth'] >= 600):
        return 'hippo'
    elif (row['depth'] < 600):
        return 'thal'
    else:
        return 'none'

In [None]:
null_df['region'] = null_df.apply(lambda row: label_region(row), axis=1)

null_df.head()

In [None]:
df_spikes['region'] = df_spikes.apply(lambda row: label_region(row), axis=1)

df_spikes.head()

In [None]:
# df_tmt['group'] = df_tmt.apply(lambda row: label_et_group(row), axis=1)
df_tmt['region'] = df_tmt.apply(lambda row: label_region(row), axis=1)

df_tmt.head()

---

# Last reordering of the columns for easy viewing

In [None]:
# just a last reordering of the columns for easy viewing
cols = ['times', 'cuid', 'depth', 'Hz', 'zscore', 'ztc', 'region', 'stim', 'cc', 'et', 'cluster_id', 'path']
null_df = null_df[cols]

null_df.head()

In [None]:
cols = ['trial', 'trial_spikes', 'spikes', 'cuid', 'stim', 'depth', 'region', 'et', 'cc', 'cluster_id', 'path']
df_spikes = df_spikes[cols]

df_spikes.head()

In [None]:
cols = ['tmt', 'stim', 'region', 'depth', 'et', 'cc', 'cuid', 'path']
df_tmt = df_tmt[cols]

df_tmt.head()

---

# Only keep the WT pre & post & novel (?) data

In [20]:
# df_psth = null_df[null_df.group.isin(['wt'])]
# df_spikes = df_spikes[df_spikes.group.isin(['wt'])]
# df_tmt = df_tmt[df_tmt.group.isin(['wt'])]

# print(df_psth.group.unique())
# print(df_spikes.group.unique())
# print(df_tmt.group.unique())

# Save the dataframe and plot elsewhere

In [57]:


df_psth.to_pickle(r"D:\mz_Data\saved_dfs\Multi_brain_regions\redo_brain_regions\V1HPC_novel_psth.pkl")
df_spikes.to_pickle(r"D:\mz_Data\saved_dfs\Multi_brain_regions\redo_brain_regions\V1HPC_novel_spikes.pkl")
df_tmt.to_pickle(r"D:\mz_Data\saved_dfs\Multi_brain_regions\redo_brain_regions\V1HPC_novel_waveforms.pkl")



---