In [1]:
import logging
from importlib import reload

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from himalaya.backend import set_backend
from matplotlib import axes

from compare_variance_residual.simulation import generate_dataset

set_backend("cupy", on_error="warn")

# disable warnings
import warnings

warnings.filterwarnings("ignore")

In [2]:
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.DEBUG, datefmt='%I:%M:%S')
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
plt.style.use("nord")
sns.set_theme()

In [3]:
def generic_errorbar_plot(means, variances, labels, horizontal_positions, title, suptitle, ylabel, xlabel,
                          n_samples_train, n_samples_test, ax: axes.Axes = None, fig=None):
    """Generic plot function for shared parameters between vp_errorbar and rm_errorbar."""

    if ax is None or fig is None:
        fig, ax = plt.subplots()

    yerr = [np.sqrt(var) / np.sqrt(n_samples_test) for var in variances]
    colors = ["tab:blue", "tab:orange", "tab:green", "tab:red", "tab:purple", "tab:brown"]

    # Create bar plot
    bars = ax.bar(range(len(means)), means, color=colors[:len(means)], yerr=yerr, capsize=5, ecolor="black")
    ax.bar_label(bars, fmt='{:,.2f}', label_type='center')
    ax.set_xticks(range(len(means)), labels=labels)

    # Function to plot horizontal lines
    def plot_horizontal_line(start, end, y_value, label, color):
        ax.plot([start, end], [y_value, y_value], linestyle='--', label=label, color=color)

    # Plot horizontal lines
    for i, (y_value, label, start, end) in enumerate(horizontal_positions):
        plot_horizontal_line(start, end, y_value, label, colors[i % len(colors)])

    ax.legend()

    ylims = ax.get_ylim()
    ax.set_ylim(ylims[0], 1.01)
    ax.set_ylabel(ylabel)
    ax.set_xlabel(xlabel)

    fig.suptitle(suptitle)
    ax.set_title(title)
    return ax, fig

In [4]:
feature_space_dimensions = [
    100,  # shared
    100,  # unique 0
    100,  # unique 1
]
scalars = [
    1 / 3, 1 / 3, 1 / 3
]
n_targets = 1000
n_samples_train = 10000
n_samples_test = 1000
n_samples = n_samples_train + n_samples_test
noise_scalar = 0.1

cv = 20
alphas = np.logspace(-4, 4, 10)

In [5]:
Xs, Y = generate_dataset(feature_space_dimensions, scalars, n_targets, n_samples, noise_scalar)

In [6]:
scalars_alt = [0.6, 0.3, 0.1]
scalars_alt = scalars
Xs_alt, Y_alt = generate_dataset(feature_space_dimensions, scalars_alt, n_targets, n_samples, noise_scalar,
                                 shuffle=False)

# Variance Partitioning

In [7]:
def vp_errorbar(score_0, score_1, joint_score, shared, x0_unique, x1_unique, feature_space_dimensions, scalars,
                n_targets, n_samples_train, n_samples_test, noise_scalar, ax: axes.Axes = None, fig=None):
    # Precompute means and variances
    scores = [score_0, score_1, joint_score, shared, x0_unique, x1_unique]
    means = [s.mean() for s in scores]
    variances = [s.var() for s in scores]

    bar_names = [r"$X_0$", r"$X_1$", r"$X_0 \cup X_1$", r"$X_0 \cap X_1$", r"$X_0 \setminus X_1$",
                 r"$X_1 \setminus X_0$"]

    # Horizontal line positions
    horizontal_positions = [
        (scalars[0] + scalars[1], r"$a_S + a_{U_0}$", -0.5, 0.5),
        (scalars[0] + scalars[2], r"$a_S + a_{U_1}$", 0.5, 1.5),
        (sum(scalars), r"$a_S + a_{U_0} + a_{U_1}$", 1.5, 2.5),
        (scalars[0], r"$a_S$", 2.5, 3.5),
        (scalars[1], r"$a_{U_0}$", 3.5, 4.5),
        (scalars[2], r"$a_{U_1}$", 4.5, 5.5)
    ]

    # Call the generic function
    generic_errorbar_plot(
        means=means,
        variances=variances,
        labels=bar_names,
        horizontal_positions=horizontal_positions,
        title=fr"$a_S$: {scalars[0]:.2f}, $a_{{U_0}}$: {scalars[1]:.2f}, $a_{{U_1}}$: {scalars[2]:.2f}, $|S|$: {feature_space_dimensions[0]}, $|U_0|$: {feature_space_dimensions[1]}, $|U_1|$: {feature_space_dimensions[2]}, $a_E$: {noise_scalar}",
        suptitle="Variance partitioning",
        ylabel=r"Average Variance Explained ($r^2$)",
        xlabel="Feature space",
        n_samples_test=n_samples_test,
        n_samples_train=n_samples_train,
        ax=ax,
        fig=fig
    )

In [8]:
from compare_variance_residual.variance_partitioning import variance_partitioning

In [9]:
(score_0, score_1, joint_score, shared, x0_unique, x1_unique) = variance_partitioning(
    Xs, Y, n_samples_train, cv=cv, alphas=alphas, logger=logger
)

In [10]:
(score_0_alt, score_1_alt, joint_score_alt, shared_alt, x0_unique_alt, x1_unique_alt) = variance_partitioning(
    Xs_alt, Y_alt, n_samples_train, alphas, cv, logger
)

TypeError: 'RootLogger' object is not callable

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

vp_errorbar(score_0, score_1, joint_score, shared, x0_unique, x1_unique, feature_space_dimensions, scalars, n_targets,
            n_samples_train, n_samples_test, noise_scalar, ax=axs[0], fig=fig)
axs[0].set_title("Shuffled")

vp_errorbar(score_0_alt, score_1_alt, joint_score_alt, shared_alt, x0_unique_alt, x1_unique_alt,
            feature_space_dimensions, scalars_alt, n_targets,
            n_samples_train, n_samples_test, noise_scalar, ax=axs[1], fig=fig)
axs[1].set_title("Descending")

plt.show()

# Residual Method

In [None]:
from compare_variance_residual.residual import residual_method

In [None]:
def rm_errorbar(full_score_0, full_score_1, feature_score_0, feature_score_1, residual_score_0, residual_scores_1,
                feature_space_dimensions, scalars,
                n_targets, n_samples_train, n_samples_test, noise_scalar, ax: axes.Axes = None, fig=None):
    if ax is None or fig is None:
        fig, ax = plt.subplots()

    # Precompute means and variances
    scores = [full_score_0, full_score_1, feature_score_0, feature_score_1, residual_score_0, residual_scores_1]
    means = [s.mean() for s in scores]
    variances = [s.var() for s in scores]

    bar_names = [r"$X_0$", r"$X_1$", r"$X_1 \rightarrow X_0$", r"$X_0 \rightarrow X_1$", r"$X_0 \setminus X_1$",
                 r"$X_1 \setminus X_0$"]

    horizontal_positions = [
        (scalars[0] + scalars[1], r"$a_S + a_{U_0}$", -0.5, 0.5),
        (scalars[0] + scalars[2], r"$a_S + a_{U_1}$", 0.5, 1.5),
        (feature_space_dimensions[0] / (feature_space_dimensions[0] + feature_space_dimensions[1]),
         r"$\frac{|S|}{|U_0|+|S|}$", 1.5, 2.5),
        (feature_space_dimensions[0] / (feature_space_dimensions[0] + feature_space_dimensions[2]),
         r"$\frac{|S|}{|U_1|+|S|}$", 2.5, 3.5),
        (scalars[1], r"$a_{U_0}$", 3.5, 4.5),
        (scalars[2], r"$a_{U_1}$", 4.5, 5.5)
    ]

    # Call the generic_errorbar function with the required data and parameters
    generic_errorbar_plot(
        means=means,
        variances=variances,
        labels=bar_names,
        horizontal_positions=horizontal_positions,
        title=fr"{n_targets} targets, $a_S$: {scalars[0]:.2f}, $|S|$: {feature_space_dimensions[0]}, $a_{{U_0}}$: {scalars[1]:.2f}, $|U_0|$: {feature_space_dimensions[1]}, $a_{{U_1}}$: {scalars[2]:.2f}, $|U_1|$: {feature_space_dimensions[2]}, $a_E$: {noise_scalar}",
        suptitle="Residual Method",
        ylabel=r"Average Variance Explained ($r^2$)",
        xlabel="Feature space/Model",
        n_samples_test=n_samples_test,
        n_samples_train=n_samples_train,
        ax=ax,
        fig=fig
    )

In [None]:
full_score_0, full_score_1, feature_score_0, feature_score_1, residual_score_0, residual_scores_1 = residual_method(
    Xs, Y, n_samples_train, cv=cv
)
full_score_0_alt, full_score_1_alt, feature_score_0_alt, feature_score_1_alt, residual_score_0_alt, residual_scores_1_alt = residual_method(
    Xs_alt, Y_alt, n_samples_train, cv=cv
)

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

rm_errorbar(full_score_0, full_score_1, feature_score_0, feature_score_1, residual_score_0, residual_scores_1,
            feature_space_dimensions, scalars, n_targets, n_samples_train, n_samples_test, noise_scalar, ax=axs[0],
            fig=fig)
axs[0].set_title("Shuffled")
rm_errorbar(full_score_0_alt, full_score_1_alt, feature_score_0_alt, feature_score_1_alt, residual_score_0_alt,
            residual_scores_1_alt, feature_space_dimensions, scalars_alt, n_targets, n_samples_train, n_samples_test,
            noise_scalar, ax=axs[1], fig=fig)
axs[1].set_title("Descending")

plt.show()