In [1]:
# ------------------------------------------------------------------------------ #
# @Author:        F. Paul Spitzner
# @Email:         paul.spitzner@ds.mpg.de
# @Created:       2023-08-04 11:59:06
# @Last Modified: 2024-03-23 16:30:22
# ------------------------------------------------------------------------------ #
# Run on the cluster, using dask.
# Analyses all units and saves a large dataframe with everything
# that is needed.
# This is the notebook that analysis spontaneous activity.
# ------------------------------------------------------------------------------ #

%load_ext ipy_dict_hierarchy
%load_ext watermark

import logging
logging.basicConfig(
    format="%(asctime)s | %(levelname)-s | %(name)-s | %(funcName)-s | %(message)s",
    level=logging.WARNING,
)
log = logging.getLogger("notebook")
log.setLevel("DEBUG")

import re
import glob
import h5py
import os
import sys
import numpy as np
import xarray as xr
import pandas as pd
import dask

from tqdm import tqdm

# also needs to be added for each dask-worker
extra_path = os.path.abspath('../')
sys.path.append(extra_path)
log.info(f"project directory: {extra_path}")

from ana import utility as utl
utl.log.setLevel("ERROR")

# specify the path as closely as possible, we search for the spike files recursively through all subdirs
data_dir = "/path/to/repo/experiment_analysis/dat/"
# analysis results can sit in the same directory
output_dir = "/path/to/repo/experiment_analysis/dat/"
output_name = "all_units_spontaneous_for_merged"

2024-04-16 12:12:45,081 | INFO | notebook | <module> | project directory: /data.nst/lucas/projects/mouse_visual_timescales_predictability/paper_code/experiment_analysis


In [None]:
meta_df = utl.all_unit_metadata(data_dir, reload=False)
meta_df = utl.load_spikes(meta_df, format="numpy")
meta_df = utl.default_filter(meta_df, trim=False)  # only updates the status column

# limit to spontaneous
meta_df = meta_df[meta_df["stimulus"] == "spontaneous"]
meta_df["spiketimes"] = meta_df.apply(
    lambda row: utl.prepare_spike_times(row["spiketimes"], "spontaneous_for_merged"),
    axis=1,
)
meta_df.tail()

Fetching metadata from sessions: 100%|████████| 58/58 [01:03<00:00,  1.09s/it]
Loading spikes for sessions:  86%|██████████▎ | 50/58 [02:42<00:21,  2.67s/it]

In [None]:
print(len(meta_df))
print(len(meta_df.query("invalid_spiketimes_check == 'SUCCESS'")))
print(len(meta_df.query("invalid_spiketimes_check != 'SUCCESS'")))

In [None]:
hde_settings = {
    "number_of_bootstraps_R_tot": 0,
    "number_of_bootstraps_R_max": 250,
    "timescale_minimum_past_range": 30 / 1000,
    "embedding_number_of_bins_set": [5],
    "estimation_method": "shuffling",
    "persistent_analysis": False,
    # fmt: off
    "embedding_past_range_set": [
        0.005, 0.00561, 0.00629, 0.00706, 0.00792, 0.00889, 0.00998, 0.01119, 0.01256,
        0.01409, 0.01581, 0.01774, 0.01991, 0.02233, 0.02506, 0.02812, 0.03155, 0.0354,
        0.03972, 0.04456, 0.05, 0.0561, 0.06295, 0.06441, 0.06591, 0.06745, 0.06902,
        0.07063, 0.07227, 0.07396, 0.07568, 0.07744, 0.07924, 0.08109, 0.08298, 0.08491,
        0.08689, 0.08891, 0.09099, 0.0931, 0.09527, 0.09749, 0.09976, 0.10209, 0.10446,
        0.1069, 0.10939, 0.11194, 0.11454, 0.11721, 0.11994, 0.12274, 0.12559, 0.12852,
        0.13151, 0.13458, 0.13771, 0.14092, 0.1442, 0.14756, 0.151, 0.15451, 0.15811,
        0.1618, 0.16557, 0.16942, 0.17337, 0.17741, 0.18154, 0.18577, 0.19009, 0.19452,
        0.19905, 0.20369, 0.20843, 0.21329, 0.21826, 0.22334, 0.25059, 0.28117, 0.31548,
        0.35397, 0.39716, 0.44563, 0.5, 0.56101, 0.62946, 0.70627, 0.79245, 0.88914,
        0.99763, 1.11936, 1.25594, 1.40919, 1.58114, 1.77407, 1.99054, 2.23342, 2.50594,
        2.81171, 3.15479, 3.53973, 3.97164, 4.45625, 5.0,
    ],
    # fmt: on
}

mre_settings = {
    "bin_size": 0.005,  # 5 ms
    "tmin": 0.03,
    "tmax": 10.0,
}


def mre_wrapper(data, settings):

    logging.getLogger("mrestimator").setLevel("ERROR")
    import mrestimator as mre
    mre.disable_progressbar()

    data = data.squeeze()
    assert data.ndim == 1, "data must be 1D, this is the simple one-unit wrapper"

    binned_spikes = utl.binned_spike_count(data, bin_size=settings["bin_size"])

    rk = mre.coefficients(
        binned_spikes,
        method="ts",  # method does not matter for single unit
        steps=(
            int(settings["tmin"] / settings["bin_size"]),
            int(settings["tmax"] / settings["bin_size"]),
        ),
        dt=settings["bin_size"],
        dtunit="s",
    )

    fit_single = mre.fit(rk, fitfunc=mre.f_exponential_offset)
    fit_double = mre.fit(rk, fitfunc=mre.f_two_timescales)

    details_single = fit_single._asdict()
    details_double = fit_double._asdict()

    details_single["fitfunc"] = details_single["fitfunc"].__name__
    details_double["fitfunc"] = details_double["fitfunc"].__name__

    for key in settings.keys():
        details_single[key] = settings[key]
        details_double[key] = settings[key]

    # steps might get too big and are easy to reconstruct
    details_single.pop("steps", None)
    details_double.pop("steps", None)

    res = {
        "tau_single": fit_single.tau,
        "tau_double": fit_double.tau,
        "details_tau_single": details_single,
        "details_tau_double": details_double,
    }

    return res


def full_analysis(spikes):
    """
    Take one set of spikes, run hdestimator and mrestimator.

    # Parameters
    spikes : 1d numpy array
        flat list of spike times for a single unit. nans are removed.
    """

    logging.getLogger("hdestimator").setLevel("ERROR")
    import hdestimator as hde

    # remove nan-padding
    spikes = spikes.squeeze()
    spikes = spikes[np.isfinite(spikes)]

    try:
        hde_res = hde.api.wrapper(spike_times=spikes, settings=hde_settings)
        for key in ["plot_AIS", "plot_settings", "plot_color", "ANALYSIS_DIR"]:
            hde_res["settings"].pop(key, None)
    except:
        hde_res = dict()
        hde_res["R_tot"] = np.nan
        hde_res["tau_R"] = np.nan

    try:
        mre_res = mre_wrapper(spikes, mre_settings)
    except:
        mre_res = dict()
        mre_res["tau_single"] = np.nan
        mre_res["tau_double"] = np.nan
        mre_res["details_tau_single"] = dict()
        mre_res["details_tau_double"] = dict()

    return hde_res, mre_res


def dask_it(itertuple):
    """
    Idea is to delegate this to a dask worker and get a dictionary back
    that has everything we need. Every worker gets a one-row dataframe

    Retruns a tuple of (index, hde_res, mre_res)
    index is a pandas index so we can re-insert into the outer dataframe,
    the others are dicts with results
    """

    index, iterrow = itertuple

    # dont forget to prepare the spikes beforehand
    spikes = iterrow["spiketimes"].squeeze()

    hde_res, mre_res = full_analysis(spikes)

    return index, hde_res, mre_res


In [None]:
# manually test for a single unit
for itertuple in meta_df.iterrows():
    index, hde_res, mre_res = dask_it(itertuple)
    break

In [None]:
index

In [None]:
mre_res

In [None]:
hde_res

In [None]:
import logging
import tempfile
import time
import numpy as np
from tqdm.auto import tqdm
from dask_jobqueue import SGECluster
from dask.distributed import Client, SSHCluster, LocalCluster, as_completed
from contextlib import nullcontext, ExitStack

# silence dask, configure this in ~/.config/dask/logging.yaml to be reliable
# https://docs.dask.org/en/latest/how-to/debug.html#logs
import logging

logging.basicConfig(level=logging.ERROR)
logging.getLogger("dask").setLevel(logging.WARNING)
logging.getLogger("distributed").setLevel(logging.WARNING)
logging.getLogger("distributed.worker").setLevel(logging.WARNING)

num_cores = 128
df_in_progress = None # for debugging, keep the global variable available

def main(dask_client, prepared_df):

    global df_in_progress, index, hde_res, mre_res
    df_in_progress = prepared_df.copy()
    df_in_progress["R_tot"] = np.nan
    df_in_progress["tau_R"] = np.nan
    df_in_progress["tau_single"] = np.nan
    df_in_progress["tau_double"] = np.nan
    df_in_progress["tau_R_details"] = None
    df_in_progress["tau_single_details"] = None
    df_in_progress["tau_double_details"] = None

    # dispatch, reading in parallel may be faster
    # create a list from the iterator, as dask map "no longer supports iterators" -.-
    futures = dask_client.map(dask_it, list(prepared_df.iterrows()))

    log.info("futures dispatched")

    # save once an hours
    last_save = time.time()
    for future in tqdm(as_completed(futures), total=len(futures), leave=True):

        index, hde_res, mre_res = future.result()

        # we use the .at[] method instead of .loc[] to be able to set
        # cells to dictionaries. (df.at can only access a single value at a time
        # df.loc can select multiple rows and/or columns)
        # https://stackoverflow.com/questions/13842088/set-value-for-particular-cell-in-pandas-dataframe-using-index
        df_in_progress.at[index, "R_tot"] = hde_res["R_tot"]
        df_in_progress.at[index, "tau_R"] = hde_res["tau_R"]
        df_in_progress.at[index, "tau_single"] = mre_res["tau_single"]
        df_in_progress.at[index, "tau_double"] = mre_res["tau_double"]

        df_in_progress.at[index, "tau_R_details"] = hde_res
        df_in_progress.at[index, "tau_single_details"] = mre_res["details_tau_single"]
        df_in_progress.at[index, "tau_double_details"] = mre_res["details_tau_double"]

        # save every hour or so
        if time.time() - last_save > 3600:
            time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
            try:
                df_in_progress.reset_index().to_feather(f"{output_dir}/{output_name}_{time_str}.feather")
            except Exception as e:
                log.error(e)
            last_save = time.time()

    try:
        df_in_progress.reset_index().to_feather(f"{output_dir}/{output_name}_final.feather")
    except Exception as e:
        log.error(e)

    return df_in_progress


with ExitStack() as stack:
    # init dask using a context manager to ensure proper cleanup of remote compute.
    # adapt as needed, details depend on your setup
    dask_cluster = stack.enter_context(
        # rudabeh
        SGECluster(
            cores=32,
            memory="192GB",
            processes=32,
            job_extra_directives=["-pe mvapich2-sam 32"],
            log_directory="/scratch01.local/pspitzner/dask/logs",
            local_directory="/scratch01.local/pspitzner/dask/scratch",
            interface="ib0",
            walltime="24:00:00",
            worker_extra_args=[
                f'--preload \'import sys; sys.path.append("{extra_path}");\''
            ],
        )
        # or use local compute:
        # LocalCluster(local_directory=f"{tempfile.gettempdir()}/dask/")
    )
    dask_cluster.scale(cores=num_cores)
    dask_client = stack.enter_context(Client(dask_cluster))

    final_df = main(dask_client, meta_df)


In [None]:
final_df.to_hdf(f"{output_dir}/{output_name}.h5", key="meta_df")

# to combine the frames from sponatneous and stimulated activity see the notebook
# combine_dataframes.ipynb

In [None]:
# fix group permissions for collaborative directories
os.system(f"chmod -R g+rwx {output_dir}")

In [None]:
%watermark -v --iversions --packages mrestimator,hdestimator,scipy