In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from ntd.train_diffusion_model import init_dataset
from ntd.utils.plotting_utils import (
    FigureLayout,
    basic_plotting,
    plot_overlapping_signal,
    plot_sd,
)
from ntd.utils.utils import path_loader

FL = FigureLayout(width_in_pt=3 * 397, width_grid=24, base_font_size=6, scale_factor=3)

matplotlibrc_path = "../matplotlibrc"


In [None]:
cfg = "TODO"
full_impu_signals = "TODO"


In [None]:
# For plots from paper, condition on first half of the channels of P07
train_data, test_data = init_dataset(cfg)
signal_channel = cfg.network.signal_channel
num_cond_channel = signal_channel // 2
num_impu_channel = signal_channel - num_cond_channel
plot_channel = np.array([35, 39, 45, 47])  # patient 07
plot_channel = np.sort(plot_channel)
signal_colors = ["firebrick", "darkorange", "C1", "orangered"]
channel_one, channel_two, channel_three, channel_four = plot_channel
sig_color_one, sig_color_two, sig_color_three, sig_color_four = signal_colors

print(num_impu_channel)
print(num_cond_channel)
print(channel_one, channel_two, channel_three, channel_four)


In [None]:
cond_list = list(range(num_cond_channel))

test_dicts = [test_data.dataset[i] for i in test_data.indices]
full_test_signals = torch.stack([dic["signal"] for dic in test_dicts])
full_test_cond = torch.stack([dic["cond"] for dic in test_dicts])
full_test_labels = torch.cat([dic["label"] for dic in test_dicts])
del test_dicts


In [None]:
print(full_test_signals.shape)
print(full_impu_signals.shape)
assert full_test_signals.shape == full_impu_signals.shape

full_test_signals_numpy = full_test_signals.numpy()
full_impu_signals_numpy = full_impu_signals.numpy()


In [None]:
plot_signal_id = np.random.randint(full_test_signals.shape[0])
offset = 9.0
with plt.rc_context(rc=FL.get_rc(8, 5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_overlapping_signal(
        fig,
        ax,
        full_test_signals[plot_signal_id, plot_channel]
        + offset * np.arange(len(plot_channel))[:, np.newaxis],
        colors=["grey"],
    )
    basic_plotting(
        fig,
        ax,
        x_label="time (s)",
        y_axis_visibility=False,
        x_lim=(0, 1001),
        x_ticks=(0, 500, 1001),
        x_ticklabels=(0, 2, 4),
    )
    fig.tight_layout()
    plt.show()

with plt.rc_context(rc=FL.get_rc(8, 5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_overlapping_signal(
        fig,
        ax,
        full_impu_signals[plot_signal_id, plot_channel]
        + offset * np.arange(len(plot_channel))[:, np.newaxis],
        signal_colors,
    )
    basic_plotting(
        fig,
        ax,
        x_label="time (s)",
        y_axis_visibility=False,
        x_lim=(0, 1001),
        x_ticks=(0, 500, 1001),
        x_ticklabels=(0, 2, 4),
    )
    fig.tight_layout()
    plt.show()


In [None]:
num_plotted_signals_start = 16
num_plotted_signals_end = 48
num_plotted_signals = num_plotted_signals_end - num_plotted_signals_start

colors = 16 * ["C0"] + 16 * ["black"]
colors[channel_one - 16] = sig_color_one
colors[channel_two - 16] = sig_color_two
colors[channel_three - 16] = sig_color_three
colors[channel_four - 16] = sig_color_four
with plt.rc_context(rc=FL.get_rc(4, 6), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_overlapping_signal(
        fig,
        ax,
        full_impu_signals[
            plot_signal_id,
            np.arange(num_plotted_signals_start, num_plotted_signals_end, dtype=int),
            :,
        ]
        - 4 * np.arange(num_plotted_signals)[:, np.newaxis],
        colors=colors,
    )
    ax.annotate(
        "condition",
        xy=(-0.15, 0.60),
        rotation=90,
        xycoords="axes fraction",
        fontsize="medium",
    )
    ax.annotate(
        "imputation",
        xy=(-0.15, 0.12),
        rotation=90,
        xycoords="axes fraction",
        fontsize="medium",
    )
    basic_plotting(
        fig,
        ax,
        y_axis_visibility=False,
        x_label="time (s)",
        x_lim=(0, 1001),
        x_ticks=(0, 500, 1001),
        x_ticklabels=(0, 2, 4),
    )
    fig.tight_layout()
    plt.show()



In [None]:
agg_function = np.median

print(channel_one)
with plt.rc_context(rc=FL.get_rc(6, 5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        full_test_signals_numpy[:, channel_one, :],
        full_impu_signals_numpy[:, channel_one, :],
        fs=250,
        nperseg=1001,
        agg_function=agg_function,
        with_quantiles=True,
        lower_quantile=0.1,
        upper_quantile=0.9,
        color_one="grey",
        color_two=sig_color_one,
        alpha_boundary=0.2,
        x_ss=slice(0, -25),
    )
    basic_plotting(fig, ax, y_label="power (a.u.)", x_label="freq (Hz)", y_ticks=[])
    fig.tight_layout()
    plt.show()

print(channel_two)
with plt.rc_context(rc=FL.get_rc(6, 5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        full_test_signals_numpy[:, channel_two, :],
        full_impu_signals_numpy[:, channel_two, :],
        fs=250,
        nperseg=1001,
        agg_function=agg_function,
        with_quantiles=True,
        lower_quantile=0.1,
        upper_quantile=0.9,
        color_one="grey",
        color_two=sig_color_two,
        alpha_boundary=0.2,
        x_ss=slice(0, -25),
    )
    basic_plotting(fig, ax, x_label="freq (Hz)", y_ticks=[])
    fig.tight_layout()
    plt.show()

print(channel_three)
with plt.rc_context(rc=FL.get_rc(6, 5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        full_test_signals_numpy[:, channel_three, :],
        full_impu_signals_numpy[:, channel_three, :],
        fs=250,
        nperseg=1001,
        agg_function=agg_function,
        with_quantiles=True,
        lower_quantile=0.1,
        upper_quantile=0.9,
        color_one="grey",
        color_two=sig_color_three,
        alpha_boundary=0.2,
        x_ss=slice(0, -25),
    )
    basic_plotting(fig, ax, x_label="freq (Hz)", y_ticks=[])
    fig.tight_layout()
    plt.show()

print(channel_four)
with plt.rc_context(rc=FL.get_rc(6, 5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        full_test_signals_numpy[:, channel_four, :],
        full_impu_signals_numpy[:, channel_four, :],
        fs=250,
        nperseg=1001,
        agg_function=agg_function,
        with_quantiles=True,
        lower_quantile=0.1,
        upper_quantile=0.9,
        color_one="grey",
        color_two=sig_color_four,
        alpha_boundary=0.2,
        x_ss=slice(0, -25),
    )
    basic_plotting(fig, ax, x_label="freq (Hz)", y_ticks=[])
    fig.tight_layout()
    plt.show()


In [None]:
n_rows = signal_channel // 10 + 1
n_cols = 10

with plt.rc_context(fname=matplotlibrc_path):
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(45, (n_rows / n_cols) * 45))
    for idx in range(n_rows * n_cols):
        if idx >= signal_channel:
            axs[idx // n_cols, idx % n_cols].axis("off")
            continue
        plot_sd(
            fig,
            axs[idx // n_cols, idx % n_cols],
            arr_one=full_impu_signals_numpy[:, idx, :],
            arr_two=full_test_signals_numpy[:, idx, :],
            fs=250,
            nperseg=1001,
            color_one="C0" if idx < num_cond_channel else "C3",
            color_two="black",
            agg_function=np.median,
            with_quantiles=True,
            lower_quantile=0.1,
            upper_quantile=0.9,
        )
        axs[idx // n_cols, idx % n_cols].set_yticks([])
        axs[idx // n_cols, idx % n_cols].set_xticks([])
    fig.tight_layout()
    plt.show()



In [None]:
def comp_ajile_plot(fig, axs, result_dicts, impu_col="red"):
    pat_ids = list(result_dicts.keys())
    # change order of patients
    pat_ids = sorted(pat_ids, key=lambda x: -result_dicts[x]["0.5"][0])
    markersize = 6
    left = -0.5
    right = 2.5
    dropout_lvls = ["0.5", "0.7", "0.9"]
    pat_labels = {
        "ec01": "P01",
        "ec02": "P02",
        "ec03": "P03",
        "ec04": "P04",
        "ec05": "P05",
        "ec06": "P06",
        "ec07": "P07",
        "ec08": "P08",
        "ec09": "P09",
        "ec10": "P10",
        "ec11": "P11",
        "ec12": "P12",
    }

    x_ticklabels = ["50", "70", "90"]
    markers = ["o", "o", "o"]
    rand_per_col = "black"

    for idx, (axis_abs, pat_id) in enumerate(zip(axs, pat_ids)):
        axis_abs.set_ylim((0.45, 1.0))
        axis_abs.set_title(pat_labels[pat_id], x=0.5, y=0.9)
        axis_abs.hlines(0.5, left, right, color=rand_per_col)
        for jdx, (dl, m) in enumerate(zip(dropout_lvls, markers)):
            full, zero, impu, diff_std = result_dicts[pat_id][dl]
            axis_abs.plot(jdx, zero, color="royalblue", marker=m, markersize=markersize)
            axis_abs.errorbar(jdx, impu, diff_std, color=impu_col)
            axis_abs.plot(jdx, impu, color=impu_col, marker=m, markersize=markersize)
        axis_abs.hlines(full, left, right, color="black", linestyles="--")
        if idx > 0:
            axis_abs.set_yticks([0.5, 0.75, 1.0])
            axis_abs.set_yticklabels([])
        else:
            axis_abs.set_ylabel("accuracy (%)")
            axis_abs.set_yticks([0.5, 0.75, 1.0])
            axis_abs.set_yticklabels(["50", "75", "100"])
        axis_abs.set_xticks([0, 1, 2])
        axis_abs.set_xticklabels(x_ticklabels, rotation=45)
    fig.suptitle(
        "neural decoding performance under imputation",
        fontsize="large",
        x=0.5,
        y=0.95,
    )
    return fig, axs



In [None]:
pat_ids = [
    "ec01",
    "ec02",
    "ec03",
    "ec04",
    "ec05",
    "ec06",
    "ec07",
    "ec08",
    "ec09",
    "ec10",
    "ec11",
    "ec12",
]

file_paths = [
    "TODO",
]

file_names = [
    "TODO",
]

results_dict = {}
for pat_id, fn, fp in zip(pat_ids, file_names, file_paths):
    results_dict[pat_id.lower()] = {}
    resi_dict_det = path_loader(fn, fp)
    for dropout_lvl, arr in resi_dict_det.items():
        mean_arr = np.mean(arr, axis=0)
        full = mean_arr[0]
        zero = mean_arr[1]
        impu = mean_arr[2]
        diff_std = np.std(arr[:, 2] - arr[:, 1])
        results_dict[pat_id.lower()][dropout_lvl] = (
            full,
            zero,
            impu,
            diff_std,
        )



In [None]:
with plt.rc_context(rc=FL.get_rc(24, 6), fname=matplotlibrc_path):
    fig, axs = plt.subplots(1, len(results_dict), figsize=FL.get_grid_in_inch(24, 6))
    comp_ajile_plot(fig, axs, results_dict)
    fig.tight_layout()
    plt.show()
