In [13]:
%reload_ext ipy_dict_hierarchy
%reload_ext autoreload
%autoreload 2


import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import glob
import zarr
import xarray as xr
import os
import re
import itertools

import sys
sys.path.append("../")

import warnings
# warnings.filterwarnings("ignore")

import logging
logging.basicConfig(
    format="%(asctime)s | %(levelname)-8s | %(name)-12s | %(message)s",
    level=logging.WARNING,
)
log = logging.getLogger("notebook")
log.setLevel("DEBUG")

In [14]:
input_directory = "/data.nst/lucas/history_dependence/signatures_of_temporal_processing_paper_code/data/bn_code_cleaned/"
output_filepath = "/data.nst/share/projects/information_timescales/branching_network/res_dset_bn_code_cleaned_N=10_000_to_merge.zarr"

# I added a faster spike format "spikes_by_neuron", so we now have two groups
# in the storage (and reading routine) to choose from.
spike_group = "/spikes_by_neuron" # or spikes_as_list

# dask settings, see last cell.
num_cores = 256
dump_progress = False # keep overwriting output_file as we make progress

# we recorded 100 neurons per realization
# A single analysis (one neuron, one core) takes ~ 420 sec
N_to_analyze = 20

# filenames are meaningless, we filter by metadata
files = glob.glob(input_directory + f"*.zarr") 
log.info(f"found {len(files)} files")

def get_dims_of_file(fname, dims_to_get='all'):
    file = zarr.open(fname + spike_group, mode="r")
    dims = dict()

    if dims_to_get == 'all':
        dims_to_get = list(file.attrs.keys())

    for dim in dims_to_get:
        if dim in file.attrs:
            dims[dim] = file.attrs[dim]

    return dims

get_dims_of_file(files[0], dims_to_get='all')

2023-03-09 20:50:43,693 | INFO     | notebook     | found 2400 files


<class 'dict'>
├── N .................................................................... int  10000
├── N_to_record .......................................................... int  100
├── adjacency_matrix ..................................................... str  scipy.spars...
├── current_timestep ..................................................... int  240000
├── dataformat_details ................................................... str           Sp...
├── dt ................................................................. float  0.005
├── duration_equil ....................................................... int  1200
├── duration_record ...................................................... int  1200
├── gamma .............................................................. float  -598.4561146858672
├── input_alpha .......................................................... str  None
├── input_tau .......................................................... float  0.03
├── input_typ

In [16]:
# create an xarray that maps coordinates to filenames

# those will become the axis of the xarray
dims_to_get = ["N", "k", "m", "rep", "input_type"]
occurrences = {d: [] for d in dims_to_get}

# we want to create an xarray that maps coordinates to filenames
# I did not manage to this in a single run, so we have to `get_dims_of_file` twice:
# once to find what coordinate combinations to expect, and then to create the map
filtered_files = []
for fname in files:
    dims = get_dims_of_file(fname, dims_to_get)
    
    # if need be, we could filter here.
    if dims["N"] != 10_000 or dims["k"] != 10:
        continue
    # else:
    #     dims.pop("N")

    filtered_files.append(fname)

    for k, v in dims.items():
        occurrences[k].append(v)

log.info(f"{len(filtered_files)} files remaing after filtering")

# since we have repetitions / degeneracies, we need to filter unique occurences
occurrences = {k: np.sort(np.unique(v)) for k, v in occurrences.items()}

# create an empty array to fill
cs_to_fname = xr.DataArray(
    data=None,
    dims = list(occurrences.keys()),
    coords = list(occurrences.values()),
)

# create the map
for fname in filtered_files: 
    dims = get_dims_of_file(fname, dims_to_get)
    cs_to_fname.loc[dims] = fname

# add the neuron dimension, (multiple neurons share the same file)
cs_to_fname = cs_to_fname.expand_dims({"nid": np.arange(N_to_analyze)})

# for the results we will fill an xarray dataset
res_dset = xr.Dataset(
    coords = cs_to_fname.coords,
)

# now we can get filenames easily:
# cs_to_fname.sel(rep=0, k=10, m=0.95, nid=2).values[()]
print(cs_to_fname.coords)

# Note to future paul: maybe its better to just crawl the files in each worker
# and pick the right filename. this seems really fast with zarr.
# currently we send the (large?) cs_to_fname frame to every worker.

2023-03-09 20:51:30,357 | INFO     | notebook     | 800 files remaing after filtering


Coordinates:
  * nid         (nid) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
  * N           (N) int64 10000
  * k           (k) int64 10
  * m           (m) float64 0.8 0.8538 0.8849 0.905 ... 0.9725 0.9738 0.975
  * rep         (rep) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
  * input_type  (input_type) <U8 'OU' 'constant'


In [4]:
from numba import jit

@jit(nopython=True, parallel=True, fastmath=False, cache=True)
def binned_spike_count(spiketimes, bin_size, length=None):
    """
    Similar to a population_rate, but we get a number of spike counts, per neuron,
    as needed for e.g. cross-correlations.

    Parameters
    ----------
    spiketimes :
        np array with first dim neurons, second dim spiketimes. nan-padded
    bin_size :
        float, in units of spiketimes
    length :
        duration of output trains, in units of spiketimes. Default: None,
        uses last spiketime

    Returns
    -------
    counts : 2d array
        time series of the counted number of spikes per bin,
        one row for each neuron, in steps of bin_size
    """

    num_n = spiketimes.shape[0]

    if length is not None:
        num_bins = int(np.ceil(length / bin_size))
    else:
        t_min = 0.0
        t_max = np.nanmax(spiketimes)
        num_bins = int(np.ceil((t_max - t_min) / bin_size)) + 1

    counts = np.zeros(shape=(num_n, num_bins))

    for n_id in range(0, num_n):
        train = spiketimes[n_id]
        for t in train:
            if not np.isfinite(t):
                break
            t_idx = int(t / bin_size)
            counts[n_id, t_idx] += 1

    return counts

In [5]:
def mrestimation(
    single_neuron_spiketimes,
    bin_size=5,
    dtunit="ms",
    tmin=0,
    tmax=10000,
    plot_autocorrelation = False,
):
    import mrestimator as mre
    import logging
    mre.log.disabled = True
    logging.getLogger("mrestimator").disabled=True

    # binned spiking, input is in seconds
    binned_spt = mre.input_handler(
        binned_spike_count(single_neuron_spiketimes[np.newaxis,:], bin_size/1000),
    )

    # we use a double exponential to get two timescales
    def f_two_timescales(k, tau1, A1, tau2, A2):
        return np.abs(A1) * np.exp(-k / tau1) + np.abs(A2) * np.exp(-k / tau2)

    # inital conditions to try when fitting: tau1, A1, tau2, A2
    fitpars_two_timescales = np.array(
        [
            (0.1, 0.01, 10, 0.01),
            (0.1, 0.1, 10, 0.01),
            (0.5, 0.01, 10, 0.001),
            (0.5, 0.1, 10, 0.01),
            (0.1, 0.01, 10, 0),
            (0.1, 0.1, 10, 0),
            (0.5, 0.01, 10, 0),
            (0.5, 0.1, 10, 0),
        ]
    )
    
    # now we work in ms
    rk = mre.coefficients(
        binned_spt,
        method="ts",
        steps=(int(tmin / bin_size), int(tmax / bin_size)),
        dt=bin_size,
        dtunit=dtunit,
    )

    fit_two_timescales = mre.fit(
        rk, fitpars=fitpars_two_timescales, fitfunc=f_two_timescales
    )
    fit_single_timescale = mre.fit(rk, fitfunc=mre.f_exponential_offset)
    if plot_autocorrelation:
        fig,ax = plt.subplots()
        mre_out = mre.OutputHandler(rk, ax)
        mre_out.add_coefficients(rk, color="g", lw=0.5, alpha=0.6)
        mre_out.add_fit(fit_two_timescales, color="g")#, lw=0.5)
        mre_out.add_fit(fit_single_timescale, color="b")#, lw=0.5)
        ax.set_ylabel('C(T)')
        ax.set_xlabel('time lag $T$ (ms)')
        plt.show()

    tau_1_two_timescales = fit_two_timescales.popt[0]
    A_1_two_timescales = np.abs(fit_two_timescales.popt[1])
    tau_2_two_timescales = fit_two_timescales.popt[2]
    A_2_two_timescales = np.abs(fit_two_timescales.popt[3])
    # Choose the timescale with higher coefficient A

    res = dict()
    res["tau_C_twots_Amax"] = (tau_1_two_timescales, tau_2_two_timescales)[
        np.argmax((A_1_two_timescales, A_2_two_timescales))
    ]
    res["tau_C_twots_Amin"] = (tau_1_two_timescales, tau_2_two_timescales)[
        np.argmin((A_1_two_timescales, A_2_two_timescales))
    ]
    res["tau_C_single"] = fit_single_timescale.tau

    return res

In [6]:
import sys
sys.path.append("/data.nst/pspitzner/information_timescales/branching_network/ana")

# these kind of imports often confuse dask
# from ana.hdestimator_wrapper import hde
from hdestimator_wrapper import hde

def analyse_neuron(coords, spike_group="/spikes_by_neuron", plot_autocorrelation = False):
    """
    The combined analysis for a single neuron, including hde and mre.

    Returns a dict with everything we found, and the (passed) coordinates so we can
    asynchroniously store into the xarray.
    """

    # in dask we do not want to infer the coordinates
    fname = cs_to_fname.sel(coords).values[()]
    dset = zarr.open(fname + spike_group, mode="r")
    # get single-neuron spiketimes, depending on storage format.
    if spike_group == "/spikes_as_list":
        nids = dset[0, :]
        idx = np.where(nids == coords["nid"])[0]
        spike_times = dset[1, idx] * dset.attrs["dt"]
    elif spike_group == "/spikes_by_neuron":
        spike_times = dset[coords["nid"], :] * dset.attrs["dt"]
        spike_times = spike_times[spike_times >= 0]

    # res will be a dict with a bunch of keys
    res_c = mrestimation(spike_times, plot_autocorrelation = plot_autocorrelation)

    # cli args (including the location of the settings file) are hard-coded in the wrapper
    # currently, settings are adjusted in ana/hdestimator_settings.yaml
    res_r = hde(spike_times)

    # some extra observables that are easy to get
    res_m_AR = {"m_AR": dset.attrs["m_AR"]}
    res_rate = {"rate": len(spike_times)/1200.}
    
    # merge the dicts, and return the coordinates so we can write to the dataset
    return {**res_r, **res_c, **res_m_AR, **res_rate}, coords



In [7]:
import humanize
f"{humanize.naturalsize(cs_to_fname.nbytes, binary=True)}"


'125.0 KiB'

In [8]:
# get all possible combinations of coordinates from res_dset
combinations = list(itertools.product(*[res_dset.coords[d] for d in res_dset.dims]))
combinations = [dict(zip(res_dset.dims, c)) for c in combinations]
combinations = [{k: v.values[()] for k, v in c.items()} for c in combinations]

# A single analysis (one neuron, one core) takes ~ 420 sec
log.info(f"{len(combinations)} parameter combinations to analyse")

# to test the analysis, we could now run it on a single neuron
# this_cs = dict(rep=0, k=100, m=0.995, nid=0)
# this_cs = combinations[0]
# this_res, this_cs = analyse_neuron(this_cs, spike_group, plot_autocorrelation=True)
# this_res.items()

2023-03-09 12:50:45,071 | INFO     | notebook     | 16000 parameter combinations to analyse


In [None]:
import logging

import functools
import itertools
import tempfile
import numpy as np
from tqdm import tqdm
from dask_jobqueue import SGECluster
from dask.distributed import Client, SSHCluster, LocalCluster, as_completed
from contextlib import nullcontext, ExitStack

# silence dask, configure this in ~/.config/dask/logging.yaml to be reliable
# https://docs.dask.org/en/latest/how-to/debug.html#logs
import logging
logging.basicConfig(level=logging.ERROR)
logging.getLogger("dask").setLevel(logging.WARNING)
logging.getLogger("distributed").setLevel(logging.WARNING)
logging.getLogger("distributed.worker").setLevel(logging.WARNING)


def main(dask_client):
    
    # global res_dset
    
    # dispatch, reading in parallel may be faster
    futures = dask_client.map(analyse_neuron, combinations)
    
    log.info("futures dispatched")
    
    idx = 0
    for future in tqdm(as_completed(futures), total=len(futures)):
        idx += 1
        # a dict of results observable -> scalar
        # and a dict fo coordinates
        this_res, this_cs = future.result()

        # add datavariables to dset if they do not exist yet
        for k in this_res.keys():
            if k not in res_dset.data_vars:
                res_dset[k] = xr.DataArray(
                    np.nan,
                    coords=res_dset.coords,
                )

        # write all results to the dataset
        for k, v in this_res.items():
            try:
                res_dset[k].loc[this_cs] = v
            except ValueError:
                # numeric types only
                res_dset[k].loc[this_cs] = np.nan

        # analysis might be slow, lets save the progress
        if dump_progress and idx % num_cores == 0:
            try:
                res_dset.to_zarr(output_filepath, mode="w")
            except:
                pass
            
    try:
        res_dset.to_zarr(output_filepath, mode="w")
    except:
        pass
    
    
    return res_dset

with ExitStack() as stack:
    # init dask using a context manager to ensure proper cleanup
    # when using remote compute
    dask_cluster = stack.enter_context(
        # rudabeh
        SGECluster(
            cores=32,
            memory="192GB",
            processes=32,
            job_extra_directives=["-pe mvapich2-sam 32"],
            log_directory="/scratch01.local/pspitzner/dask/logs",
            local_directory="/scratch01.local/pspitzner/dask/scratch",
            # log_directory="/scratch02.local/johannes/dask/logs",
            # local_directory="/scratch02.local/johannes/dask/scratch",
            # log_directory="/scratch03.local/lucas/dask/logs",
            # local_directory="/scratch03.local/lucas/dask/scratch",
            interface="ib0",
            walltime='24:00:00',
            worker_extra_args=[
                '--preload \'import sys; sys.path.append("/data.nst/pspitzner/information_timescales/branching_network/ana/"); sys.path.append("/data.nst/pspitzner/information_timescales/branching_network/");\''
            ],
        )
        # local cluster
        # LocalCluster(local_directory=f"{tempfile.gettempdir()}/dask/")
    )
    dask_cluster.scale(cores=num_cores)
    dask_client = stack.enter_context(Client(dask_cluster))

    xr_dset = main(dask_client)

In [None]:
res_dset_loaded = xr.open_zarr(output_filepath)
np.where(np.isnan(res_dset['tau_C_single']))

In [None]:
# # this cell shows how we can add dimensions that we might forgot about in a previous ana
# res_dset_loaded =xr.open_zarr(output_filepath.replace(".zarr", "_backup.zarr"))

# # load dims from an old simulation file
# files = glob.glob(input_directory + f"*230308*.zarr")
# print(f"{len(files)} files matched the old pattern")

# new_dims = get_dims_of_file(files[0], dims_to_get=["k", "N"])

# # expand_dims: "If provided as a dict, then the keys are the new dimensions and the
# # values are either integers (giving the length of the new dimensions) or
# # sequence/ndarray (giving the coordinates of the new dimensions)
# new_dims = {k: [v] for k, v in new_dims.items()}
# res_dset_loaded = res_dset_loaded.expand_dims(new_dims)
# res_dset_loaded


In [None]:
# save back to disk
# res_dset_loaded.to_zarr(output_filepath.replace(".zarr", "_added_dims.zarr"), mode="w")

In [13]:
# this cell is for merging two existing datasets

# ds1 = xr.open_zarr("/data.nst/share/projects/information_timescales/branching_network/res_dset_bn_code_cleaned.zarr")
# ds2 = xr.open_zarr("/data.nst/share/projects/information_timescales/branching_network/res_dset_bn_code_cleaned_N=1000_to_merge.zarr")
# ds1 = ds1.merge(ds2)
# ds1.to_zarr("/data.nst/share/projects/information_timescales/branching_network/res_dset_bn_code_cleaned_merged.zarr", mode="w")

<xarray.backends.zarr.ZarrStore at 0x7f62c512e570>