In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from ntd.train_diffusion_model import init_dataset
from ntd.utils.plotting_utils import (
    FigureLayout,
    basic_plotting,
    plot_density,
    plot_overlapping_signal,
    plot_phase_line,
    plot_sd,
    polar_hist,
)
from ntd.utils.utils import (
    extract_sharp_wave_ripples,
    path_loader,
    permutation_test,
    phase_amplitude_coupling,
    phase_count_coupling,
    surrogate_dataset,
)

matplotlibrc_path = "../matplotlibrc"

FL = FigureLayout(
    width_in_pt=3 * 397,
    width_grid=24,
    scale_factor=3,
    base_font_size=6,
)


In [None]:
samples = "TODO"
imputations = "TODO"
cfg = "TODO"


In [None]:
samples = samples.cpu()
samples_numpy = samples.numpy()
imputations = imputations.cpu()
imputations_numpy = imputations.numpy()

train_dataset, test_dataset = init_dataset(cfg)
raw_signal = torch.stack([dic["signal"] for dic in train_dataset])
raw_signal_numpy = raw_signal.numpy()
raw_signal_test = torch.stack([dic["signal"] for dic in test_dataset])
raw_signal_test_numpy = raw_signal_test.numpy()

print(imputations.shape)
print(samples.shape)


In [None]:
rand_id = np.random.randint(len(samples_numpy))

signal_channels = 3
signal_colors = ["firebrick", "C1", "orangered"]
sig_color_one, sig_color_two, sig_color_three = signal_colors

offset = 5.0
with plt.rc_context(rc=FL.get_rc(10, 4), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_overlapping_signal(
        fig,
        ax,
        samples_numpy[rand_id] - offset * np.arange(signal_channels)[:, np.newaxis],
        colors=signal_colors,
    )
    basic_plotting(
        fig,
        ax,
        x_label="time (s)",
        y_axis_visibility=False,
        x_lim=(0, 1200),
        x_ticks=(0, 600, 1200),
        x_ticklabels=(0, 1, 2),
    )
    fig.tight_layout()
    plt.show()


In [None]:
lower_marginal = -5.0
upper_marginal = 5.0
with plt.rc_context(rc=FL.get_rc(6, 6), fname=matplotlibrc_path):
    fig, axs = plt.subplots(signal_channels, 1)
    for i in range(signal_channels):
        plot_density(
            fig,
            axs[i],
            raw_signal_numpy[:, i, :].flatten(),
            x_range=np.linspace(lower_marginal, upper_marginal, 100),
            d_alpha=0.5,
            color="grey",
        )
        plot_density(
            fig,
            axs[i],
            samples_numpy[:, i, :].flatten(),
            x_range=np.linspace(lower_marginal, upper_marginal, 100),
            d_alpha=0.5,
            color=signal_colors[i],
        )
        if i == signal_channels - 1:
            basic_plotting(
                fig,
                axs[-1],
                x_label="standardized voltage (a.u.)",
                y_axis_visibility=False,
                x_lim=(lower_marginal, upper_marginal),
                x_ticks=[lower_marginal, 0, upper_marginal],
            )
        else:
            basic_plotting(
                fig,
                axs[i],
                y_axis_visibility=False,
                x_lim=(lower_marginal, upper_marginal),
                x_ticks=[lower_marginal, 0, upper_marginal],
                x_ticklabels=["", "", ""],
            )
    fig.tight_layout()
    plt.show()



In [None]:
agg_function = np.median
with_quantiles = True
lower_quantile = 0.1
upper_quantile = 0.9
alpha_boundary = 0.2


with plt.rc_context(rc=FL.get_rc(6, 6), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        raw_signal_numpy[:, 0, :],
        samples_numpy[:, 0, :],
        fs=600,
        nperseg=1200,
        color_one="grey",
        color_two=sig_color_one,
        agg_function=agg_function,
        with_quantiles=with_quantiles,
        lower_quantile=lower_quantile,
        upper_quantile=upper_quantile,
        alpha_boundary=alpha_boundary,
        x_ss=slice(0, -50),
    )
    basic_plotting(fig, ax, y_label="power (a.u.)", x_label="freq (Hz)", y_ticks=[])
    fig.tight_layout()
    plt.show()

with plt.rc_context(rc=FL.get_rc(6, 6), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        raw_signal_numpy[:, 1, :],
        samples_numpy[:, 1, :],
        fs=600,
        nperseg=1200,
        color_one="grey",
        color_two=sig_color_two,
        agg_function=agg_function,
        with_quantiles=with_quantiles,
        lower_quantile=lower_quantile,
        upper_quantile=upper_quantile,
        alpha_boundary=alpha_boundary,
        x_ss=slice(0, -50),
    )
    basic_plotting(fig, ax, x_label="freq (Hz)", y_ticks=[])
    fig.tight_layout()
    plt.show()

with plt.rc_context(rc=FL.get_rc(6, 6), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        raw_signal_numpy[:, 2, :],
        samples_numpy[:, 2, :],
        fs=600,
        nperseg=1200,
        color_one="grey",
        color_two=sig_color_three,
        agg_function=agg_function,
        with_quantiles=with_quantiles,
        lower_quantile=lower_quantile,
        upper_quantile=upper_quantile,
        alpha_boundary=alpha_boundary,
        x_ss=slice(0, -50),
    )
    basic_plotting(fig, ax, x_label="freq (Hz)", y_ticks=[])
    fig.tight_layout()
    plt.show()



In [None]:
# Frequency windows

delta_wn = [1.0, 4.0]
theta_wn = [4.0, 9.0]
spindle_wn = [6.0, 14.0]
high_wn = [100.0, 275.0]

surro_dataset = np.zeros_like(raw_signal_numpy)
surro_dataset[:, 0, :] = surrogate_dataset(raw_signal_numpy[:, 0, :])
surro_dataset[:, 1, :] = surrogate_dataset(raw_signal_numpy[:, 1, :])
surro_dataset[:, 2, :] = surrogate_dataset(raw_signal_numpy[:, 2, :])


In [None]:
gene_coup_color = "red"
real_coup_color = "black"
surro_coup_color = "olivedrab"

x_tick_angles = [
    0,
    np.pi / 4,
    np.pi / 2,
    (3 * np.pi) / 4,
    np.pi,
    (5 * np.pi) / 4,
    (3 * np.pi) / 2,
    (7 * np.pi) / 4,
]


In [None]:
# SPINDLE PAC
(
    spindle_real_pac,
    p_bins,
) = phase_amplitude_coupling(
    raw_signal_numpy[:, 0, :],
    raw_signal_numpy[:, 1, :],
    spindle_wn,
    high_wn,
)
mean_bin_phase = np.convolve(p_bins, np.array([0.5, 0.5]), mode="valid")
(
    spindle_surrogate_pac,
    _p_bins,
) = phase_amplitude_coupling(
    raw_signal_numpy[:, 0, :],
    surro_dataset[:, 1, :],
    spindle_wn,
    high_wn,
)
(
    spindle_generated_pac,
    _p_bins,
) = phase_amplitude_coupling(
    samples_numpy[:, 0, :],
    samples_numpy[:, 1, :],
    spindle_wn,
    high_wn,
)



In [None]:
# SPINDLE PAC Plot
with plt.rc_context(rc=FL.get_rc(4, 4), fname=matplotlibrc_path):
    fig, ax = plt.subplots(subplot_kw={"projection": "polar"})
    plot_phase_line(
        fig,
        ax,
        mean_bin_phase,
        spindle_real_pac,
        color=real_coup_color,
    )
    plot_phase_line(
        fig,
        ax,
        mean_bin_phase,
        spindle_generated_pac,
        color=gene_coup_color,
    )
    plot_phase_line(
        fig,
        ax,
        mean_bin_phase,
        spindle_surrogate_pac,
        color=surro_coup_color,
    )
    basic_plotting(fig, ax, y_axis_visibility=False)
    ax.set_ylim((0.025, 0.04))
    ax.set_xticks(x_tick_angles)
    ax.set_xticklabels(["0°", "45°", "90°", "", "", "", "", ""])
    fig.tight_layout()
    plt.show()


In [None]:
# DELTA PAC
delta_real_pac, p_bins = phase_amplitude_coupling(
    raw_signal_numpy[:, 0, :],
    raw_signal_numpy[:, 1, :],
    delta_wn,
    high_wn,
)
mean_bin_phase = np.convolve(p_bins, np.array([0.5, 0.5]), mode="valid")
delta_surrogate_pac, _p_bins = phase_amplitude_coupling(
    raw_signal_numpy[:, 0, :],
    surro_dataset[:, 1, :],
    delta_wn,
    high_wn,
)
delta_generated_pac, _p_bins = phase_amplitude_coupling(
    samples_numpy[:, 0, :],
    samples_numpy[:, 1, :],
    delta_wn,
    high_wn,
)



In [None]:
# DELTA PAC Plot
with plt.rc_context(rc=FL.get_rc(4, 4), fname=matplotlibrc_path):
    fig, ax = plt.subplots(subplot_kw={"projection": "polar"})
    plot_phase_line(
        fig,
        ax,
        mean_bin_phase,
        delta_real_pac,
        color=real_coup_color,
    )
    plot_phase_line(
        fig,
        ax,
        mean_bin_phase,
        delta_generated_pac,
        color=gene_coup_color,
    )
    plot_phase_line(
        fig,
        ax,
        mean_bin_phase,
        delta_surrogate_pac,
        color=surro_coup_color,
    )
    ax.set_ylim((0.025, 0.04))
    ax.set_xticks(x_tick_angles)
    ax.set_xticklabels(["0°", "45°", "90°", "", "", "", "", ""])
    basic_plotting(fig, ax, y_axis_visibility=False)
    fig.tight_layout()
    plt.show()


In [None]:
# THETA PAC
theta_real_pac, p_bins = phase_amplitude_coupling(
    raw_signal_numpy[:, 0, :],
    raw_signal_numpy[:, 1, :],
    theta_wn,
    high_wn,
)
mean_bin_phase = np.convolve(p_bins, np.array([0.5, 0.5]), mode="valid")
theta_surrogate_pac, p_bins = phase_amplitude_coupling(
    raw_signal_numpy[:, 0, :],
    surro_dataset[:, 1, :],
    theta_wn,
    high_wn,
)
theta_generated_pac, p_bins = phase_amplitude_coupling(
    samples_numpy[:, 0, :],
    samples_numpy[:, 1, :],
    theta_wn,
    high_wn,
)


In [None]:
# THETA PAC Plot
with plt.rc_context(rc=FL.get_rc(4, 4), fname=matplotlibrc_path):
    fig, ax = plt.subplots(subplot_kw={"projection": "polar"})
    plot_phase_line(
        fig,
        ax,
        mean_bin_phase,
        theta_real_pac,
        color=real_coup_color,
    )
    plot_phase_line(
        fig,
        ax,
        mean_bin_phase,
        theta_generated_pac,
        color=gene_coup_color,
    )
    plot_phase_line(
        fig,
        ax,
        mean_bin_phase,
        theta_surrogate_pac,
        color=surro_coup_color,
    )
    ax.set_ylim((0.025, 0.04))
    ax.set_xticks(x_tick_angles)
    ax.set_xticklabels(["0°", "45°", "90°", "", "", "", "", ""])
    basic_plotting(fig, ax, y_axis_visibility=False)
    fig.tight_layout()
    plt.show()


In [None]:
supervised_mean = 0.031925421208143234
supervised_std = 0.1935054510831833


# SPINDLE PCC
num_bins = 21
spindle_real_pcc, p_bins = phase_count_coupling(
    raw_signal_numpy[:, 0, :],
    raw_signal_numpy[:, 1, :],
    spindle_wn,
    high_wn,
    num_bins=num_bins,
)
spindle_surrogate_pcc, _p_bins = phase_count_coupling(
    raw_signal_numpy[:, 0, :],
    surro_dataset[:, 1, :],
    spindle_wn,
    high_wn,
    num_bins=num_bins,
)
spindle_generated_pcc, _p_bins = phase_count_coupling(
    samples_numpy[:, 0, :],
    samples_numpy[:, 1, :],
    spindle_wn,
    high_wn,
    num_bins=num_bins,
    supervised_mean=supervised_mean,
    supervised_std=supervised_std,
)


In [None]:
# SPINDLE PCC Plot
with plt.rc_context(rc=FL.get_rc(4, 4), fname=matplotlibrc_path):
    fig, ax = plt.subplots(subplot_kw={"projection": "polar"})
    polar_hist(
        fig,
        ax,
        spindle_surrogate_pcc,
        p_bins,
        fillcolor=surro_coup_color,
        spinecolor=surro_coup_color,
    )
    polar_hist(
        fig,
        ax,
        spindle_real_pcc,
        p_bins,
        fillcolor=real_coup_color,
        spinecolor=real_coup_color,
    )
    polar_hist(
        fig,
        ax,
        spindle_generated_pcc,
        p_bins,
        fillcolor=gene_coup_color,
        spinecolor=gene_coup_color,
    )
    ax.set_xticks(x_tick_angles)
    ax.set_xticklabels(["0°", "45°", "90°", "", "", "", "", ""])
    basic_plotting(fig, ax, y_axis_visibility=False)
    fig.tight_layout()
    plt.show()


In [None]:
# DELTA PCC
num_bins = 21
delta_real_pcc, p_bins = phase_count_coupling(
    raw_signal_numpy[:, 0, :],
    raw_signal_numpy[:, 1, :],
    delta_wn,
    high_wn,
    num_bins=num_bins,
)
delta_surrogate_pcc, _p_bins = phase_count_coupling(
    raw_signal_numpy[:, 0, :],
    surro_dataset[:, 1, :],
    delta_wn,
    high_wn,
    num_bins=num_bins,
)
delta_generated_pcc, _p_bins = phase_count_coupling(
    samples_numpy[:, 0, :],
    samples_numpy[:, 1, :],
    delta_wn,
    high_wn,
    num_bins=num_bins,
    supervised_mean=supervised_mean,
    supervised_std=supervised_std,
)



In [None]:
# DELTA PCC Plot
with plt.rc_context(rc=FL.get_rc(4, 4), fname=matplotlibrc_path):
    fig, ax = plt.subplots(subplot_kw={"projection": "polar"})
    polar_hist(
        fig,
        ax,
        delta_surrogate_pcc,
        p_bins,
        fillcolor=surro_coup_color,
        spinecolor=surro_coup_color,
    )
    polar_hist(
        fig,
        ax,
        delta_real_pcc,
        p_bins,
        fillcolor=real_coup_color,
        spinecolor=real_coup_color,
    )
    polar_hist(
        fig,
        ax,
        delta_generated_pcc,
        p_bins,
        fillcolor=gene_coup_color,
        spinecolor=gene_coup_color,
    )
    ax.set_xticks(x_tick_angles)
    ax.set_xticklabels(["0°", "45°", "90°", "", "", "", "", ""])
    basic_plotting(fig, ax, y_axis_visibility=False)
    fig.tight_layout()
    plt.show()


In [None]:
# THETA PCC
num_bins = 21
theta_real_pcc, p_bins = phase_count_coupling(
    raw_signal_numpy[:, 0, :],
    raw_signal_numpy[:, 1, :],
    theta_wn,
    high_wn,
    num_bins=num_bins,
)
theta_surrogate_pcc, _p_bins = phase_count_coupling(
    raw_signal_numpy[:, 0, :],
    surro_dataset[:, 1, :],
    theta_wn,
    high_wn,
    num_bins=num_bins,
)
theta_generated_pcc, _p_bins = phase_count_coupling(
    samples_numpy[:, 0, :],
    samples_numpy[:, 1, :],
    theta_wn,
    high_wn,
    num_bins=num_bins,
    supervised_mean=supervised_mean,
    supervised_std=supervised_std,
)


In [None]:
# THETA PCC Plot
with plt.rc_context(rc=FL.get_rc(4, 4), fname=matplotlibrc_path):
    fig, ax = plt.subplots(subplot_kw={"projection": "polar"})
    polar_hist(
        fig,
        ax,
        theta_surrogate_pcc,
        p_bins,
        fillcolor=surro_coup_color,
        spinecolor=surro_coup_color,
    )
    polar_hist(
        fig,
        ax,
        theta_real_pcc,
        p_bins,
        fillcolor=real_coup_color,
        spinecolor=real_coup_color,
    )
    polar_hist(
        fig,
        ax,
        theta_generated_pcc,
        p_bins,
        fillcolor=gene_coup_color,
        spinecolor=gene_coup_color,
    )
    ax.set_xticks(x_tick_angles)
    ax.set_xticklabels(["0°", "45°", "90°", "", "", "", "", ""])
    basic_plotting(fig, ax, y_axis_visibility=False)
    fig.tight_layout()
    plt.show()


In [None]:
# SHARP WAVE RIPPLE IMPUTATION
for idx in np.random.permutation(len(imputations_numpy))[:10]:
    with plt.rc_context(fname=matplotlibrc_path):
        print(idx)
        plt.figure(figsize=(20, 5))
        plt.ylim((-5, 5))
        plt.plot(imputations_numpy[idx, :, :].T)
        plt.show()


In [None]:
print(imputations_numpy.shape)
print(raw_signal_test_numpy.shape)

high_wn = [100.0, 275.0]
fuser_gap = 10
min_length = 11
real_swrs = extract_sharp_wave_ripples(
    raw_signal_test_numpy[:, 0, :],
    high_wn,
    fuser_gap=fuser_gap,
    min_length=min_length,
)
generated_swrs = extract_sharp_wave_ripples(
    imputations_numpy[:, 0, :],
    high_wn,
    fuser_gap=fuser_gap,
    min_length=min_length,
    supervised_mean=0.031925421208143234,
    supervised_std=0.1935054510831833,
)



In [None]:
f1, perm_f1s, p_val = permutation_test(real_swrs, generated_swrs, 1000)

print(f1)
print(p_val)

with plt.rc_context(rc=FL.get_rc(6, 2.3), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_density(fig, ax, perm_f1s, np.linspace(0.25, 0.5, 100), color=surro_coup_color)
    ax.vlines(f1, 0, 20.0, color=gene_coup_color)
    basic_plotting(fig, ax, y_axis_visibility=False, x_label="F1-score")
    fig.tight_layout()
    plt.show()
