In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr

import ecephys as ec
import findlay2025a as f25a
import wisc_ecephys_tools as wet

  warn(


In [2]:
nbsh = wet.get_sglx_project("seahorse")
experiment = wet.rats.constants.SleepDeprivationExperiments.NOD
evt_types = [
    "spw",
    "dspk",
]

In [3]:
ctm = pd.read_parquet(nbsh.get_project_file("aggregated_cell_metrics.pqt"))
cq = pd.read_parquet(nbsh.get_project_file("cluster_quality.pqt"))
ctm = ctm.merge(cq, on=["subject", "experiment", "probe", "cluster_id"], how="left")
ctm = ctm[(ctm["state"] == "NREM") & (ctm["experiment"] == experiment)]
ctm["petersen_cell_type"] = ctm["petersen_cell_type"].fillna("untyped")

In [4]:
df = ctm[ctm["max_quality"].isin(["sua_moderate", "sua_conservative"])]
df = df[df["state"] == "NREM"]

In [5]:
subject_peths = []
for subject, probes in f25a.units.get_nod_sortings():
    cluster_info = ctm[(ctm["subject"] == subject)].set_index("cluster_id")
    evt_peths = []
    for evt_type in evt_types:
        print(f"Processing {evt_type} for {subject}...")
        zscored_peths_file = nbsh.get_experiment_subject_file(
            experiment, subject, f"{evt_type}_zscored_peths.zarr"
        )
        zpeths = xr.open_zarr(zscored_peths_file).load()

        zpeths = zpeths.assign_coords(
            {
                "max_quality": (
                    "cluster_id",
                    cluster_info.loc[zpeths.cluster_id.values, "max_quality"],
                )
            }
        )
        zpeths = zpeths.assign_coords(
            {
                "petersen_cell_type": (
                    "cluster_id",
                    cluster_info.loc[zpeths.cluster_id.values, "petersen_cell_type"],
                )
            }
        )
        zpeths = zpeths.assign_coords(
            {
                "waxholm_structure": (
                    "cluster_id",
                    cluster_info.loc[zpeths.cluster_id.values, "acronym"],
                )
            }
        )
        replacements = {
            "^mPPC$": "PPC",
            "^V$": "VO",
            "^CA3-DG-dl$": "CA3-DG",
            "^CA3-DG-vl$": "CA3-DG",
        }
        for pattern, replacement in replacements.items():
            zpeths = zpeths.assign_coords(
                {
                    "acronym": (
                        "cluster_id",
                        zpeths.acronym.str.replace(
                            pattern, replacement, regex=True
                        ).values,
                    ),
                    "waxholm_structure": (
                        "cluster_id",
                        zpeths.waxholm_structure.str.replace(
                            pattern, replacement, regex=True
                        ).values,
                    ),
                }
            )

        sba = (
            zpeths.sel(
                cluster_id=zpeths.max_quality.isin(["sua_moderate", "sua_conservative"])
            )
            .groupby("state")
            .mean(dim="event")
            .expand_dims("trigger")
            .assign(trigger=[evt_type])
        )
        evt_peths.append(sba)
    evt_peths = xr.concat(evt_peths, dim="trigger")
    subject_peths.append(evt_peths)
subject_peths = xr.concat(subject_peths, dim="cluster_id")

Processing spw for CNPIX2-Segundo...
Processing dspk for CNPIX2-Segundo...
Processing spw for CNPIX3-Valentino...
Processing dspk for CNPIX3-Valentino...
Processing spw for CNPIX4-Doppio...
Processing dspk for CNPIX4-Doppio...
Processing spw for CNPIX5-Alessandro...
Processing dspk for CNPIX5-Alessandro...
Processing spw for CNPIX6-Eugene...
Processing dspk for CNPIX6-Eugene...
Processing spw for CNPIX8-Allan...
Processing dspk for CNPIX8-Allan...
Processing spw for CNPIX9-Luigi...
Processing dspk for CNPIX9-Luigi...
Processing spw for CNPIX10-Charles...
Processing dspk for CNPIX10-Charles...
Processing spw for CNPIX11-Adrian...
Processing dspk for CNPIX11-Adrian...
Processing spw for CNPIX12-Santiago...
Processing dspk for CNPIX12-Santiago...
Processing spw for CNPIX14-Francis...
Processing dspk for CNPIX14-Francis...
Processing spw for CNPIX15-Claude...
Processing dspk for CNPIX15-Claude...
Processing spw for CNPIX17-Hans...
Processing dspk for CNPIX17-Hans...
Processing spw for CNPI

In [6]:
structures = [
    "CA1",
    "CA3",
    "DG",
    "Cg1",
    "EC",
    "HR",
    "IL",
    "M1",
    "M2",
    "MO",
    "PPC",
    "PrL",
    "V1",
    "V2",
    "VO",
    "CL",
    "DLG",
    "LD",
    "LP",
    "MG",
    "PCN",
    "Po",
    "VM",
    "VPM",
    "VPN",
    "???",
    "BRF",
    "CLA",
    "CPu",
    "HY",
    "NAc",
    "NAc-c",
    "NAc-sh",
    "OB",
    "SN",
    "Str",
    "Tel",
    "ZI",
    "ZI-c",
    "ac",
    "bsc",
    "cc-ec-cing-dwm",
    "cfp",
    "eml",
    "ml",
]

In [7]:
semd_ws = ec.units.get_peths_sem(
    subject_peths["peth"].sel(state="NREM"),
    variance_dim="cluster_id",
    group_variance_by=["waxholm_structure", "petersen_cell_type"],
)

In [8]:
print([s for s in semd_ws.waxholm_structure.values if s not in structures])
print([s for s in structures if s not in semd_ws.waxholm_structure.values])

[]
[]


In [9]:
pyramidal_rgb = (204, 51, 51)
narrow_rgb = (51, 51, 204)
wide_rgb = (51, 204, 204)
untyped_rgb = (128, 128, 128)

# Define colors for each petersen_cell_type using RGB values
# Convert RGB values to matplotlib format (0-1 range)
color_map = {
    "pyramidal": tuple(c / 255 for c in pyramidal_rgb),
    "narrow interneuron": tuple(c / 255 for c in narrow_rgb),
    "wide interneuron": tuple(c / 255 for c in wide_rgb),
    "untyped": tuple(c / 255 for c in untyped_rgb),
}

In [None]:
megafigure = False

triggers = [
    "spw",
    "dspk",
]

nrows = len(semd_ws.waxholm_structure)
ncols = len(semd_ws.trigger)
height = 4
width = 4

if megafigure:
    fig, axes = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(ncols * width, nrows * height), sharex=True
    )

for i, structure in enumerate(semd_ws.waxholm_structure.values):
    if not megafigure:
        fig, axes = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * width, height))
    for j, trigger in enumerate(semd_ws.trigger.values):
        if megafigure:
            ax = axes[i, j]
        else:
            ax = axes[j]

        for k, cell_type in enumerate(color_map.keys()):
            data = semd_ws.sel(
                waxholm_structure=structure,
                trigger=trigger,
                petersen_cell_type=cell_type,
            )
            time = data.time.values
            mean = data["mean"].values
            sem = data["sem"].values

            color = color_map[cell_type]

            # Plot mean as solid line
            ax.plot(time, mean, color=color, linewidth=2, label=cell_type)

            # Plot SEM as shaded region
            ax.fill_between(time, mean - sem, mean + sem, alpha=0.3, color=color)

        # Add vertical and horizontal lines at zero.
        ax.axvline(x=0, color="black", linestyle="--", alpha=0.7)
        ax.axhline(y=0, color="black", linestyle="-", linewidth=0.5, alpha=0.7)

        # Set labels
        if (not megafigure) or (i == 0):
            ax.set_title(f"{trigger}")
        if j == 0:
            ax.set_ylabel(f"{structure}")
        if (not megafigure) or (i == len(semd_ws.waxholm_structure) - 1):
            ax.set_xlabel("Time")

        # Add legend to first subplot only
        if i == 0 and j == 0:
            pass
            # ax.legend()

    if not megafigure:
        plt.tight_layout()
        savefile = nbsh.get_project_file(f"figures/spw_dspk_sua_peths/{structure}.svg")
        savefile.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(savefile)
        plt.show()

if megafigure:
    plt.tight_layout()
    print("Saving figure...")
    fig.savefig(nbsh.get_project_file("figures/spw_dspk_sua_peths.svg"))
    plt.show()