In [1]:
import numpy as np
import pandas as pd
import xarray as xr

from findlay2025a import core
from findlay2025a.constants import Experiments as Exps

In [2]:
def get_roi(acronym: str) -> str:
    ROIS = {
        "CX": ["mPPC", "Cx", "PPC", "V2", "V1", "???"],
        "CA1-so": ["CA1-so"],
        "CA1-sp": ["CA1-sp"],
        "CA1-sr": ["CA1-sr"],
        "CA1-slm": ["CA1-slm"],
        "DG": [
            "DG-dl-ml",
            "DG-dl-gcl",
            "DG-dl-pol",
            "DG-vl-ml",
            "DG-vl-gcl",
            "DG-vl-pol",
            "DG-pol",
        ],
        "CA3": ["CA3-sp"],
    }
    for roi, subregions in ROIS.items():
        if acronym in subregions:
            return roi
    return None


def _aggregate_psds(
    region, drop_cols=["x", "y", "structure", "channel", "notes"]
) -> pd.DataFrame:
    df = pd.DataFrame()
    freqs = None
    for subject, experiment in core.yield_subject_name_experiment_pairs():
        if region == "cx":
            ds = xr.load_dataset(core.get_cortical_psds_file(subject, experiment)).sel(
                frequency=slice(0.5, 100)
            )
        elif region == "hipp":
            ds = xr.load_dataset(
                core.get_hippocampal_psds_file(subject, experiment)
            ).sel(frequency=slice(1, 300))
        else:
            raise ValueError(f"Invalid region: {region}")

        # All subjects should use the same EXACT frequency bins
        if freqs is None:
            freqs = ds["frequency"].values
        elif np.allclose(ds["frequency"].values, freqs):
            ds["frequency"] = freqs
        else:
            raise ValueError(f"Frequency mismatch for {subject}: {experiment}.")

        _df = ds.to_dataframe().reset_index()
        _df["subject"] = subject
        _df["experiment"] = experiment

        # Label each channel with its coarse ROI
        _df["roi"] = _df["acronym"].apply(get_roi)
        _df = _df.loc[_df["roi"].notnull()].reset_index(drop=True)

        if region == "hipp":
            # Some deep cortical channels may appear in the hippocampal_psds file. Drop them.
            _df = _df.loc[_df["roi"] != "CX"]
        elif region == "cx":
            # Some non-cortical channels may appear in the cortical_psds file. Drop them.
            _df = _df.loc[_df["roi"] == "CX"]

        df = pd.concat([df, _df], axis=0, ignore_index=True)

    # Drop columns that are not needed for the analysis, just to save memory
    df.drop(columns=drop_cols, inplace=True)

    return df.reset_index(drop=True)


def get_psd_contrast_as_percent_change(df: pd.DataFrame) -> pd.DataFrame:
    # Get percent_change for each contrast
    dfc = pd.DataFrame(index=df.index)
    dfc["nrem_rebound"] = (
        (df["early_rec_nrem"] - df["early_rec_nrem_match"])
        / df["early_rec_nrem_match"]
        * 100
    )
    dfc["nrem_rec_decline"] = (
        (df["early_rec_nrem"] - df["late_rec_nrem"]) / df["early_rec_nrem"] * 100
    )
    dfc["ext_wake_incline"] = (
        (df["late_ext_wake"] - df["early_ext_wake"]) / df["early_ext_wake"] * 100
    )
    dfc["ext_incline"] = (df["late_ext"] - df["early_ext"]) / df["early_ext"] * 100
    dfc["nrem_bsl_decline"] = (
        (df["early_bsl_nrem"] - df["early_rec_nrem_match"])
        / df["early_rec_nrem_match"]
        * 100
    )
    dfc["nrem_surge"] = (
        (df["early_rec_nrem"] - df["early_bsl_nrem"]) / df["early_bsl_nrem"] * 100
    )
    dfc.reset_index(inplace=True)

    # Convert from wide to long
    return dfc.melt(
        id_vars=["subject", "experiment", "frequency", "roi"],
        var_name="contrast",
        value_name="percent_change",
    )


def get_exclusions():
    # Missing data during what would be Eugene's REM/NREM circadian match.
    lbsl_exclusions = [("CNPIX6-Eugene", Exps.NOD)]

    # All these exclusions are because there is some missing data, so our E_BSL may actually come a bit later than true E_BSL. We are being conservative here.
    ewk_exclusions = [
        ("CNPIX5-Alessandro", Exps.NOD),
        ("CNPIX6-Eugene", Exps.NOD),
        ("CNPIX10-Charles", Exps.NOD),
        ("CNPIX11-Adrian", Exps.NOD),
        ("CNPIX11-Adrian", Exps.COW),
        ("CNPIX17-Hans", Exps.COW),
    ]

    return lbsl_exclusions, ewk_exclusions


def apply_condition_exclusions(df: pd.DataFrame) -> pd.DataFrame:
    lbsl_exclusions, ewk_exclusions = get_exclusions()

    # Drop bad data
    for sub, exp in ewk_exclusions:
        drop = (
            (df["subject"] == sub)
            & (df["experiment"] == exp)
            & (
                df["condition"].isin(
                    [
                        "early_ext_wake",
                        "late_ext_wake",
                        "ext_wake",
                        "early_ext",
                        "late_ext",
                    ]
                )
            )
        )
        df = df[~drop]
    for sub, exp in lbsl_exclusions:
        drop = (
            (df["subject"] == sub)
            & (df["experiment"] == exp)
            & (df["condition"].isin(["early_rec_nrem_match"]))
        )
        df = df[~drop]
    return df.reset_index(drop=True)


def apply_contrast_exclusions(dfc: pd.DataFrame) -> pd.DataFrame:
    lbsl_exclusions, ewk_exclusions = get_exclusions()

    for sub, exp in ewk_exclusions:
        drop = (
            (dfc["subject"] == sub)
            & (dfc["experiment"] == exp)
            & (dfc["contrast"].isin(["ext_wake_incline", "ext_incline"]))
        )
        dfc = dfc[~drop]
    for sub, exp in lbsl_exclusions:
        drop = (
            (dfc["subject"] == sub)
            & (dfc["experiment"] == exp)
            & (dfc["contrast"].isin(["nrem_rebound", "rem_rebound"]))
        )
        dfc = dfc[~drop]
    return dfc.reset_index(drop=True)


def aggregate_roi_psds(region, apply_exclusions=True):
    df = _aggregate_psds(region)

    # Get the average PSD across all channels in an ROI. This produces a dataframe that is wide form, where every column is a condition.
    df = df.groupby(["subject", "experiment", "frequency", "roi"]).mean(
        numeric_only=True
    )
    dfc = get_psd_contrast_as_percent_change(df)

    # Convert from wide to long
    df.reset_index(inplace=True)
    df = df.melt(
        id_vars=["subject", "experiment", "frequency", "roi"],
        var_name="condition",
        value_name="psd",
    )

    # This is important for doing stats. Dramatically improves model assumtions.
    df["psd"] = np.log10(df["psd"])

    if apply_exclusions:
        df = apply_condition_exclusions(df)
        dfc = apply_contrast_exclusions(dfc)

    return df, dfc


In [3]:
cx_psds, cx_pct_changes = aggregate_roi_psds("cx")
hc_psds, hc_pct_changes = aggregate_roi_psds("hipp")

psds = pd.concat([cx_psds, hc_psds]).reset_index(drop=True)
pct_changes = pd.concat([cx_pct_changes, hc_pct_changes]).reset_index(drop=True)

conditions = [
    "early_bsl_nrem",
    "early_rec_nrem_match",
    "early_ext_wake",
    "late_ext_wake",
    "early_rec_nrem",
    "late_rec_nrem",
    "early_ext",
    "late_ext",
]
contrasts = [
    "nrem_rebound",
    "nrem_surge",
    "nrem_rec_decline",
    "nrem_bsl_decline",
    "ext_wake_incline",
    "ext_incline",
]

nb = core.get_project("seahorse")
psds.loc[
    (psds["experiment"] == Exps.NOD) & psds["condition"].isin(conditions)
].to_parquet(nb.get_project_file("psds_by_condition.pqt"))
pct_changes.loc[
    (pct_changes["experiment"] == Exps.NOD) & pct_changes["contrast"].isin(contrasts)
].to_parquet(nb.get_project_file("psd_pct_changes.pqt"))

