In [1]:
%load_ext autoreload
%autoreload 2

# import os
# from glob import glob
# import json
# import pickle
import json

import numpy as np
import pandas as pd
import scipy as sp
from scipy.cluster.hierarchy import dendrogram, linkage
from sklearn.metrics import pairwise_distances
from sklearn.cluster import AgglomerativeClustering
import matplotlib.pyplot as plt
from matplotlib import cm#, patches
# import matplotlib.gridspec as gridspec
from tqdm.auto import tqdm
# import pandarallel
from IPython.utils.capture import capture_output
with capture_output():
    tqdm.pandas()
#     pandarallel.pandarallel.initialize(progress_bar=True)

import elephant
from neo.core import AnalogSignal
import quantities as pq

from tbd_eeg.data_analysis.eegutils import EEGexp
from tbd_eeg.data_analysis.Utilities.utilities import get_stim_events, find_nearest_ind

from ipympl.backend_nbagg import Canvas
Canvas.header_visible.default_value = False
%matplotlib widget

In [2]:
# accessing the Google sheet with experiment metadata in python
# setting up the permissions:
# 1. install gspread (pip install gspread / conda install gspread)
# 2. copy the service_account.json file to '~/.config/gspread/service_account.json'
# 3. run the following:
import gspread
_gc = gspread.service_account() # need a key file to access the account (step 2)
_sh = _gc.open('Zap_Zip-log_exp') # open the spreadsheet
_df = pd.DataFrame(_sh.sheet1.get()) # load the first worksheet
gmetadata = _df.T.set_index(0).T # put it in a nicely formatted dataframe

In [3]:
gmetadata

Unnamed: 0,mouse_name,exp_name,brain states,stimulation,visual_stim,audio_stim,ISI (sec),stimulus duration (msec),Current (uA),Cortical Area stimulation,N trials per stimulus,EEG bad_channels,Npx,Units Sorted (X),Brain slices (X),Pupil tracking pre-processing,Brain areas assignment,Notes
1,mouse496220,audio_vis1_2020-06-10_14-54-43,awake/ISO,sensory,black/white,whitenoise/10000,5,250,,,60,,,,,,,
2,mouse496220,audio_vis2_2020-06-11_11-42-47,awake/ISO,sensory,black/white,whitenoise/10000,5,250,,,60,29.0,,,,,,
3,mouse496220,audio_vis3_2020-06-16_10-35-57,run/resting,sensory,black,whitenoise,5,250,,,20,,,,,,,
4,mouse496220,audio_vis4_2020-06-18_13-49-17,run/resting,sensory,black/white,whitenoise/10000,5,250,,,60,,,,,,,
5,mouse521885,audio_vis1_2020-07-08_12-37-58,awake/ISO,sensory,black/white,whitenoise/10000,[3.5 4.5],250,,,50,6.0,,,,,,
6,mouse521885,estim1_2020-07-09_14-23-49,awake/ISO,electrical,,,[3.5 4.5],0.2,20/50/100,M2,60,613141112.0,,,,,,
7,mouse521886,audio_vis1_2020-07-15_13-28-29,awake/ISO,sensory,black/white,whitenoise/10000,[3.5 4.5],250,,,60,,,,,,,
8,mouse521886,estim1_2020-07-16_13-37-02,awake/ISO/recovery,electrical,,,[3.5 4.5],0.2,20/50/100,M2,100,1.018111213141516e+17,,,,,,
9,mouse521887,audio1_2020-07-29_09-13-05,awake/awake/awake/ISO/ISO/ISO_low/recovery/rec...,sensory,,whitenoise/10000,2.5,250,,,100,187.0,,,,,,
10,mouse521887,estim1_2020-07-30_11-25-05,awake/awake,electrical,,,[3.5 4.5],0.2,20,M2,100,411121314187.0,,,,,,


In [5]:
rec_folder = '../tiny-blue-dot/zap-n-zip/EEG_exp/mouse569062/estim_vis_2021-02-18_11-17-51/experiment1/recording1/'
exp = EEGexp(rec_folder, preprocess=False, make_stim_csv=False)

print('What data is in here?')
print(exp.experiment_data)

Experiment type: electrical and sensory stimulation
What data is in here?
['probeB_sorted', 'probeD_sorted', 'probeF_sorted', 'recording1']


# Demarcating areas

In [6]:
# select probe
probe = 'probeB'

# load lfp
lfp = np.memmap(exp.ephys_params[probe]['lfp_continuous'], dtype='int16', mode='r')
lfp = np.reshape(lfp, (int(lfp.size/exp.ephys_params[probe]['num_chs']), exp.ephys_params[probe]['num_chs']))
samp_rate = exp.ephys_params[probe]['lfp_sample_rate']
timestamps = np.load(exp.ephys_params[probe]['lfp_timestamps'])

---
**unrelated stuff - CSD trial**

In [7]:
# put lfp into a dataframe and add stimulation metadata
lfp = pd.DataFrame(lfp, index=timestamps)

stim_log = pd.read_csv(exp.stimulus_log_file)
stim_log.rename_axis(index='stim_id', inplace=True)

# assign stimulus information at each timestamp, including -1s onwards relative to stimulus to that block
idx = stim_log.reset_index().set_index('onset')
idx.index = idx.index - 1
idx = idx.reindex(timestamps, method='ffill', limit=2500*4).reset_index()

def _reset_index_time(df):
    df['onset'] = (df.onset - df.onset.iloc[0] - 1).round(4)
    return df
idx = idx.groupby('stim_id').apply(_reset_index_time).drop(['offset', 'duration'], axis=1)

lfp.index = pd.MultiIndex.from_frame(idx)
lfp.columns = pd.MultiIndex.from_arrays([
    range(len(lfp.columns)),
    lfp.columns.map(lambda x: -3840+(x+1)//2*20)
], names=['channel', 'depth'])
# lfp = lfp.xs(True, level='good')
_t = idx.sweep.replace([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], True).fillna(False)
_t.index = lfp.index
lfp = lfp[_t]
lfp.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,channel,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,...,359,360,361,362,363,364,365,366,367,368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,depth,-3840,-3820,-3820,-3800,-3800,-3780,-3780,-3760,-3760,-3740,-3740,-3720,-3720,-3700,-3700,-3680,-3680,-3660,-3660,-3640,-3640,-3620,-3620,-3600,-3600,...,-240,-240,-220,-220,-200,-200,-180,-180,-160,-160,-140,-140,-120,-120,-100,-100,-80,-80,-60,-60,-40,-40,-20,-20,0
onset,stim_id,stim_type,parameter,sweep,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2,Unnamed: 27_level_2,Unnamed: 28_level_2,Unnamed: 29_level_2,Unnamed: 30_level_2,Unnamed: 31_level_2,Unnamed: 32_level_2,Unnamed: 33_level_2,Unnamed: 34_level_2,Unnamed: 35_level_2,Unnamed: 36_level_2,Unnamed: 37_level_2,Unnamed: 38_level_2,Unnamed: 39_level_2,Unnamed: 40_level_2,Unnamed: 41_level_2,Unnamed: 42_level_2,Unnamed: 43_level_2,Unnamed: 44_level_2,Unnamed: 45_level_2,Unnamed: 46_level_2,Unnamed: 47_level_2,Unnamed: 48_level_2,Unnamed: 49_level_2,Unnamed: 50_level_2,Unnamed: 51_level_2,Unnamed: 52_level_2,Unnamed: 53_level_2,Unnamed: 54_level_2,Unnamed: 55_level_2
-1.0,0.0,biphasic,30,0.0,-457,12,-421,-204,-433,-180,-313,204,-541,-180,-469,-264,-168,60,-517,-60,-589,-589,-625,-385,-721,-216,-529,-685,-349,...,84,240,156,240,192,325,240,493,637,409,-36,276,288,180,276,385,553,216,373,445,-216,264,96,589,228
-0.9996,0.0,biphasic,30,0.0,-589,-192,-625,-409,-709,-481,-541,-156,-877,-589,-709,-673,-481,-445,-733,-397,-757,-938,-757,-613,-793,-481,-601,-805,-505,...,-36,60,-24,60,0,96,24,264,457,192,-204,132,120,36,84,240,433,84,132,313,-373,132,-84,529,96
-0.9992,0.0,biphasic,30,0.0,-1130,-1166,-745,-817,-709,-553,-421,-36,-589,-445,-505,-409,-288,-120,-673,-204,-673,-745,-637,-409,-613,-228,-469,-625,-565,...,120,12,-96,108,12,180,144,373,529,325,-96,204,228,48,84,216,409,72,180,409,-300,252,60,661,228
-0.9988,0.0,biphasic,30,0.0,-505,-216,-409,-216,-385,-120,-180,264,-349,-156,-276,-204,-36,84,-385,48,-421,-457,-373,-144,-300,72,-228,-349,-421,...,240,168,60,240,132,337,216,493,697,433,24,313,325,204,264,349,553,252,300,541,-180,397,180,817,421
-0.9984,0.0,biphasic,30,0.0,-96,325,-96,108,-132,168,24,457,-156,60,0,60,240,409,-276,288,-373,-132,-505,48,-264,313,-72,-120,-96,...,397,300,228,433,276,493,373,637,805,589,156,469,469,325,397,517,649,385,457,649,-36,517,337,865,493


In [7]:
print('Available parameters:')
aparams = list(stim_log.apply(lambda row: (row.stim_type, row.parameter), axis=1).unique())
print(aparams)

Available parameters:
[('biphasic', '30'), ('biphasic', '50'), ('biphasic', '70'), ('fullscreen', 'white')]


In [9]:
def smoothen_df(df, win=5):
    win = sp.signal.hann(win)
    df2 = df.copy()
    for t in df.columns:
        df2[t] = sp.signal.convolve(df[t], win, mode='same') / sum(win)
    return df2

def compute_csd(lfp_df, spacing, method='sl', sf=2500, smoothen=False, **kwargs):
    if method=='sl':
        # Need to pad lfp channels for Laplacian approx.
        padded_lfp = np.pad(
            lfp_df, pad_width=((1, 1), (0, 0)), mode='edge'
        )
        csd = (1 / (spacing ** 2)) * (
            padded_lfp[2:, :] - (2 * padded_lfp[1:-1, :]) + padded_lfp[:-2, :]
        )
    else:
        params = dict(diam=500E-6*pq.m, sigma=0.3*pq.S/pq.m, sigma_top=0.3*pq.S/pq.m, f_type='gaussian', f_order=(3, 1))
        if method=='StepiCSD':
            params.update(dict(h=20E-6*pq.m, tol=1E-12))
        if method=='SplineiCSD':
            params.update(dict(num_steps=len(lfp_df), tol=1E-12, f_order=(20, 5)))
        params.update(kwargs)
        coord = lfp_df.columns.remove_unused_levels().to_frame()[['depth']].values*pq.um
        # select only the vertical dimension
        coord = coord[:, np.newaxis]
        neo_lfp = AnalogSignal(lfp_df, units='V', sampling_rate=sf*pq.Hz)
        csd = elephant.current_source_density.estimate_csd(neo_lfp, coords=coord, method=method, **params)
    csd = pd.DataFrame(csd, index=lfp_df.index, columns=lfp_df.columns)
    if smoothen:
        csd = smoothen_df(csd.T).T
    return csd

In [10]:
# plot LFP for an example trial
_lfp = lfp.loc[(slice(None), 3, slice(None), slice(None), 0), :]
_lfp = smoothen_df(_lfp.T, win=10).T
_lfp

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,channel,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,...,359,360,361,362,363,364,365,366,367,368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,depth,-3840,-3820,-3820,-3800,-3800,-3780,-3780,-3760,-3760,-3740,-3740,-3720,-3720,-3700,-3700,-3680,-3680,-3660,-3660,-3640,-3640,-3620,-3620,-3600,-3600,...,-240,-240,-220,-220,-200,-200,-180,-180,-160,-160,-140,-140,-120,-120,-100,-100,-80,-80,-60,-60,-40,-40,-20,-20,0
onset,stim_id,stim_type,parameter,sweep,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2,Unnamed: 27_level_2,Unnamed: 28_level_2,Unnamed: 29_level_2,Unnamed: 30_level_2,Unnamed: 31_level_2,Unnamed: 32_level_2,Unnamed: 33_level_2,Unnamed: 34_level_2,Unnamed: 35_level_2,Unnamed: 36_level_2,Unnamed: 37_level_2,Unnamed: 38_level_2,Unnamed: 39_level_2,Unnamed: 40_level_2,Unnamed: 41_level_2,Unnamed: 42_level_2,Unnamed: 43_level_2,Unnamed: 44_level_2,Unnamed: 45_level_2,Unnamed: 46_level_2,Unnamed: 47_level_2,Unnamed: 48_level_2,Unnamed: 49_level_2,Unnamed: 50_level_2,Unnamed: 51_level_2,Unnamed: 52_level_2,Unnamed: 53_level_2,Unnamed: 54_level_2,Unnamed: 55_level_2
-1.0000,3.0,biphasic,70,0.0,-281.451304,-397.975257,-476.713967,-492.392520,-447.058964,-376.300642,-321.050674,-302.793963,-326.953620,-373.178791,-402.345528,-406.421325,-398.784166,-398.604602,-432.507609,-498.043046,-569.537206,-615.720761,-621.534669,-599.881570,-569.341584,-547.038172,-537.480474,-512.022935,-454.138822,...,-614.038044,-651.949918,-673.508337,-666.754232,-637.343451,-594.956400,-545.675063,-507.567674,-497.529282,-515.889790,-554.058115,-596.982338,-622.082213,-610.750533,-576.737912,-550.022562,-534.994433,-538.087467,-567.238616,-606.000267,-625.291492,-607.949258,-554.263675,-462.163989,-341.938900
-0.9996,3.0,biphasic,70,0.0,-339.533506,-484.499331,-583.636175,-607.115661,-559.886450,-484.661699,-427.895765,-412.094832,-438.413516,-483.307146,-511.394224,-514.211606,-503.880453,-502.729505,-541.919485,-621.231524,-715.720327,-794.586984,-835.902067,-839.388168,-821.220720,-797.387051,-766.535576,-705.279227,-609.474763,...,-655.006953,-678.573799,-689.550499,-678.745704,-649.655392,-607.754790,-556.090802,-511.974884,-496.049365,-513.270447,-555.580194,-605.364990,-636.522509,-629.794691,-600.833181,-580.684069,-573.784295,-584.948040,-621.143347,-666.369852,-691.017176,-677.599692,-624.259136,-525.955846,-392.761848
-0.9992,3.0,biphasic,70,0.0,-438.294207,-619.358893,-739.635164,-763.458501,-699.048021,-599.339772,-517.851049,-477.400240,-482.539457,-512.433506,-529.049197,-522.116337,-500.493821,-481.161763,-489.412759,-526.065189,-574.592855,-612.913003,-624.654915,-617.763867,-611.087288,-619.940908,-646.758770,-662.459641,-645.659573,...,-599.971184,-669.629938,-721.161325,-735.756228,-715.033925,-668.378672,-606.365118,-550.835457,-522.469642,-527.989597,-561.693068,-606.874002,-636.464865,-627.261540,-590.459058,-556.405214,-531.675959,-524.748973,-547.001925,-582.719418,-599.877749,-578.568601,-519.690485,-424.538112,-305.764391
-0.9988,3.0,biphasic,70,0.0,-328.922691,-466.499331,-559.182812,-576.421069,-521.403991,-435.809654,-367.201173,-337.699113,-351.749596,-388.740909,-408.878679,-400.556633,-373.283683,-346.683842,-347.890680,-379.146972,-425.489615,-462.712851,-470.566051,-454.732840,-435.125376,-437.477065,-473.191150,-519.941680,-554.917741,...,-515.002276,-575.762832,-619.269661,-626.725137,-600.593629,-551.542657,-491.267853,-442.357536,-423.642156,-437.951973,-476.719621,-523.053049,-551.529043,-541.310795,-505.601937,-473.614666,-449.996142,-442.340390,-461.716201,-495.626069,-515.075684,-499.708492,-449.540337,-365.559727,-260.183230
-0.9984,3.0,biphasic,70,0.0,-280.756711,-401.922150,-490.318248,-519.270698,-490.231478,-436.154682,-396.094832,-390.979963,-426.010845,-480.428968,-514.708160,-519.047621,-504.206438,-490.542419,-506.934838,-556.319485,-620.891728,-676.364371,-706.139370,-714.403230,-717.591342,-725.467816,-727.484465,-697.404482,-630.274239,...,-668.628292,-666.478530,-650.847378,-619.066720,-582.263792,-548.586815,-513.491478,-485.600761,-481.708329,-505.195198,-550.062653,-602.491350,-638.261260,-637.203896,-612.005695,-591.639910,-581.892631,-588.598474,-622.059150,-670.361323,-703.729235,-702.806378,-660.823937,-566.491837,-428.335883
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2.6548,3.0,biphasic,70,0.0,-71.072707,-93.130351,-98.438173,-80.769295,-35.887556,19.261608,55.727889,49.559617,-12.603919,-105.730939,-177.797880,-204.442042,-188.507417,-158.713953,-150.552250,-175.015897,-219.098377,-251.229087,-249.199074,-219.599131,-184.701802,-156.130827,-127.862568,-80.630230,-10.617271,...,70.927980,57.715089,55.917368,75.734213,109.140864,151.137960,203.000616,242.333364,243.557337,205.164002,141.374739,75.174556,31.924321,28.092016,46.390443,56.395719,59.041196,49.007802,13.495597,-36.181857,-69.579294,-61.685466,-23.368618,24.011021,60.779518
2.6552,3.0,biphasic,70,0.0,-195.368415,-268.744673,-312.346080,-312.971685,-270.016821,-207.635774,-155.953153,-128.857786,-134.348784,-156.703348,-159.492462,-133.274614,-86.886078,-44.866028,-35.948388,-63.058059,-109.454895,-144.045082,-142.844970,-110.497329,-69.312617,-45.086378,-50.064167,-68.477374,-84.509622,...,111.871784,17.370062,-58.692352,-86.028001,-64.037916,-2.699561,81.477850,151.290356,177.191317,156.107667,103.519613,44.969513,8.789432,14.092016,44.332799,68.802929,88.218248,97.211407,82.814071,56.130084,44.644330,69.354154,114.642325,153.391130,165.909869
2.6556,3.0,biphasic,70,0.0,-194.512380,-278.736144,-337.244277,-348.275108,-303.209844,-220.925444,-136.466349,-71.857232,-38.898918,-26.446015,-0.942856,45.693545,101.649596,139.495077,129.329982,74.350969,-1.900257,-65.928485,-89.645839,-78.825502,-59.144513,-51.691569,-59.044158,-61.229063,-45.682135,...,258.827561,173.646502,105.631618,80.031081,97.567803,149.724216,221.791915,280.699497,299.819273,278.315346,233.424280,189.270511,169.321545,186.655013,222.840354,247.792544,262.001863,261.653076,236.303399,199.373931,179.712183,198.709789,237.781401,263.957767,255.008570
2.6560,3.0,biphasic,70,0.0,-216.475442,-310.544456,-380.911167,-404.632312,-371.544328,-301.210673,-228.412891,-172.500669,-144.922751,-134.823665,-105.362782,-45.941808,29.386910,88.833571,100.865128,67.988337,12.137367,-34.003734,-43.204106,-24.884183,-7.294325,-13.352230,-43.056799,-74.710417,-95.061706,...,346.503991,262.398739,191.702330,162.139416,175.975012,224.863644,294.378172,350.770208,367.341352,344.162294,299.279405,257.775114,242.241135,263.715275,302.765328,327.653897,338.589093,333.599619,304.102179,264.522834,244.611067,265.849503,306.108759,327.808641,308.005963


In [11]:
f, ax = plt.subplots(figsize=(6, 3), tight_layout=True)
im = ax.imshow(_lfp.T, aspect='auto', extent=[
    _lfp.index.remove_unused_levels().levels[0][0],
    _lfp.index.remove_unused_levels().levels[0][-1],
    _lfp.columns.levels[-1][0], _lfp.columns.levels[-1][-1]
], origin='lower', cmap='jet')
plt.colorbar(im, ax=ax, label='LFP')
ax.set_xlabel('time (s)')
ax.set_ylabel('depth along electrode (um)');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [12]:
csd = compute_csd(_lfp.loc[:, ::4], spacing=20, method='DeltaiCSD', smoothen=True)

discrete filter coefficients: 
b = [ 0.607 1.000 0.607 ],                
a = [ 2.213 ]


In [13]:
f, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 6), tight_layout=True, sharex=True, sharey=True)

im = ax1.imshow(_lfp.T[::4], aspect='auto', extent=[
    _lfp.index.remove_unused_levels().levels[0][0],
    _lfp.index.remove_unused_levels().levels[0][-1],
    _lfp.columns.levels[-1][0], _lfp.columns.levels[-1][-1]
], origin='lower', cmap='jet')
plt.colorbar(im, ax=ax1, label='LFP')
ax1.set_ylabel('depth along electrode (um)');

vx = np.quantile(csd.abs().values, 0.999)
im = ax2.imshow((csd-csd.mean()).T, aspect='auto', extent=[
    csd.index.remove_unused_levels().levels[0][0],
    csd.index.remove_unused_levels().levels[0][-1],
    csd.columns.levels[-1][0], csd.columns.levels[-1][-1]
], origin='lower', cmap='jet', vmax=vx, vmin=-vx)
plt.colorbar(im, ax=ax2, label='CSD')
ax2.set_xlabel('time (s)')
ax2.set_ylabel('depth along electrode (um)');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
lfp.index.to_frame()[['stim_type', 'parameter']].set_index(['stim_type', 'parameter']).index.unique()

MultiIndex([(  'biphasic',    '30'),
            (  'biphasic',    '50'),
            (  'biphasic',    '70'),
            ('fullscreen', 'white')],
           names=['stim_type', 'parameter'])

In [14]:
try:
#     x
    estim_mean_lfps = pd.read_pickle(f'mouse_{exp.mouse}_{probe}_estim_mean_lfp.pkl')
except:
    estim_mean_lfps = lfp.iloc[:, ::4].groupby(['parameter', 'sweep', 'onset']).mean()
    estim_mean_lfps.to_pickle(f'mouse_{exp.mouse}_{probe}_estim_mean_lfp.pkl')

In [15]:
def compute_and_plot_csd(df):
    csd = None
    try:
        ax = axes[df.name]
        try:
            global estim_csds
            csd = estim_csds.loc[df.name]
            csd.index = pd.MultiIndex.from_arrays(
                [[df.name[0]]*len(csd), [df.name[1]]*len(csd), csd.index],
                names=['parameter', 'sweep', 'onset']
            )
        except Exception as e:
            csd = compute_csd(df, spacing=20, method='DeltaiCSD', smoothen=True)
        vx = np.quantile(csd.abs().values, 0.99)
        im = ax.imshow((csd-csd.mean()).T.fillna(0), aspect='auto', extent=[
            csd.index.remove_unused_levels().levels[-1][0],
            csd.index.remove_unused_levels().levels[-1][-1],
            csd.columns.levels[-1][0], csd.columns.levels[-1][-1]
        ], origin='lower', cmap='jet', vmax=vx, vmin=-vx)
        plt.colorbar(im, ax=ax, label='CSD')
        ax.set_ylabel('depth along electrode (um)')
        ax.set_title(df.name)
        return csd
    except Exception as e:
        print(e, 'for', df.name)
        return csd

In [16]:
ncol = len(lfp.index.levels[4])
f, axes1 = plt.subplots(
    (len(estim_mean_lfps.index.droplevel('onset').unique())+ncol-1)//ncol, ncol,
    figsize=(5*3, len(estim_mean_lfps.index.droplevel('onset').unique())),
    tight_layout=True, sharex=True, sharey=True
)

axes = {
    p : axes1.flatten()[i] for i, p in enumerate(estim_mean_lfps.index.droplevel('onset').unique())
}
for ax in axes1.flatten()[len(axes):]:
    ax.set_visible(False)

try:
#     x
    estim_csds = pd.read_pickle(f'mouse_{exp.mouse}_{probe}_estim_mean_csd.pkl')
    estim_csds = estim_mean_lfps.groupby(['parameter', 'sweep']).apply(compute_and_plot_csd)
except:
    estim_csds = estim_mean_lfps.groupby(['parameter', 'sweep']).apply(compute_and_plot_csd)
    estim_csds.to_pickle(f'mouse_{exp.mouse}_{probe}_estim_mean_csd.pkl')

f.suptitle(f'mouse {exp.mouse}, {probe}');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

discrete filter coefficients: 
b = [ 0.607 1.000 0.607 ],                
a = [ 2.213 ]
discrete filter coefficients: 
b = [ 0.607 1.000 0.607 ],                
a = [ 2.213 ]
discrete filter coefficients: 
b = [ 0.607 1.000 0.607 ],                
a = [ 2.213 ]
discrete filter coefficients: 
b = [ 0.607 1.000 0.607 ],                
a = [ 2.213 ]
discrete filter coefficients: 
b = [ 0.607 1.000 0.607 ],                
a = [ 2.213 ]
discrete filter coefficients: 
b = [ 0.607 1.000 0.607 ],                
a = [ 2.213 ]
discrete filter coefficients: 
b = [ 0.607 1.000 0.607 ],                
a = [ 2.213 ]
discrete filter coefficients: 
b = [ 0.607 1.000 0.607 ],                
a = [ 2.213 ]
discrete filter coefficients: 
b = [ 0.607 1.000 0.607 ],                
a = [ 2.213 ]
discrete filter coefficients: 
b = [ 0.607 1.000 0.607 ],                
a = [ 2.213 ]
discrete filter coefficients: 
b = [ 0.607 1.000 0.607 ],                
a = [ 2.213 ]
discrete filter coefficients: 
b

---

In [6]:
# # remove lfp offsets
# with open(exp.ephys_params[probe]['probe_info'], 'r') as f:
#     probe_info = json.load(f)
# offsets = np.array(probe_info['offset'])
# lfp = lfp - offsets

## Final code
The idea is to compute the pairwise distance or correlation between all channels, and then cluster on this matrix.  
We will take a small window (length is a parameter, but ~2-4s) and compute the pairwise distances for that window; then cluster that matrix. In addition to the similarity/distance, we also add the channel number as a feature for clustering, to encourage clustering together nearby channels.  
This is repeated for 50 distinct windows, and the cluster definitions for all 50 windows are again clustered to get the final clusters.

In [7]:
# generic function that performs hierarchical clustering
def hierarchical_clusters(features, n_clusters=4, link='ward', pl=False):
    '''
    features is the matrix of features to cluster
    n_clusters: number of clusters
    link: do not change this parameter
    pl: plot the dendrogram tree (could be useful for debugging)
    '''
    cluster_data = {}
    cluster_data['z'] = linkage(
        features, link
    )
    if pl:
        f, ax = plt.subplots(1, 1, figsize=(4, 3), tight_layout=True)
        cluster_data['dendrogram'] = dendrogram(
            cluster_data['z'], p=n_clusters,
            truncate_mode='lastp', count_sort='ascending', show_contracted=True, ax=ax
        )
    clustering = AgglomerativeClustering(
        linkage=link, n_clusters=n_clusters
    )
    clustering.fit(features)
    cluster_data['clustering'] = clustering
    clusters = clustering.labels_
    cluster_data['clusters'] = clusters
    return cluster_data

In [8]:
def get_cluster_ids(lfp, t_start_s, win_length_s, n_clusters=4):
    '''
    Returns a pandas series object containing the cluster IDs for a given window (start time and window length are passed as parameters)
    '''
    # extract the slice of lfp within the window
    t_lim_s = np.array([t_start_s, t_start_s+win_length_s])
    sampled_lfp = pd.DataFrame(lfp[slice(*t_lim_s*samp_rate), EEGexp.NPX_lfp_channel_order].T)
    
    # first we will drop some channels that look abnormal (see the testing section below)
    distances = pd.DataFrame(pairwise_distances(sampled_lfp, metric='euclidean'))
    
    # compute the mean distances between any pair and the SD of this value
    mean_distances = distances.mean().rename('mean_distances')
    _mn = mean_distances.rolling(10, center=True).median()
    _sd = mean_distances.rolling(10, center=True).std()
    
    # only keep channels whose mean distance from others is within 1 SD of all mean distances
    idx_normal = (((mean_distances-_mn).abs()-_sd.mean())<0)
    sampled_lfp_normal = sampled_lfp[idx_normal]
    
    # compute the correlation distance
    correlations = pd.DataFrame(
        pairwise_distances(sampled_lfp_normal, metric='correlation'),
        index=sampled_lfp_normal.index, columns=sampled_lfp_normal.index
    )
    
    # add the channel index as a feature to the matrix (we weight this feature such that this dimension has approximately
    # the same importance as all other dimensions; that way, keeping clusters contiguous is given priority)
    _c = correlations.reset_index()
    _c['index'] = _c.index/_c.index.values.mean()*correlations.values.mean()*np.sqrt(correlations.shape[0])
    
    # run clustering and return clusters
    cdata = hierarchical_clusters(_c, n_clusters=n_clusters)
    return pd.Series(cdata['clusters'], index=sampled_lfp_normal.index)

In [9]:
# set parameters for the algorithm
n_clusters = 5 # depending on how many areas you think there are, set this number. might need to play with this
winsize = 4 # I tried 2/4 and both seem to work equally well

# repeat clustering for 50 distinct windows
cids = []
for t_start in tqdm(np.arange(50)*winsize):
    cids.append(get_cluster_ids(lfp, t_start, winsize, n_clusters=n_clusters))

clusters = pd.concat(cids, axis=1)
clusters.head()

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))




Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49
5,3.0,3.0,1.0,1.0,1.0,1.0,1.0,3.0,2.0,3.0,3.0,0.0,2.0,3.0,3.0,2.0,3.0,1.0,2.0,3.0,1.0,2.0,1.0,1.0,2.0,3.0,2.0,3.0,0.0,1.0,3.0,1.0,3.0,3.0,3.0,2.0,1.0,1.0,1.0,2.0,3.0,2.0,3.0,3.0,1.0,2.0,2.0,1.0,2.0,2.0
6,3.0,3.0,1.0,1.0,1.0,1.0,1.0,3.0,2.0,3.0,3.0,0.0,2.0,3.0,3.0,2.0,3.0,1.0,2.0,3.0,1.0,2.0,1.0,1.0,2.0,3.0,2.0,3.0,0.0,1.0,3.0,1.0,3.0,3.0,3.0,2.0,1.0,1.0,1.0,2.0,3.0,2.0,3.0,3.0,1.0,2.0,2.0,1.0,2.0,2.0
7,3.0,3.0,1.0,1.0,1.0,1.0,1.0,3.0,2.0,3.0,3.0,0.0,2.0,3.0,3.0,2.0,3.0,1.0,2.0,3.0,1.0,2.0,1.0,1.0,2.0,3.0,2.0,3.0,0.0,1.0,3.0,1.0,3.0,3.0,3.0,2.0,1.0,1.0,1.0,2.0,3.0,2.0,3.0,3.0,1.0,2.0,2.0,1.0,2.0,2.0
8,3.0,3.0,1.0,1.0,1.0,1.0,1.0,3.0,2.0,3.0,3.0,0.0,2.0,3.0,3.0,2.0,3.0,1.0,2.0,3.0,1.0,2.0,1.0,1.0,2.0,3.0,2.0,3.0,0.0,1.0,3.0,1.0,3.0,3.0,3.0,2.0,1.0,1.0,1.0,2.0,3.0,2.0,3.0,3.0,1.0,2.0,2.0,1.0,2.0,2.0
9,3.0,3.0,1.0,1.0,1.0,1.0,1.0,3.0,2.0,3.0,3.0,0.0,2.0,3.0,3.0,2.0,3.0,1.0,2.0,3.0,1.0,2.0,1.0,1.0,2.0,3.0,2.0,3.0,0.0,1.0,3.0,1.0,3.0,3.0,3.0,2.0,1.0,1.0,1.0,2.0,3.0,2.0,3.0,3.0,1.0,2.0,2.0,1.0,2.0,2.0


In [12]:
# cluster the clusters
final_clusters = hierarchical_clusters(clusters.bfill(limit=3).ffill(limit=3).dropna(axis=1), n_clusters=n_clusters)['clusters']

# plot results

# sample window to plot
t_lim_s = np.array([2, 2+winsize])
sampled_lfp = pd.DataFrame(lfp[slice(*t_lim_s*samp_rate), EEGexp.NPX_lfp_channel_order].T)
correlations = pd.DataFrame(pairwise_distances(sampled_lfp, metric='correlation'))

f, (ax, axa, ax2) = plt.subplots(
    1, 3, figsize=(9, 3.5), constrained_layout=True,
    gridspec_kw=dict(width_ratios=[3, 0.2, 1.5]), sharey=True
)
v = np.quantile(sampled_lfp, q=[0.01, 0.99])
ax.imshow(
    sampled_lfp, aspect='auto', cmap=cm.bwr, vmin=v[0], vmax=v[1],
    extent=[*t_lim_s, 0, sampled_lfp.shape[0]], origin='lower'
)
ax.set_xlabel('time (s)')
ax.set_ylabel('channel')

vn, vx = np.quantile(correlations, [0.01, 0.99])
ax2.imshow(
    correlations, aspect='auto', cmap=cm.bwr, vmin=vn, vmax=vx, origin='lower'
)
ax2.set_xlabel('channel')

axa.imshow(final_clusters[:, np.newaxis], aspect='auto', origin='lower', cmap=cm.Dark2)
axa.set_xticks([])
axa.set_xlabel('clusters')

ax.set_title('LFP')
ax2.set_title('similarity')

f.suptitle(f'mouse{exp.mouse}, {probe}');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Test on a sample window

In [13]:
t_lim_s = np.array([2, 6])
sampled_lfp = pd.DataFrame(lfp[slice(*t_lim_s*samp_rate), EEGexp.NPX_lfp_channel_order].T)

In [22]:
correlations = pd.DataFrame(pairwise_distances(sampled_lfp, metric='euclidean'))

f, (ax, ax2) = plt.subplots(
    1, 2, figsize=(8, 3), constrained_layout=True,
    gridspec_kw=dict(width_ratios=[3, 1.5]), sharey=True
)
v = np.quantile(sampled_lfp, q=[0.01, 0.99])
ax.imshow(
    sampled_lfp, aspect='auto', cmap=cm.bwr, vmin=v[0], vmax=v[1],
    extent=[*t_lim_s, 0, sampled_lfp.shape[0]], origin='lower'
)
ax.set_xlabel('time (s)')
ax.set_ylabel('channel')

vn, vx = np.quantile(correlations, [0.05, 0.95])
ax2.imshow(
    correlations, aspect='auto', cmap=cm.bwr, vmin=vn, vmax=vx, origin='lower'
)
ax2.set_xlabel('channel')

ax.set_title('LFP')
ax2.set_title('distance');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Remove channels that look very different (red lines in the top right plot)

In [15]:
# remove very different channels

mean_distances = correlations.mean().rename('mean_distances')
_mn = mean_distances.rolling(10, center=True).median()
_sd = mean_distances.rolling(10, center=True).std()

f, ax = plt.subplots(figsize=(7, 2.5), tight_layout=True)
ax.plot(mean_distances)
ax.fill_between(_mn.index, _mn-_sd.mean(), _mn+_sd.mean(), facecolor=cm.Greys(0.5, 0.5))
ax.set_xlabel('channel')
ax.set_ylabel('mean distance')

# which channels deviate by more than 2 local SD?
idx_normal = (((mean_distances-_mn).abs()-_sd.mean())<0)
sampled_lfp_normal = sampled_lfp[idx_normal]

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [20]:
# compute correlations after removing funky channels
correlations = pd.DataFrame(
    pairwise_distances(sampled_lfp_normal, metric='cosine'),
    index=sampled_lfp_normal.index, columns=sampled_lfp_normal.index
)

f, (ax, ax2) = plt.subplots(
    1, 2, figsize=(8, 3), constrained_layout=True,
    gridspec_kw=dict(width_ratios=[3, 1.5]), sharey=True
)
v = np.quantile(sampled_lfp_normal, q=[0.01, 0.99])
ax.imshow(
    sampled_lfp_normal, aspect='auto', cmap=cm.bwr, vmin=v[0], vmax=v[1],
    extent=[*t_lim_s, 0, sampled_lfp_normal.shape[0]], origin='lower'
)
ax.set_xlabel('time (s)')
ax.set_ylabel('channel')

vn, vx = np.quantile(correlations, [0.01, 0.99])
ax2.imshow(
    correlations, aspect='auto', cmap=cm.bwr, vmin=vn, vmax=vx, origin='lower'
)
ax2.set_xlabel('channel')

ax.set_title('LFP')
ax2.set_title('correlations');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
# example of a correlation between two channels (for debugging)
f, ax = plt.subplots(figsize=(3, 3), tight_layout=True)
ax.scatter(sampled_lfp_normal.iloc[275], sampled_lfp_normal.iloc[277]);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
# add channel index as feature and run clustering
_c = correlations.reset_index()
_c['index'] = _c.index/_c.index.values.mean()*correlations.values.mean()*10
cdata = hierarchical_clusters(_c, pl=True)
clusters = cdata['clusters']

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
# plot the clustering results

f, (ax, axa, ax2) = plt.subplots(
    1, 3, figsize=(9, 3), constrained_layout=True,
    gridspec_kw=dict(width_ratios=[3, 0.2, 1.5]), sharey=True
)
v = np.quantile(sampled_lfp_normal, q=[0.01, 0.99])
ax.imshow(
    sampled_lfp_normal, aspect='auto', cmap=cm.bwr, vmin=v[0], vmax=v[1],
    extent=[*t_lim_s, 0, sampled_lfp_normal.shape[0]], origin='lower'
)
ax.set_xlabel('time (s)')
ax.set_ylabel('channel')

vn, vx = np.quantile(correlations, [0.01, 0.99])
ax2.imshow(
    correlations, aspect='auto', cmap=cm.bwr, vmin=vn, vmax=vx, origin='lower'
)
ax2.set_xlabel('channel')

axa.imshow(clusters[:, np.newaxis], aspect='auto', origin='lower', cmap=cm.Dark2)
axa.set_xticks([])
axa.set_xlabel('clusters')

ax.set_title('LFP')
ax2.set_title('correlations');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …