In [1]:
%load_ext autoreload
%autoreload 2

import os
from glob import glob
import json
import pickle

import numpy as np
import pandas as pd
import scipy as sp
from scipy.cluster.hierarchy import dendrogram, linkage
from sklearn.metrics import pairwise_distances
from sklearn.cluster import AgglomerativeClustering
import matplotlib.pyplot as plt
from matplotlib import cm, patches
import matplotlib.gridspec as gridspec
from tqdm.auto import tqdm
import pandarallel
from IPython.utils.capture import capture_output
with capture_output():
    tqdm.pandas()
    pandarallel.pandarallel.initialize(progress_bar=True)

from tbd_eeg.data_analysis.eegutils import EEGexp
from tbd_eeg.data_analysis.Utilities.utilities import get_stim_events, find_nearest_ind

from ipympl.backend_nbagg import Canvas
Canvas.header_visible.default_value = False
%matplotlib widget

In [2]:
# accessing the Google sheet with experiment metadata in python
# setting up the permissions:
# 1. install gspread (pip install gspread / conda install gspread)
# 2. copy the service_account.json file to '~/.config/gspread/service_account.json'
# 3. run the following:
import gspread
_gc = gspread.service_account() # need a key file to access the account (step 2)
_sh = _gc.open('Zap_Zip-log_exp') # open the spreadsheet
_df = pd.DataFrame(_sh.sheet1.get()) # load the first worksheet
gmetadata = _df.T.set_index(0).T # put it in a nicely formatted dataframe

In [3]:
rec_folder = '../tiny-blue-dot/zap-n-zip/EEG_exp/mouse571618/estim1_2021-04-29_12-28-54/experiment1/recording1/'
exp = EEGexp(rec_folder, preprocess=False, make_stim_csv=False)

Experiment type: electrical stimulation


In [4]:
# Let's print some meta data
print('Mouse: {}'.format(exp.mouse))
print('Experiment date: {}'.format(exp.date))
print('What data is in here?')
print(exp.experiment_data)

Mouse: 571618
Experiment date: 2021-04-29 12:28:54
What data is in here?
['probeB_sorted', 'probeC_sorted', 'probeF_sorted', 'recording1']


In [5]:
stim_log = pd.read_csv(exp.stimulus_log_file)
stim_log.head()

Unnamed: 0,stim_type,parameter,duration,onset,offset,sweep,good
0,biphasic,70,0.0004,326.809,326.8094,0,True
1,biphasic,40,0.0004,330.34202,330.34242,0,True
2,biphasic,70,0.0004,334.40364,334.40404,0,True
3,biphasic,90,0.0004,338.78593,338.78633,0,True
4,biphasic,90,0.0004,342.95554,342.95594,0,True


In [6]:
fig, ax = plt.subplots(figsize=(3.5, 3), constrained_layout=True)

ax.scatter(exp.EEG_channel_coordinates['ML'], exp.EEG_channel_coordinates['AP'], s=300, color='orange')
ax.scatter(-1.25, 1.5, marker='P', color='red')
ax.axis('equal')
    
for ind in range(len(exp.EEG_channel_coordinates)):
    ax.annotate(str(ind),  xy=(exp.EEG_channel_coordinates['ML'].iloc[ind], exp.EEG_channel_coordinates['AP'].iloc[ind]), ha='center', va='center', color="k")

ax.set_xlabel("ML axis (mm)\nmouse's left <--> right")
ax.set_ylabel('AP axis (mm)')
ax.set_title('NeuroNexus numbering');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [7]:
eeg_data = exp.load_eegdata(frequency=2500, return_type='pd')

In [8]:
plot_ch = 27 # choose which electrode to plot (zero-indexed, ch 30:31 do not exist)

fig, ax = plt.subplots(figsize=(7, 2.4), tight_layout=True)
eeg_data[plot_ch].plot(ax=ax)

# plot cosmetics
ax.set_xlabel('Time (s)')
ax.set_ylabel('Raw signal (uV)')
ax.set_title('EEG channel %d' % plot_ch);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Demarcating areas

In [13]:
exp.ephys_params.keys()

dict_keys(['probeB', 'probeC', 'probeF', 'EEG'])

In [14]:
probe = 'probeC'
lfp = np.memmap(exp.ephys_params[probe]['lfp_continuous'], dtype='int16', mode='r')
lfp = np.reshape(lfp, (int(lfp.size/exp.ephys_params[probe]['num_chs']), exp.ephys_params[probe]['num_chs']))
samp_rate = exp.ephys_params[probe]['lfp_sample_rate']
timestamps = np.load(exp.ephys_params[probe]['lfp_timestamps'])

t_lim_s = np.array([2, 6])
sampled_lfp = pd.DataFrame(lfp[slice(*t_lim_s*samp_rate), EEGexp.NPX_lfp_channel_order].T)

In [15]:
correlations = pd.DataFrame(pairwise_distances(sampled_lfp, metric='euclidean'))

f, (ax, ax2, axa, axp) = plt.subplots(
    1, 4, figsize=(10, 3), constrained_layout=True,
    gridspec_kw=dict(width_ratios=[3, 1.5, 0.2, 0.2]), sharey=True
)
v = np.quantile(sampled_lfp, q=[0.01, 0.99])
ax.imshow(
    sampled_lfp, aspect='auto', cmap=cm.bwr, vmin=v[0], vmax=v[1],
    extent=[*t_lim_s, 0, sampled_lfp.shape[0]], origin='lower'
)
ax.set_xlabel('time (s)')
ax.set_ylabel('channel')

vn, vx = np.quantile(correlations, [0.05, 0.95])
ax2.imshow(
    correlations, aspect='auto', cmap=cm.bwr, vmin=vn, vmax=vx, origin='lower'
)
ax2.set_xlabel('channel')

ax.set_title('LFP')
ax2.set_title('distance');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [57]:
# remove very different channels

mean_distances = correlations.mean().rename('mean_distances')
_mn = mean_distances.rolling(10, center=True).median()
_sd = mean_distances.rolling(10, center=True).std()

f, ax = plt.subplots(figsize=(7, 2.5), tight_layout=True)
ax.plot(mean_distances)
ax.fill_between(_mn.index, _mn-_sd.mean(), _mn+_sd.mean(), facecolor=cm.Greys(0.5, 0.5))
ax.set_xlabel('channel')
ax.set_ylabel('mean distance')

# which channels deviate by more than 2 local SD?
idx_normal = (((mean_distances-_mn).abs()-_sd.mean())<0)
sampled_lfp_normal = sampled_lfp[idx_normal]

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [58]:
correlations = pd.DataFrame(
    pairwise_distances(sampled_lfp_normal, metric='cosine'),
    index=sampled_lfp_normal.index, columns=sampled_lfp_normal.index
)

f, (ax, ax2, axa, axp) = plt.subplots(
    1, 4, figsize=(10, 3), constrained_layout=True,
    gridspec_kw=dict(width_ratios=[3, 1.5, 0.2, 0.2]), sharey=True
)
v = np.quantile(sampled_lfp_normal, q=[0.01, 0.99])
ax.imshow(
    sampled_lfp_normal, aspect='auto', cmap=cm.bwr, vmin=v[0], vmax=v[1],
    extent=[*t_lim_s, 0, sampled_lfp_normal.shape[0]], origin='lower'
)
ax.set_xlabel('time (s)')
ax.set_ylabel('channel')

vn, vx = np.quantile(correlations, [0.01, 0.99])
ax2.imshow(
    correlations, aspect='auto', cmap=cm.bwr, vmin=vn, vmax=vx, origin='lower'
)
ax2.set_xlabel('channel')

ax.set_title('LFP')
ax2.set_title('correlations');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
f, ax = plt.subplots(figsize=(3, 3), tight_layout=True)
ax.scatter(sampled_lfp_normal.iloc[275], sampled_lfp_normal.iloc[277]);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# lets look at the cosine distance along an offset diagonal
offsets = [30, 40, 50, 60, 70, 80, 90, 100]
f, ax = plt.subplots(figsize=(8, 2.5), tight_layout=True)
for offset in offsets:
    diag_similarity = pd.Series(
        np.diagonal(correlations, offset=offset),
        index=correlations.index[:-offset]
    )
    diag_similarity.rolling(20, center=True).median().plot(
        ax=ax, label=f'distance = {offset}',
        c=cm.Reds(0.3+0.7*offsets.index(offset)/len(offsets))
    )
ax.set_xlabel('channel')
ax.set_ylabel('similarity at distance')
ax.legend(fontsize=8);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [16]:
def hierarchical_clusters(features, n_clusters=4, link='ward', pl=False):
    cluster_data = {}
    cluster_data['z'] = linkage(
        features, link
    )
    if pl:
        f, ax = plt.subplots(1, 1, figsize=(4, 3), tight_layout=True)
        cluster_data['dendrogram'] = dendrogram(
            cluster_data['z'], p=n_clusters,
            truncate_mode='lastp', count_sort='ascending', show_contracted=True, ax=ax
        )
    clustering = AgglomerativeClustering(
        linkage=link, n_clusters=n_clusters
    )
    clustering.fit(features)
    cluster_data['clustering'] = clustering
    clusters = clustering.labels_
    cluster_data['clusters'] = clusters
    return cluster_data

In [17]:
_c = correlations.reset_index()
_c['index'] = _c.index/_c.index.values.mean()*correlations.values.mean()*10
cdata = hierarchical_clusters(_c, pl=True)
clusters = cdata['clusters']

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
f, (ax, axa, ax2, axp) = plt.subplots(
    1, 4, figsize=(10, 3), constrained_layout=True,
    gridspec_kw=dict(width_ratios=[3, 0.2, 1.5, 0.2]), sharey=True
)
v = np.quantile(sampled_lfp_normal, q=[0.01, 0.99])
ax.imshow(
    sampled_lfp_normal, aspect='auto', cmap=cm.bwr, vmin=v[0], vmax=v[1],
    extent=[*t_lim_s, 0, sampled_lfp_normal.shape[0]], origin='lower'
)
ax.set_xlabel('time (s)')
ax.set_ylabel('channel')

vn, vx = np.quantile(correlations, [0.01, 0.99])
ax2.imshow(
    correlations, aspect='auto', cmap=cm.bwr, vmin=vn, vmax=vx, origin='lower'
)
ax2.set_xlabel('channel')

axa.imshow(clusters[:, np.newaxis], aspect='auto', origin='lower', cmap=cm.Dark2)
axa.set_xticks([])
axa.set_xlabel('clusters')

ax.set_title('LFP')
ax2.set_title('correlations');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [70]:
# repeat hierarchical clustering on multiple short windows and then see what groupings of channels occurs consistently
def get_cluster_ids(lfp, t_start_s, win_length_s, n_clusters=4):
    t_lim_s = np.array([t_start_s, t_start_s+win_length_s])
    sampled_lfp = pd.DataFrame(lfp[slice(*t_lim_s*samp_rate), EEGexp.NPX_lfp_channel_order].T)
    
    distances = pd.DataFrame(pairwise_distances(sampled_lfp, metric='euclidean'))
    mean_distances = distances.mean().rename('mean_distances')
    _mn = mean_distances.rolling(10, center=True).median()
    _sd = mean_distances.rolling(10, center=True).std()
    idx_normal = (((mean_distances-_mn).abs()-_sd.mean())<0)
    sampled_lfp_normal = sampled_lfp[idx_normal]
    
    correlations = pd.DataFrame(
        pairwise_distances(sampled_lfp_normal, metric='correlation'),
        index=sampled_lfp_normal.index, columns=sampled_lfp_normal.index
    )
    
    _c = correlations.reset_index()
    _c['index'] = _c.index/_c.index.values.mean()*correlations.values.mean()*np.sqrt(correlations.shape[0])
    cdata = hierarchical_clusters(_c, n_clusters=n_clusters)
    return pd.Series(cdata['clusters'], index=sampled_lfp_normal.index)

In [94]:
n_clusters = 5
winsize = 4
cids = []
for t_start in tqdm(np.arange(50)*winsize):
    cids.append(get_cluster_ids(lfp, t_start, winsize, n_clusters=n_clusters))

clusters = pd.concat(cids, axis=1)
clusters.head()

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))




Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49
0,1,1,1,1,1,1,1,4,3,3,1,3,3,3,3,3,1,3,3,3,1,3,1,3,4,3,3,4,1,1,3,3,3,3,3,1,3,1,3,3,3,3,3,3,3,3,3,3,3,1
1,1,1,1,1,1,1,1,4,3,3,1,3,3,3,3,3,1,3,3,3,1,3,1,3,4,3,3,4,1,1,3,3,3,3,3,1,3,1,3,3,3,3,3,3,3,3,3,3,3,1
2,1,1,1,1,1,1,1,4,3,3,1,3,3,3,3,3,1,3,3,3,1,3,1,3,4,3,3,4,1,1,3,3,3,3,3,1,3,1,3,3,3,3,3,3,3,3,3,3,3,1
3,1,1,1,1,1,1,1,4,3,3,1,3,3,3,3,3,1,3,3,3,1,3,1,3,4,3,3,4,1,1,3,3,3,3,3,1,3,1,3,3,3,3,3,3,3,3,3,3,3,1
4,1,1,1,1,1,1,1,4,3,3,1,3,3,3,3,3,1,3,3,3,1,3,1,3,4,3,3,4,1,1,3,3,3,3,3,1,3,1,3,3,3,3,3,3,3,3,3,3,3,1


In [95]:
final_clusters = hierarchical_clusters(clusters.bfill(limit=3).ffill(limit=3).dropna(axis=1), n_clusters=n_clusters)['clusters']
correlations = pd.DataFrame(pairwise_distances(sampled_lfp, metric='correlation'))

f, (ax, axa, ax2, axp) = plt.subplots(
    1, 4, figsize=(10, 3.5), constrained_layout=True,
    gridspec_kw=dict(width_ratios=[3, 0.2, 1.5, 0.2]), sharey=True
)
v = np.quantile(sampled_lfp, q=[0.01, 0.99])
ax.imshow(
    sampled_lfp, aspect='auto', cmap=cm.bwr, vmin=v[0], vmax=v[1],
    extent=[*t_lim_s, 0, sampled_lfp.shape[0]], origin='lower'
)
ax.set_xlabel('time (s)')
ax.set_ylabel('channel')

vn, vx = np.quantile(correlations, [0.01, 0.99])
ax2.imshow(
    correlations, aspect='auto', cmap=cm.bwr, vmin=vn, vmax=vx, origin='lower'
)
ax2.set_xlabel('channel')

axa.imshow(final_clusters[:, np.newaxis], aspect='auto', origin='lower', cmap=cm.Dark2)
axa.set_xticks([])
axa.set_xlabel('clusters')

ax.set_title('LFP')
ax2.set_title('similarity')

f.suptitle(f'mouse{exp.mouse}, {probe}');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [65]:
c = hierarchical_clusters(correlations, 4)['clusters']
axp.imshow(c[:, np.newaxis], aspect='auto', origin='lower', cmap=cm.Dark2)
axp.set_xticks([])
axp.set_xlabel('clusters')

  after removing the cwd from sys.path.
  return linkage(y, method='ward', metric='euclidean')


Text(0.5, 18.167000000000016, 'clusters')

# Let us look at evoked traces

In [9]:
# add event data to eeg_data frame for easy manipulalation
_stim_log = stim_log.rename_axis('eid').reset_index().set_index('onset', drop=False).reindex(
        eeg_data.index, method='nearest'
    )
_stim_log = _stim_log.join((eeg_data.index - _stim_log.onset).to_frame(name='time'))
_stim_log['time'] = (_stim_log.time*10000/4).astype(int)*4/10000 # align to a multiple of 0.0004
eeg_data.index = pd.MultiIndex.from_frame(_stim_log[['eid', 'time']])
eeg_data

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29
eid,time,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1
0,-315.1956,-9.165000,-9.165000,-3.705000,-3.900000,-55.769998,71.174997,27.689999,-26.714999,64.154998,33.149999,4.875000,-8.580000,7.605000,28.079999,-19.694999,-72.929997,-37.439999,-19.499999,-28.859999,-61.814998,-14.624999,58.694998,-15.014999,1.950000,64.349998,15.989999,-12.090000,-0.390000,-0.390000,-2.730000
0,-315.1952,-59.084998,-40.169999,-26.324999,-71.564997,-159.704994,43.289998,-5.850000,-72.734997,19.889999,-15.794999,-34.904999,-52.454998,-32.759999,-40.559999,-92.429997,-6.435000,19.304999,30.614999,7.605000,-73.709997,7.020000,84.824997,4.095000,-20.279999,37.829999,4.680000,4.485000,-17.744999,-5.850000,18.134999
0,-315.1948,40.754999,57.329998,60.449998,-15.404999,-158.729994,153.464994,130.454995,46.409998,125.969995,71.564997,53.234998,16.574999,7.605000,48.944998,-2.535000,49.919998,76.439997,93.014997,73.514997,-20.669999,79.949997,159.314994,95.159997,129.674995,150.149994,122.264996,119.924996,78.584997,91.454997,105.299996
0,-315.1944,58.304998,62.009998,51.284998,-52.844998,-218.789992,140.594995,112.904996,35.489999,117.194996,76.244997,59.084998,30.614999,28.469999,72.734997,32.369999,32.954999,64.739998,80.339997,63.959998,-84.629997,39.389999,124.409995,31.004999,-5.655000,113.684996,39.974999,-24.764999,45.824998,41.924998,-4.875000
0,-315.1940,50.114998,32.564999,21.644999,-119.339996,-339.299988,115.634996,106.469996,20.864999,101.789996,56.549998,31.199999,-15.404999,15.989999,36.659999,11.505000,32.954999,68.054998,80.339997,58.499998,-55.379998,58.109998,123.044995,35.294999,20.279999,101.399996,36.269999,12.870000,29.639999,34.514999,9.750000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
719,68.0284,-11.310000,-18.524999,-17.939999,248.624991,537.029980,-20.474999,-23.204999,-16.769999,-10.920000,-11.310000,-17.549999,-45.434998,-38.219999,-11.700000,-12.480000,-19.499999,-14.039999,-23.009999,-21.644999,69.419997,-16.769999,-15.794999,-19.499999,-13.844999,-23.009999,-9.165000,-11.310000,-15.794999,-14.039999,-5.460000
719,68.0288,-8.775000,-21.059999,-14.819999,284.699990,607.229978,-17.939999,-13.844999,-16.964999,-13.844999,-15.404999,-15.014999,-57.914998,-17.549999,-16.964999,-17.159999,-12.285000,-14.039999,-21.839999,-16.964999,76.244997,-8.580000,-13.455000,-4.680000,-7.605000,-10.725000,-5.655000,-3.510000,-7.410000,-12.285000,-3.705000
719,68.0292,-13.455000,-18.914999,-18.524999,290.744989,600.599978,-16.574999,-11.505000,-15.404999,-16.379999,-9.750000,-17.549999,-38.999999,-27.689999,-11.700000,-16.574999,-7.020000,-11.700000,-16.769999,-10.920000,63.764998,-7.800000,-8.970000,-12.285000,-5.460000,-11.310000,-3.900000,-2.535000,-5.655000,-10.725000,1.560000
719,68.0296,-13.260000,-16.769999,-11.895000,282.359990,581.099979,-14.234999,-12.480000,-15.404999,-16.964999,-6.240000,-5.265000,-38.999999,-8.970000,-12.285000,-12.480000,-7.215000,-11.700000,-17.939999,-15.599999,49.919998,-5.850000,-12.285000,-6.825000,-6.045000,-11.310000,-4.485000,-1.950000,-9.945000,-9.750000,-4.875000


In [10]:
print(set(stim_log.parameter))
print(set(stim_log.sweep))
print(set(stim_log.stim_type))

{40, 90, 70}
{0, 1}
{'biphasic'}


In [11]:
pick_stim = 'biphasic' # stimulus type, we often do 'biphasic' (electrical) and 'fullscreen' (for visual)
pick_param = 70 # parameter is the amplitude if it is 'biphasic' or color ('white') if it is a visual stim
pick_sweep = 0 # we deliver multiple sessions during the recordings, in different states (0 is always awake, the ephys notes excel file will tell you which sweeps belong to which state)

events = stim_log[
    (stim_log.stim_type==pick_stim)&\
    (stim_log.parameter==pick_param)&\
    (stim_log.sweep==pick_sweep)&\
    stim_log.good
].index
events

Int64Index([  0,   2,  12,  14,  18,  22,  24,  27,  29,  35,
            ...
            321, 323, 330, 333, 336, 337, 339, 342, 348, 354],
           dtype='int64', length=119)

In [12]:
# keep only responses between -2 to 2 s around the event onset times
eeg_responses = eeg_data.loc[(events, slice(-2, 2)), :]

In [13]:
# plot mean responses for all channels
mean_responses = eeg_responses.groupby(level='time').mean()

def find_saturation_peaks(col):
    idx_peak = col.loc[0.0005:].abs().idxmax()
    val_peak = col.loc[idx_peak]
    return pd.Series(dict(idx_peak=idx_peak, val_peak=val_peak))
peak_locations = mean_responses.apply(find_saturation_peaks).T

f, (ax, ax2) = plt.subplots(1, 2, figsize=(10, 4), tight_layout=True, sharey=True, sharex=True)
try:
    title = f'mouse{exp.mouse} {gmetadata[gmetadata.mouse_name==f"mouse{exp.mouse}"]["brain states"].values[0].split("/")[pick_sweep]}'
except:
    title = f'mouse{exp.mouse} -unknown state-'
f.suptitle(title)
ax.set_title('original')
ax2.set_title('original - exponential')

mean_responses.plot(ax=ax, lw=0.5, c='k', legend=None)
ax.scatter(peak_locations.idx_peak, peak_locations.val_peak, s=2, c='r')
ax.axvspan(0, 0.0004, facecolor='r')
ax.set_xlim(-0.1, 1.5);
# ax.legend(loc=(1.01, 0), fontsize=(8), ncol=3);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
# fit between first peak and a relaxed time
def extract_relaxation(col):
    baseline = col.loc[:0].mean()
    idx_peak = col.loc[0.001:].abs().idxmax()
    val_peak = col.loc[idx_peak]
    
    val_relaxed = baseline + (val_peak - baseline)/np.e**3
    idx_relaxed = (col.loc[0:] - val_relaxed).abs().idxmin()
    val_relaxed = col.loc[idx_relaxed]
    
    return col.loc[idx_peak:idx_relaxed]

def expfn(x, offset, scale, amplitude):
    return offset + amplitude*np.exp(-scale*x)

def fit_exponential(col):
    try:
        return pd.Series(
            sp.optimize.curve_fit(expfn, col.dropna().index, col.dropna().values, p0=(0, 1, 500))[0],
            index=['offset', 'scale', 'amplitude']
        )
    except:
        return (0, 0, 0)

In [15]:
_relaxation = mean_responses.apply(extract_relaxation)
_relaxation.plot(ax=ax, lw=1, c='r', linestyle='dotted', legend=None)
_params = _relaxation.progress_apply(fit_exponential)

def _plot_fit(col):
    x = mean_responses.loc[0:].index
    ax.plot(x, expfn(x, *col), c='b', lw=0.4, label='_'*int(col.name)+'exponential fit')
    return

_params.dropna(axis=1).apply(_plot_fit);

# ax.scatter(_x.idx_relaxed, _x.val_relaxed, s=10, c='r')

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))






In [16]:
# subtract the relaxation and see if the mean responses look realistic
_corrected_responses = mean_responses.apply(
    lambda col: col - expfn(
        _relaxation[col.name].dropna().index.to_series(),
        *_params[col.name]
    ).reindex(col.index, fill_value=0)
)
_corrected_responses.plot(ax=ax2, lw=0.5, c='k', legend=False)
ax2.axvspan(0, 0.0004, facecolor='r')
ax2.set_xlim(-0.1, 1.5)
ax2.set_ylim(-500, 1000);

In [17]:
fig, ax = plt.subplots(figsize=(4, 3), constrained_layout=True)

sc = ax.scatter(exp.EEG_channel_coordinates['ML'], exp.EEG_channel_coordinates['AP'], s=300, c=_params.T.scale, vmin=1, vmax=15)
ax.scatter(0, 0, marker='P', color='red')
ax.axis('equal')
plt.colorbar(sc, ax=ax, label='relaxation timescale (s)')
    
for ind in range(len(exp.EEG_channel_coordinates)):
    ax.annotate(f'{_params.T.scale[ind]:.0f}',  xy=(exp.EEG_channel_coordinates['ML'].iloc[ind], exp.EEG_channel_coordinates['AP'].iloc[ind]), ha='center', va='center', color="k")

ax.set_xlabel("ML axis (mm)\nmouse's left <--> right")
ax.set_ylabel('AP axis (mm)')
try:
    title = f'Relaxation timescale (s) ({gmetadata[gmetadata.mouse_name==f"mouse{exp.mouse}"]["brain states"].values[0].split("/")[pick_sweep]})'
except:
    title = f'Relaxation timescale (s) (-unknown state-)'
ax.set_title(title);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [34]:
fig, ax = plt.subplots(figsize=(4, 3), constrained_layout=True)

sc = ax.scatter(exp.EEG_channel_coordinates['ML'], exp.EEG_channel_coordinates['AP'], s=300, c=_params.T.amplitude)
ax.scatter(0, 0, marker='P', color='red')
ax.axis('equal')
plt.colorbar(sc, ax=ax, label='artifact amplitude')
    
for ind in range(len(exp.EEG_channel_coordinates)):
    ax.annotate(
        f'{_params.T.amplitude[ind]:.0f}',
        xy=(exp.EEG_channel_coordinates['ML'].iloc[ind],
            exp.EEG_channel_coordinates['AP'].iloc[ind]),
        ha='center', va='center', color="k", fontsize=8
    )

ax.set_xlabel("ML axis (mm)\nmouse's left <--> right")
ax.set_ylabel('AP axis (mm)')
try:
    title = f'Amplitude ({gmetadata[gmetadata.mouse_name==f"mouse{exp.mouse}"]["brain states"].values[0].split("/")[pick_sweep]})'
except:
    title = f'Amplitude (-unknown state-)'
ax.set_title(title);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

---

In [19]:
choose_probe = 'probeB' # leave off the '_sorted' for selecting the EEGexp.ephys_params associated with the probe

In [20]:
unit_meta = pd.read_csv(exp.ephys_params[choose_probe]['cluster_metrics'], index_col=0)

In [21]:
spike_times = np.load(exp.ephys_params[choose_probe]['spike_times'])
spike_clusters = np.load(exp.ephys_params[choose_probe]['spike_clusters'])
spike_times = pd.DataFrame(np.array([spike_clusters, spike_times]).T, columns=['unit', 'times'])
spike_times = spike_times.groupby('unit').apply(lambda df: df.times.values)

In [59]:
def bin_spikes(row, tmin=0, tmax=1000):
    row = row[(row>tmin)&(row<tmax)]
    bins = np.zeros(int((tmax-tmin)*1000), dtype=bool)
    bins[((row-tmin)*1000).astype(int)] = True
    times = np.arange(int(tmin*1000), int(tmax*1000), dtype=int)[:len(bins)]
    times = times/1e3
    return pd.Series(bins, index=times)

In [44]:
stim_log[(stim_log.parameter=='80')&(stim_log.sweep==0)]

Unnamed: 0,stim_type,parameter,onset,offset,duration,sweep,good
0,biphasic,80,33.63939,33.63999,0.0004,0,True
1,biphasic,80,37.74180,37.74240,0.0004,0,True
2,biphasic,80,41.50520,41.50580,0.0004,0,True
3,biphasic,80,45.65971,45.66031,0.0004,0,True
4,biphasic,80,49.84935,49.84996,0.0004,0,True
...,...,...,...,...,...,...,...
95,biphasic,80,410.96337,410.96397,0.0004,0,True
96,biphasic,80,415.10811,415.10871,0.0004,0,True
97,biphasic,80,418.98714,418.98774,0.0004,0,True
98,biphasic,80,423.29914,423.29974,0.0004,0,True


In [90]:
stim_id = 99
spikes = {}
tmin = stim_log.loc[stim_id, 'onset']-2
tmax = stim_log.loc[stim_id, 'offset']+2
for i, row in spike_times.items():
    spikes[i] = bin_spikes(row, tmin, tmax)
spikes = pd.concat(spikes, axis=1).T

In [100]:
f, ax = plt.subplots(1, 1, figsize=(6, 3), tight_layout=True)
ax.imshow(spikes, aspect='auto', cmap=cm.Greys_r, interpolation='none');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [92]:
spikes.index = pd.MultiIndex.from_arrays([
    spikes.index,
    unit_meta.set_index('cluster_id').loc[spikes.index, 'peak_channel']
], names=['unit', 'channel'])
spikes

Unnamed: 0_level_0,Unnamed: 1_level_0,425.332,425.333,425.334,425.335,425.336,425.337,425.338,425.339,425.340,425.341,425.342,425.343,425.344,425.345,425.346,425.347,425.348,425.349,425.350,425.351,425.352,425.353,425.354,425.355,425.356,...,429.307,429.308,429.309,429.310,429.311,429.312,429.313,429.314,429.315,429.316,429.317,429.318,429.319,429.320,429.321,429.322,429.323,429.324,429.325,429.326,429.327,429.328,429.329,429.330,429.331
unit,channel,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1
0.0,0,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False
1.0,0,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
2.0,0,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
3.0,3,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
4.0,1,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,...,False,False,False,True,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
530.0,185,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
531.0,209,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
532.0,149,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
533.0,153,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False


In [130]:
f, ax = plt.subplots(1, 1, figsize=(6, 3), tight_layout=True)
ax.imshow(_stacked_spikes.sort_index(level='channel'), aspect='auto', cmap=cm.Greys_r, interpolation='none');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [94]:
_spikes = spikes.sort_index(level='channel')

In [127]:
_stacked_spikes = []
for channel in range(384):
    try:
        _df = _spikes.xs(channel, level='channel', drop_level=False)
        _stacked_spikes.append(_df)
    except:
        _stacked_spikes.append(pd.DataFrame(0, index=pd.MultiIndex.from_arrays([[-1], [channel]]), columns=_spikes.columns))

In [128]:
_stacked_spikes = pd.concat(_stacked_spikes).sort_index(level='channel')

In [129]:
f, ax = plt.subplots(1, 1, figsize=(6, 3), tight_layout=True)
sns.heatmap(_stacked_spikes.droplevel(0), ax=ax);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [133]:
# f = plt.figure(figsize=(6, 3), tight_layout=True)
sns.clustermap(_stacked_spikes.droplevel(0), col_cluster=False, );

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

