In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from ntd.utils.plotting_utils import (
    FigureLayout,
    basic_plotting,
    plot_overlapping_signal,
    plot_sd,
)
from ntd.train_diffusion_model import init_dataset
from ntd.utils.utils import path_loader

FL = FigureLayout(
    width_in_pt=3 * 397,
    width_grid=24,
    scale_factor=3,
    base_font_size=6,
)

matplotlibrc_path = "../matplotlibrc"


In [None]:
samples = "TODO"
cfg = "TODO"


In [None]:
train, _test = init_dataset(cfg)

samples = samples.cpu()
samples_numpy = samples.numpy()
num_samples = samples_numpy.shape[0]
raw_signal = torch.stack([dic["signal"] for dic in train])
raw_signal_numpy = raw_signal.numpy()

cond = torch.stack([dic["cond"] for dic in train])

anesthetized = torch.mean(cond, dim=(1, 2)).numpy().astype(bool)
awake = np.logical_not(anesthetized)
anesthetized_ids = np.arange(len(train))[anesthetized]
awake_ids = np.arange(len(train))[awake]

awake_colors = ["firebrick", "red", "darkred", "indianred"]
awake_color_one, awake_color_two, awake_color_three, awake_color_four = awake_colors
anes_colors = ["goldenrod", "sandybrown", "orange", "darkorange"]
anes_color_one, anes_color_two, anes_color_three, anes_color_four = anes_colors

signal_channel = 12
plot_channels = np.array([1, 4, 7, 10])
channel_one, channel_two, channel_three, channel_four = plot_channels
awake_full_colors = signal_channel * ["black"]
awake_full_colors[channel_one] = awake_color_one
awake_full_colors[channel_two] = awake_color_two
awake_full_colors[channel_three] = awake_color_three
awake_full_colors[channel_four] = awake_color_four



In [None]:
awake_id = np.random.choice(awake_ids)
print(awake_id)
with plt.rc_context(rc=FL.get_rc(3.8, 4.2), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_overlapping_signal(
        fig,
        ax,
        samples_numpy[awake_id, np.arange(signal_channel, dtype=int), :]
        + 2 * np.arange(signal_channel)[:, np.newaxis],
        colors=awake_full_colors,
    )
    basic_plotting(
        fig, ax, y_axis_visibility=False, x_axis_visibility=False, x_lim=[0, 1000]
    )
    fig.tight_layout()
    plt.show()

anes_full_colors = signal_channel * ["black"]
anes_full_colors[channel_one] = anes_color_one
anes_full_colors[channel_two] = anes_color_two
anes_full_colors[channel_three] = anes_color_three
anes_full_colors[channel_four] = anes_color_four

anesthetized_id = np.random.choice(anesthetized_ids)
print(anesthetized_id)
with plt.rc_context(rc=FL.get_rc(4, 5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_overlapping_signal(
        fig,
        ax,
        samples_numpy[anesthetized_id, np.arange(signal_channel, dtype=int), :]
        + 2 * np.arange(signal_channel)[:, np.newaxis],
        colors=anes_full_colors,
    )
    basic_plotting(
        fig,
        ax,
        y_axis_visibility=False,
        x_label="time (s)",
        x_lim=(0, 1000),
        x_ticks=(0, 500, 1000),
        x_ticklabels=(0, 0.5, 1),
    )
    fig.tight_layout()
    plt.show()


In [None]:
with plt.rc_context(rc=FL.get_rc(7.8, 4), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_overlapping_signal(
        fig,
        ax,
        samples[awake_id, plot_channels, :]
        + 2.5 * np.arange(len(plot_channels))[:, np.newaxis],
        colors=awake_colors,
    )
    basic_plotting(
        fig, ax, y_axis_visibility=False, x_axis_visibility=False, x_lim=[0, 1000]
    )
    fig.tight_layout()
    plt.show()

with plt.rc_context(rc=FL.get_rc(8, 5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_overlapping_signal(
        fig,
        ax,
        samples[anesthetized_id, plot_channels, :]
        + 2.5 * np.arange(len(plot_channels))[:, np.newaxis],
        colors=anes_colors,
    )
    basic_plotting(
        fig,
        ax,
        y_axis_visibility=False,
        x_label="time (s)",
        x_lim=(0, 1000),
        x_ticks=(0, 500, 1000),
        x_ticklabels=(0, 0.5, 1),
    )
    fig.tight_layout()
    plt.show()



In [None]:
# anesthetized vs. awake spectra

agg_function = np.median
with_quantiles = True
lower_quantile = 0.1
upper_quantile = 0.9
alpha_boundary = 1.0

print(channel_one)
with plt.rc_context(rc=FL.get_rc(6, 5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        raw_signal_numpy[anesthetized, channel_one, :],
        samples_numpy[anesthetized, channel_one, :],
        fs=1000,
        nperseg=1000,
        agg_function=agg_function,
        with_quantiles=with_quantiles,
        lower_quantile=lower_quantile,
        upper_quantile=upper_quantile,
        color_one="grey",
        color_two=anes_color_one,
        alpha_boundary=alpha_boundary,
        x_ss=slice(0, -1),
    )
    basic_plotting(fig, ax, y_label="power (a.u.)", x_label="freq (Hz)", y_ticks=[])
    fig.tight_layout()
    plt.show()
with plt.rc_context(rc=FL.get_rc(6, 4.5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        raw_signal_numpy[awake, channel_one, :],
        samples_numpy[awake, channel_one, :],
        fs=1000,
        nperseg=1000,
        agg_function=agg_function,
        with_quantiles=with_quantiles,
        lower_quantile=lower_quantile,
        upper_quantile=upper_quantile,
        color_one="grey",
        color_two=awake_color_one,
        alpha_boundary=alpha_boundary,
        x_ss=slice(0, -1),
    )
    basic_plotting(fig, ax, y_label="power (a.u.)", y_ticks=[])
    fig.tight_layout()
    plt.show()

print(channel_two)
with plt.rc_context(rc=FL.get_rc(6, 5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        raw_signal_numpy[anesthetized, channel_two, :],
        samples_numpy[anesthetized, channel_two, :],
        fs=1000,
        nperseg=1000,
        agg_function=agg_function,
        with_quantiles=with_quantiles,
        lower_quantile=lower_quantile,
        upper_quantile=upper_quantile,
        color_one="grey",
        color_two=anes_color_two,
        alpha_boundary=alpha_boundary,
        x_ss=slice(0, -1),
    )
    basic_plotting(fig, ax, x_label="freq (Hz)", y_ticks=[])
    fig.tight_layout()
    plt.show()
with plt.rc_context(rc=FL.get_rc(6, 4.5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        raw_signal_numpy[awake, channel_two, :],
        samples_numpy[awake, channel_two, :],
        fs=1000,
        nperseg=1000,
        agg_function=agg_function,
        with_quantiles=with_quantiles,
        lower_quantile=lower_quantile,
        upper_quantile=upper_quantile,
        color_one="grey",
        color_two=awake_color_two,
        alpha_boundary=alpha_boundary,
        x_ss=slice(0, -1),
    )
    basic_plotting(fig, ax, y_ticks=[])
    fig.tight_layout()
    plt.show()


print(channel_three)
with plt.rc_context(rc=FL.get_rc(6, 5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        raw_signal_numpy[anesthetized, channel_three, :],
        samples_numpy[anesthetized, channel_three, :],
        fs=1000,
        nperseg=1000,
        agg_function=agg_function,
        with_quantiles=with_quantiles,
        lower_quantile=lower_quantile,
        upper_quantile=upper_quantile,
        color_one="grey",
        color_two=anes_color_three,
        alpha_boundary=alpha_boundary,
        x_ss=slice(0, -1),
    )
    basic_plotting(fig, ax, x_label="freq (Hz)", y_ticks=[])
    fig.tight_layout()
    plt.show()
with plt.rc_context(rc=FL.get_rc(6, 4.5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        raw_signal_numpy[awake, channel_three, :],
        samples_numpy[awake, channel_three, :],
        fs=1000,
        nperseg=1000,
        agg_function=agg_function,
        with_quantiles=with_quantiles,
        lower_quantile=lower_quantile,
        upper_quantile=upper_quantile,
        color_one="grey",
        color_two=awake_color_three,
        alpha_boundary=alpha_boundary,
        x_ss=slice(0, -1),
    )
    basic_plotting(fig, ax, y_ticks=[])
    fig.tight_layout()
    plt.show()


print(channel_four)
with plt.rc_context(rc=FL.get_rc(6, 5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        raw_signal_numpy[anesthetized, channel_four, :],
        samples_numpy[anesthetized, channel_four, :],
        fs=1000,
        nperseg=1000,
        agg_function=agg_function,
        with_quantiles=with_quantiles,
        lower_quantile=lower_quantile,
        upper_quantile=upper_quantile,
        color_one="grey",
        color_two=anes_color_four,
        alpha_boundary=alpha_boundary,
        x_ss=slice(0, -1),
    )
    basic_plotting(fig, ax, x_label="freq (Hz)", y_ticks=[])
    fig.tight_layout()
    plt.show()
with plt.rc_context(rc=FL.get_rc(6, 4.5), fname=matplotlibrc_path):
    fig, ax = plt.subplots()
    plot_sd(
        fig,
        ax,
        raw_signal_numpy[awake, channel_four, :],
        samples_numpy[awake, channel_four, :],
        fs=1000,
        nperseg=1000,
        agg_function=agg_function,
        with_quantiles=with_quantiles,
        lower_quantile=lower_quantile,
        upper_quantile=upper_quantile,
        color_one="grey",
        color_two=awake_color_four,
        alpha_boundary=alpha_boundary,
        x_ss=slice(0, -1),
    )
    basic_plotting(fig, ax, y_ticks=[])
    fig.tight_layout()
    plt.show()



In [None]:
# correlation matrices
real_corrs = []
samp_corrs = []
for i in range(num_samples):
    real_corrs.append(np.corrcoef(raw_signal_numpy[i]))
    samp_corrs.append(np.corrcoef(samples_numpy[i]))

real_corrs = np.mean(np.array(real_corrs), axis=0)
samp_corrs = np.mean(np.array(samp_corrs), axis=0)


cmap_name = "plasma"
with plt.rc_context(rc=FL.get_rc(7, 7), fname=matplotlibrc_path):
    plt.imshow(real_corrs, cmap=cmap_name, vmin=-0.1, vmax=1)
    plt.colorbar()
    plt.show()

with plt.rc_context(rc=FL.get_rc(7, 7), fname=matplotlibrc_path):
    plt.imshow(samp_corrs, cmap=cmap_name, vmin=-0.1, vmax=1)
    plt.colorbar()
    plt.show()
