In [1]:
# ------------------------------------------------------------------------------ #
# @Author:        F. Paul Spitzner
# @Email:         paul.spitzner@ds.mpg.de
# @Created:       2023-08-04 11:59:06
# @Last Modified: 2023-08-04 11:59:15
# ------------------------------------------------------------------------------ #
# Run on the cluster, using dask.
# Analyses all units and saves a large dataframe with everything
# that is needed.
# ------------------------------------------------------------------------------ #

%reload_ext autoreload
%autoreload 2
%reload_ext ipy_dict_hierarchy
%reload_ext watermark

import logging
logging.basicConfig(
    format="%(asctime)s | %(levelname)-s | %(name)-s | %(funcName)-s | %(message)s",
    level=logging.WARNING,
)
log = logging.getLogger("notebook")
log.setLevel("DEBUG")

import re
import glob
import h5py
import sys
import numpy as np
import xarray as xr
import pandas as pd
import dask

from tqdm import tqdm

sys.path.append('../')

from ana import utility as utl
utl.log.setLevel("ERROR")

output_dir = "/Users/paul/Desktop/"
# specify the path as closely as possible, we search recursively through all subdirs
data_dir = "../../../gnode/experiment_analysis/dat/"
meta_df = utl.all_unit_metadata(data_dir, reload=False)
meta_df = utl.load_spikes(meta_df)
meta_df = utl.default_filter(meta_df, trim=True)
meta_df = utl.merge_blocks(meta_df)
meta_df.tail()


Fetching metadata from sessions: 100%|██████████| 58/58 [00:00<00:00, 62.98it/s]
Loading spikes for sessions: 100%|██████████| 58/58 [00:35<00:00,  1.63it/s]
Merging blocks for units: 100%|██████████| 11999/11999 [00:15<00:00, 779.86it/s]


Unnamed: 0,unit_id,stimulus,session,block,ecephys_structure_acronym,invalid_spiketimes_check,recording_length,firing_rate,filepath,num_spikes,spiketimes
23288,951190716,natural_movie_one_more_repeats,847657808,merged_3.0_and_8.0,LP,SUCCESS,1067.306152,3.486347,/Users/paul/para/2_Projects/information_timesc...,3721,"[<xarray.DataArray ()>\narray(3.1584473, dtype..."
23289,951190722,natural_movie_one_more_repeats,847657808,merged_3.0_and_8.0,LP,SUCCESS,1077.341553,2.770709,/Users/paul/para/2_Projects/information_timesc...,2985,"[<xarray.DataArray ()>\narray(0.40185547, dtyp..."
23290,951190724,natural_movie_one_more_repeats,847657808,merged_3.0_and_8.0,LP,SUCCESS,1076.488525,2.043682,/Users/paul/para/2_Projects/information_timesc...,2200,"[<xarray.DataArray ()>\narray(1.0895996, dtype..."
23291,951190819,natural_movie_one_more_repeats,847657808,merged_3.0_and_8.0,LP,SUCCESS,1075.182861,2.104758,/Users/paul/para/2_Projects/information_timesc...,2263,"[<xarray.DataArray ()>\narray(1.2353516, dtype..."
23292,951190848,natural_movie_one_more_repeats,847657808,merged_3.0_and_8.0,VISrl,SUCCESS,1071.963867,0.80693,/Users/paul/para/2_Projects/information_timesc...,865,"[<xarray.DataArray ()>\narray(1.5703125, dtype..."


In [2]:
# to reassemble after anlaysis, we need an non-ambigous index
meta_df.set_index(['unit_id', 'stimulus', 'session', 'block'], inplace=True)
assert meta_df.index.is_unique
meta_df.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,ecephys_structure_acronym,invalid_spiketimes_check,recording_length,firing_rate,filepath,num_spikes,spiketimes
unit_id,stimulus,session,block,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
951013153,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,900.668806,21.54288,/Users/paul/para/2_Projects/information_timesc...,19403,[[[<xarray.DataArray (spiketimes: 30719)>\narr...
951013143,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,900.718039,11.652925,/Users/paul/para/2_Projects/information_timesc...,10496,[[[<xarray.DataArray (spiketimes: 30719)>\narr...
951013133,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,894.867666,0.448111,/Users/paul/para/2_Projects/information_timesc...,401,[[[<xarray.DataArray (spiketimes: 30719)>\narr...
951013202,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,900.265072,2.384853,/Users/paul/para/2_Projects/information_timesc...,2147,[[[<xarray.DataArray (spiketimes: 30719)>\narr...
951013187,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,896.074534,0.706414,/Users/paul/para/2_Projects/information_timesc...,633,[[[<xarray.DataArray (spiketimes: 30719)>\narr...


In [3]:
hde_settings = {
    "number_of_bootstraps_R_tot": 0,
    "number_of_bootstraps_R_max": 250,
    "timescale_minimum_past_range": 30 / 1000,
    "embedding_number_of_bins_set": [5],
    "estimation_method": "shuffling",
    "persistent_analysis": False,
    # fmt: off
    "embedding_past_range_set": [
        0.005, 0.00561, 0.00629, 0.00706, 0.00792, 0.00889, 0.00998, 0.01119, 0.01256,
        0.01409, 0.01581, 0.01774, 0.01991, 0.02233, 0.02506, 0.02812, 0.03155, 0.0354,
        0.03972, 0.04456, 0.05, 0.0561, 0.06295, 0.06441, 0.06591, 0.06745, 0.06902,
        0.07063, 0.07227, 0.07396, 0.07568, 0.07744, 0.07924, 0.08109, 0.08298, 0.08491,
        0.08689, 0.08891, 0.09099, 0.0931, 0.09527, 0.09749, 0.09976, 0.10209, 0.10446,
        0.1069, 0.10939, 0.11194, 0.11454, 0.11721, 0.11994, 0.12274, 0.12559, 0.12852,
        0.13151, 0.13458, 0.13771, 0.14092, 0.1442, 0.14756, 0.151, 0.15451, 0.15811,
        0.1618, 0.16557, 0.16942, 0.17337, 0.17741, 0.18154, 0.18577, 0.19009, 0.19452,
        0.19905, 0.20369, 0.20843, 0.21329, 0.21826, 0.22334, 0.25059, 0.28117, 0.31548,
        0.35397, 0.39716, 0.44563, 0.5, 0.56101, 0.62946, 0.70627, 0.79245, 0.88914,
        0.99763, 1.11936, 1.25594, 1.40919, 1.58114, 1.77407, 1.99054, 2.23342, 2.50594,
        2.81171, 3.15479, 3.53973, 3.97164, 4.45625, 5.0,
    ],
    # fmt: on
}

mre_settings = {
    "bin_size": 0.005,  # 5 ms
    "tmin": 0.03,
    "tmax": 10.0,
}


def mre_wrapper(data, settings):

    logging.getLogger("mrestimator").setLevel("ERROR")
    import mrestimator as mre

    data = data.squeeze()
    assert data.ndim == 1, "data must be 1D, this is the simple one-unit wrapper"

    binned_spikes = utl.binned_spike_count(data, bin_size=settings["bin_size"])

    rk = mre.coefficients(
        binned_spikes,
        method="ts",  # method does not matter for single unit
        steps=(
            int(settings["tmin"] / settings["bin_size"]),
            int(settings["tmax"] / settings["bin_size"]),
        ),
        dt=settings["bin_size"],
        dtunit="s",
    )

    fit_single = mre.fit(rk, fitfunc=mre.f_exponential_offset)
    fit_double = mre.fit(rk, fitfunc=mre.f_two_timescales)

    details_single = fit_single._asdict()
    details_double = fit_double._asdict()

    details_single["fitfunc"] = details_single["fitfunc"].__name__
    details_double["fitfunc"] = details_double["fitfunc"].__name__

    for key in settings.keys():
        details_single[key] = settings[key]
        details_double[key] = settings[key]
    
    # steps might get too big and are easy to reconstruct
    details_single.pop("steps", None)
    details_double.pop("steps", None)

    res = {
        "tau_single": fit_single.tau,
        "tau_double": fit_double.tau,
        "details_tau_single": details_single,
        "details_tau_double": details_double,
    }

    return res


def full_analysis(spikes):
    """
    Take one set of spikes, run hdestimator and mrestimator.

    # Parameters
    spikes : 1d numpy array
        flat list of spike times for a single unit. nans are removed.
    """

    logging.getLogger("hdestimator").setLevel("ERROR")
    import hdestimator as hde

    # remove nan-padding
    spikes = spikes.squeeze()
    spikes = spikes[np.isfinite(spikes)]
    hde_res = hde.api.wrapper(spike_times=spikes, settings=hde_settings)
    mre_res = mre_wrapper(spikes, mre_settings)

    for key in ["plot_AIS", "plot_settings", "plot_color", "ANALYSIS_DIR"]:
        hde_res["settings"].pop(key, None)

    return hde_res, mre_res


def dask_it(itertuple):
    """
    Idea is to delegate this to a dask worker and get a dictionary back
    that has everything we need. Every worker gets a one-row dataframe

    Retruns a tuple of (index, hde_res, mre_res)
    index is a pandas index so we can re-insert into the outer dataframe,
    the others are dicts with results
    """

    index, iterrow = itertuple

    # we save spikes as xarray, get 1d numpy
    spikes = iterrow["spiketimes"].squeeze().values

    hde_res, mre_res = full_analysis(spikes)

    return index, hde_res, mre_res


In [4]:
# manually test for a single unit
for itertuple in meta_df.iterrows():
    index, hde_res, mre_res = dask_it(itertuple)
    break

                                     

In [5]:
index

(951013153, 'natural_movie_one_more_repeats', 787025148, '3.0')

In [6]:
mre_res

<class 'dict'>
├── tau_single ....................................................... float64  1.7908380639059973
├── tau_double ....................................................... float64  0.05596782549734632
├── details_tau_single                                                        
│   ├── tau .......................................................... float64  1.7908380639059973
│   ├── mre .......................................................... float64  0.9972119050740276
│   ├── fitfunc .......................................................... str  f_exponenti...
│   ├── taustderr ................................................... NoneType
│   ├── mrestderr ................................................... NoneType
│   ├── tauquantiles ................................................ NoneType
│   ├── mrequantiles ................................................ NoneType
│   ├── quantiles ................................................... NoneType
│   ├── popt ......

In [7]:
hde_res

<class 'dict'>
├── firing_rate ...................................................... float64  21.52952801803102
├── firing_rate_sd ................................................... float64  0.0
├── recording_length ................................................. float64  900.668701171875
├── recording_length_sd .............................................. float64  0.0
├── H_spiking ........................................................ float64  0.34156872139915495
├── R_tot ............................................................ float64  0.03332835049792322
├── R_tot_sd ........................................................ NoneType
├── AIS_tot .......................................................... float64  0.011383922065918525
├── T_D ................................................................ float  0.06441
├── tau_R ............................................................ float64  0.011799304491435247
├── T_vals ..........................................

In [8]:
import logging
import tempfile
import time
import numpy as np
from tqdm import tqdm
from dask_jobqueue import SGECluster
from dask.distributed import Client, SSHCluster, LocalCluster, as_completed
from contextlib import nullcontext, ExitStack

# silence dask, configure this in ~/.config/dask/logging.yaml to be reliable
# https://docs.dask.org/en/latest/how-to/debug.html#logs
import logging

logging.basicConfig(level=logging.ERROR)
logging.getLogger("dask").setLevel(logging.WARNING)
logging.getLogger("distributed").setLevel(logging.WARNING)
logging.getLogger("distributed.worker").setLevel(logging.WARNING)

num_cores = 4
df_in_progress = None # for debugging, keep the global variable available

def main(dask_client, prepared_df):

    global df_in_progress, index, hde_res, mre_res
    df_in_progress = prepared_df.copy()
    df_in_progress["R_tot"] = np.nan
    df_in_progress["tau_R"] = np.nan
    df_in_progress["tau_single"] = np.nan
    df_in_progress["tau_double"] = np.nan
    df_in_progress["tau_R_details"] = None
    df_in_progress["tau_single_details"] = None
    df_in_progress["tau_double_details"] = None

    # dispatch, reading in parallel may be faster
    # create a list from the iterator, as dask map "no longer supports iterators" -.-
    futures = dask_client.map(dask_it, list(prepared_df.iterrows()))

    log.info("futures dispatched")

    # save once an hours
    last_save = time.time()
    for future in tqdm(as_completed(futures), total=len(futures)):

        index, hde_res, mre_res = future.result()

        # we use the .at[] method instead of .loc[] to be able to set
        # cells to dictionaries. (df.at can only access a single value at a time
        # df.loc can select multiple rows and/or columns)
        # https://stackoverflow.com/questions/13842088/set-value-for-particular-cell-in-pandas-dataframe-using-index
        df_in_progress.at[index, "R_tot"] = hde_res["R_tot"]
        df_in_progress.at[index, "tau_R"] = hde_res["tau_R"]
        df_in_progress.at[index, "tau_single"] = mre_res["tau_single"]
        df_in_progress.at[index, "tau_double"] = mre_res["tau_double"]

        df_in_progress.at[index, "tau_R_details"] = hde_res
        df_in_progress.at[index, "tau_single_details"] = mre_res["details_tau_single"]
        df_in_progress.at[index, "tau_double_details"] = mre_res["details_tau_double"]

        # save every hour or so
        if time.time() - last_save > 3600:
            time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
            utl.save_dataframe(
                df_in_progress,
                path=f"{output_dir}/meta_df_in_progress_{time_str}.h5",
            )
            last_save = time.time()

    utl.save_dataframe(df_in_progress, path=f"{output_dir}/meta_df_final.h5")

    return df_in_progress


with ExitStack() as stack:
    # init dask using a context manager to ensure proper cleanup of remote compute
    dask_cluster = stack.enter_context(
        # rudabeh
        # SGECluster(
        #     cores=32,
        #     memory="192GB",
        #     processes=32,
        #     job_extra_directives=["-pe mvapich2-sam 32"],
        #     log_directory="/scratch01.local/pspitzner/dask/logs",
        #     local_directory="/scratch01.local/pspitzner/dask/scratch",
        #     interface="ib0",
        #     walltime="24:00:00",
        # )
        # local cluster
        LocalCluster(local_directory=f"{tempfile.gettempdir()}/dask/")
    )
    dask_cluster.scale(cores=num_cores)
    dask_client = stack.enter_context(Client(dask_cluster))

    # for test use a smaller dataframe

    final_df = main(dask_client, meta_df.iloc[:20])


2023-08-10 10:35:14,369 | INFO | notebook | main | futures dispatched
2023-08-10 10:38:18,352 | DEBUG    | its_utility  | Binning spiketimes: dtype <class 'numpy.ndarray'> with shape (1, 401)
2023-08-10 10:38:19,281 | DEBUG    | its_utility  | Binning spiketimes: dtype <class 'numpy.ndarray'> with shape (1, 2147)
2023-08-10 10:38:36,173 | DEBUG    | its_utility  | Binning spiketimes: dtype <class 'numpy.ndarray'> with shape (1, 10496)
2023-08-10 10:38:48,376 | DEBUG    | its_utility  | Binning spiketimes: dtype <class 'numpy.ndarray'> with shape (1, 19403)
2023-08-10 10:41:13,984 | DEBUG    | its_utility  | Binning spiketimes: dtype <class 'numpy.ndarray'> with shape (1, 633)
2023-08-10 10:41:28,175 | DEBUG    | its_utility  | Binning spiketimes: dtype <class 'numpy.ndarray'> with shape (1, 1266)
2023-08-10 10:41:43,778 | DEBUG    | its_utility  | Binning spiketimes: dtype <class 'numpy.ndarray'> with shape (1, 1072)
2023-08-10 10:42:20,338 | DEBUG    | its_utility  | Binning spiketime

In [10]:
final_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,ecephys_structure_acronym,invalid_spiketimes_check,recording_length,firing_rate,filepath,num_spikes,spiketimes,R_tot,tau_R,tau_single,tau_double,tau_R_details,tau_single_details,tau_double_details
unit_id,stimulus,session,block,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
951013153,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,900.668806,21.54288,/Users/paul/para/2_Projects/information_timesc...,19403,[[[<xarray.DataArray (spiketimes: 30719)>\narr...,0.033412,0.011884,1.790838,0.055968,"{'firing_rate': 21.52952801803102, 'firing_rat...","{'tau': 1.7908380639059973, 'mre': 0.997211905...","{'tau': 0.05596782549734632, 'mre': 0.91453728..."
951013143,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,900.718039,11.652925,/Users/paul/para/2_Projects/information_timesc...,10496,[[[<xarray.DataArray (spiketimes: 30719)>\narr...,0.075249,0.056636,-161.091734,0.594005,"{'firing_rate': 10.611510791366905, 'firing_ra...","{'tau': -161.09173360616265, 'mre': 1.00003103...","{'tau': 0.5940050457755989, 'mre': 0.991617890..."
951013133,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,894.867666,0.448111,/Users/paul/para/2_Projects/information_timesc...,401,[[[<xarray.DataArray (spiketimes: 30719)>\narr...,0.073422,0.226062,1.120947,0.496862,"{'firing_rate': 0.448109781309017, 'firing_rat...","{'tau': 1.12094653581912, 'mre': 0.99554941728...","{'tau': 0.49686180983789935, 'mre': 0.98998730..."
951013202,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,900.265072,2.384853,/Users/paul/para/2_Projects/information_timesc...,2147,[[[<xarray.DataArray (spiketimes: 30719)>\narr...,0.08516,0.161609,1.906939,1.613602,"{'firing_rate': 2.3848533487362054, 'firing_ra...","{'tau': 1.9069389187340104, 'mre': 0.997381431...","{'tau': 1.6136015170529685, 'mre': 0.996906137..."
951013187,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,896.074534,0.706414,/Users/paul/para/2_Projects/information_timesc...,633,[[[<xarray.DataArray (spiketimes: 30719)>\narr...,0.233878,0.021783,0.040537,0.029344,"{'firing_rate': 0.7030661495968529, 'firing_ra...","{'tau': 0.04053663670050711, 'mre': 0.88395845...","{'tau': 0.029344318252677346, 'mre': 0.8433352..."
951013303,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,899.964772,1.406722,/Users/paul/para/2_Projects/information_timesc...,1266,[[[<xarray.DataArray (spiketimes: 30719)>\narr...,0.070918,0.059447,0.475836,0.238142,"{'firing_rate': 1.404499063852483, 'firing_rat...","{'tau': 0.475836014503997, 'mre': 0.9895471925...","{'tau': 0.2381417000452154, 'mre': 0.979222976..."
951013292,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,900.648773,6.338764,/Users/paul/para/2_Projects/information_timesc...,5709,[[[<xarray.DataArray (spiketimes: 30719)>\narr...,0.084842,0.065714,1.963956,0.020156,"{'firing_rate': 6.3309831788153, 'firing_rate_...","{'tau': 1.9639564312948419, 'mre': 0.997457356...","{'tau': 0.02015603876809622, 'mre': 0.78030952..."
951013360,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,899.435505,1.191859,/Users/paul/para/2_Projects/information_timesc...,1072,[[[<xarray.DataArray (spiketimes: 30719)>\narr...,0.107437,0.029876,0.181691,0.100557,"{'firing_rate': 1.1885172996531175, 'firing_ra...","{'tau': 0.18169073644835218, 'mre': 0.97285591...","{'tau': 0.10055675762742168, 'mre': 0.95149279..."
951013556,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,898.482603,2.07127,/Users/paul/para/2_Projects/information_timesc...,1861,[[[<xarray.DataArray (spiketimes: 30719)>\narr...,0.086762,0.035966,1.451338,0.06956,"{'firing_rate': 2.0679254522891313, 'firing_ra...","{'tau': 1.4513379620588718, 'mre': 0.996560830...","{'tau': 0.06956038567712379, 'mre': 0.93064257..."
951013545,natural_movie_one_more_repeats,787025148,3.0,VISam,SUCCESS,898.468003,0.98501,/Users/paul/para/2_Projects/information_timesc...,885,[[[<xarray.DataArray (spiketimes: 30719)>\narr...,0.05006,0.033046,0.121684,0.102071,"{'firing_rate': 0.9850078466726768, 'firing_ra...","{'tau': 0.12168427741232384, 'mre': 0.95974280...","{'tau': 0.10207080390199351, 'mre': 0.95219483..."


In [12]:
%watermark -v --iversions --packages mrestimator,hdestimator

Python implementation: CPython
Python version       : 3.11.4
IPython version      : 8.14.0

mrestimator: 0.1.9b2
hdestimator: 0.10b2

h5py          : 3.9.0
dask          : 2023.7.1
xarray        : 2023.7.0
IPython       : 8.14.0
sys           : 3.11.4 | packaged by conda-forge | (main, Jun 10 2023, 18:08:41) [Clang 15.0.7 ]
numpy         : 1.24.4
matplotlib    : 3.7.2
sqlite3       : 2.6.0
pandas        : 2.0.3
logging       : 0.5.1.2
prompt_toolkit: 3.0.39
re            : 2.2.1

