In [1]:
import pathlib
import shutil

import xarray as xr

import ecephys as ec
import findlay2025a as f25a
import wisc_ecephys_tools as wet

In [2]:
def overwrite_zarr(data: xr.DataArray, filepath: pathlib.Path):
    if filepath.exists():
        shutil.rmtree(filepath)
    assert not filepath.exists(), f"Failed to remove extant {filepath}."
    return data.to_zarr(filepath)

In [None]:
nbsh = wet.get_sglx_project("seahorse")
experiment = wet.rats.constants.SleepDeprivationExperiments.NOD
bin_size = 0.025
zscore_window_size = 10.0
kept_window_size = 1.0

for subject, probes in f25a.units.get_nod_sortings():
    print(f"Doing {subject}...")
    spws = f25a.sharp_waves.read_spws(subject, experiment)
    ds = f25a.dentate_spikes.read_dspks(subject, experiment)
    ripples = f25a.ripples.read_ripples(subject, experiment)
    cx_sps, cx_trghs = f25a.spindles.read_spindles(
        subject, experiment, region="cortical"
    )

    evt_trains = {
        "spw": spws["pk_time"].values,
        "ripple": ripples["pk_time"].values,
        "dspk": ds["peak_time"].values,
        "cortical_spindle": cx_sps["Peak"].values,
        "cortical_spindle_trough": cx_trghs,
    }

    hgs = f25a.hypnograms.load_statistical_condition_hypnograms(
        experiment, subject, include_full_conservative=True
    )
    hg = hgs.pop("full_conservative").drop_states(
        f25a.core.ARTIFACT_STATES + ["NoData"]
    )

    evt_states = {k: hg.get_states(v) for k, v in evt_trains.items()}

    threshold_kwargs = f25a.units.get_threshold_kwargs()["mua"]
    mps = f25a.units.load_nod_multiprobe_sorting(subject, **threshold_kwargs)
    spike_trains = mps.get_cluster_trains()

    for evt_type in list(evt_trains):
        print(f"Computing {evt_type} peths...")
        peths = ec.units.get_peths_from_trains(
            spike_trains,
            evt_trains[evt_type],
            event_labels=evt_states[evt_type],
            train_keys="cluster_id",
            property_frame=mps.properties,
            property_names=["acronym"],
            pre_time=zscore_window_size,
            post_time=zscore_window_size,
            bin_size=bin_size,
            return_fr=False,
        ).rename("peth")
        peths = peths.rename({"event_type": "state"})
        peths = peths.sel(event=peths.state.isin(["Wake", "NREM", "IS", "REM"]))

        zscored_peths = ec.units.zscore_peths_by_peri_event_window(peths).sel(
            time=slice(-kept_window_size, kept_window_size)
        )
        zscored_peths_file = nbsh.get_experiment_subject_file(
            experiment, subject, f"{evt_type}_zscored_peths.zarr"
        )
        overwrite_zarr(zscored_peths, zscored_peths_file)

        peths = peths.sel(time=slice(-kept_window_size, kept_window_size)).astype(
            "uint16"
        )
        peths_file = nbsh.get_experiment_subject_file(
            experiment, subject, f"{evt_type}_peths.zarr"
        )
        overwrite_zarr(peths, peths_file)