In [None]:
%pip install gadjid

In [None]:
!python -V

---

# Social Preference Sandbox and Analyses

## Visualizing Social Preference Model v03/31/2025

In [None]:
!pip3 install torcheeg

In [None]:
!pip3 install torch_scatter

In [None]:
# ----------------------------------------------------------------------------------------------------------------------
# FILE DESCRIPTION
# ----------------------------------------------------------------------------------------------------------------------

# File:  plot.py
# Author:  anonymous
# Date written:  01-18-2022
# Last modified:  10-25-2023

r"""
Description:
"""


# ----------------------------------------------------------------------------------------------------------------------
# IMPORT STATEMENTS
# ----------------------------------------------------------------------------------------------------------------------

# Import statements
import numpy as np
#from scipy.stats import pearsonr
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba, LinearSegmentedColormap
from matplotlib.patches import Polygon
import seaborn as sns

# Constants
R1 = 1.0  # inner radius of power plots


# ----------------------------------------------------------------------------------------------------------------------
# FUNCTION DEFINITIONS
# ----------------------------------------------------------------------------------------------------------------------

# Chord plot
def chord_plot(
    x,
    rois=None,
    freqs=None,
    freq_ticks=None,
    max_alpha=0.7,
    buffer_percent=1.0,
    outer_radius=1.2,
    min_max_quantiles=(0.5, 0.9),
    color=None,
    cmap=None,
    roi_fontsize=13,
    roi_extent=0.28,
    tick_extent=0.03,
    tick_label_extent=0.11,
    tick_label_fontsize=10.0,
    fontfamily='sans-serif',
    figsize=(7, 7)):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Check arguments
    assert x.ndim == 3
    assert x.shape[0] == x.shape[1]
    assert max_alpha >= 0.0 and max_alpha <= 1.0, f"{max_alpha}"
    n_roi, n_freq = x.shape[1:]
    assert freqs is None or len(freqs) == n_freq, f"{len(freqs)} != {n_freq}"

    # Replace ROI underscores with spaces
    if rois is not None:
        assert len(rois) == n_roi, f"{len(rois)} != {n_roi}"
        pretty_rois = [roi.replace("_", " ") for roi in rois]

    # Default color
    if color is None and cmap is None:
        color = 'tab:blue'

    # Set color to None if color map is provided
    if cmap is not None:
        color = None

    #  variables
    r2 = outer_radius
    center_angles = np.linspace(0, 2 * np.pi, n_roi + 1)
    buffer = buffer_percent / 100.0 * 2.0 * np.pi
    start_angles = center_angles[:-1] + buffer
    stop_angles = center_angles[1:] - buffer
    freq_diff = (stop_angles[0] - start_angles[0]) / (n_freq + 1)
    min_val, max_val = np.quantile(x, min_max_quantiles)
    x = max_alpha * np.clip((x - min_val) / (max_val - min_val), 0.0, 1.0)

    # Set up axes and labels and ticks
    _, ax = _set_up_chord_plot(
        start_angles=start_angles,
        stop_angles=stop_angles,
        r1=R1,
        r2=r2,
        pretty_rois=pretty_rois,
        freqs=freqs,
        freq_ticks=freq_ticks,
        tick_extent=tick_extent,
        tick_label_extent=tick_label_extent,
        tick_label_fontsize=tick_label_fontsize,
        roi_fontsize=roi_fontsize,
        roi_extent=roi_extent,
        fontfamily=fontfamily,
        figsize=figsize)

    # Add the power and chord plots
    _update_chord_plot(
        x=x,
        ax=ax,
        start_angles=start_angles,
        stop_angles=stop_angles,
        freq_diff=freq_diff,
        outer_radius=outer_radius,
        color=color,
        cmap=cmap)

    # Return plot axis
    return ax


# Update chord plot
def _update_chord_plot(
    x,
    ax,
    start_angles,
    stop_angles,
    freq_diff,
    outer_radius,
    color,
    cmap=None):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Initialize variables
    r2 = outer_radius
    handles = []
    n_roi, n_freq = x.shape[1:]

    # Create colormap array for different frequencies
    if cmap is not None:
        if isinstance(cmap, str):
            cmap_idx = np.linspace(0, 1, n_freq)
            # cmap_color_arr = mpl.colormaps[cmap](cmap_idx)[:, :3]
            cmap_color_arr = mpl.cm.get_cmap(cmap)(cmap_idx)[:, :3]
        elif isinstance(cmap, LinearSegmentedColormap):
            cmap_idx = np.linspace(0, 1, n_freq)
            cmap_color_arr = cmap(cmap_idx)[:, :3]
        else:
            cmap_color_arr = None
            color = 'b'
    else:
        cmap_color_arr = None

    # Draw the power plots
    for i, (c1, c2) in enumerate(zip(start_angles, stop_angles)):
        # Iterate over frequencies
        for j in range(n_freq):
            # Retrieve colormap color for corresponding frequency value
            if cmap_color_arr is not None:
                color = cmap_color_arr[j]

            # Generate arc power patch
            if x[i, i, j] > 0:
                # Arc power patch rotation
                diff1 = j * (c2 - c1) / n_freq
                diff2 = (j + 1) * (c2 - c1) / n_freq

                # Transparency value
                alpha = x[i, i, j]

                # Append arc power patch to handles
                h = _arc_patch(
                    r1=R1,
                    r2=r2,
                    theta1=c1 + diff1,
                    theta2=c1 + diff2,
                    ax=ax,
                    color=color,
                    cmap=cmap,
                    n=5,
                    alpha=alpha)
                handles.append(h)

    # Draw the chords to represent cross-power
    print("WARNING - CROSS-POWER PLOT HAS BEEN OVER-WRITTEN TO REPRESENT DIRECTIONAL ADJACENCY MATRICES!!!")
    for i in range(n_roi):# - 1):
        for j in range(n_roi):#i + 1, n_roi):
            # Iterate over frequency values
            for k in range(n_freq):
                # Frequency-specific color map color
                if cmap_color_arr is not None:
                    color = cmap_color_arr[k]
                if i >= j:
                    color = 'b'
                else:
                    color = 'r'

                # Generate chord connection
                if x[i, j, k] > 0.0:
                    # Chord connection rotation
                    theta1 = start_angles[i] + freq_diff * k
                    theta2 = start_angles[j] + freq_diff * k

                    # Transparency value
                    alpha = x[i, j, k]

                    # Append chord connection polygon to handles
                    h = _plot_poly_chord(
                        theta1=theta1,
                        theta2=theta2,
                        diff=freq_diff,
                        ax=ax,
                        color=color,
                        alpha=alpha)
                    handles.append(h)

    # Return collection of power arc and chord connection handles
    return handles


# Set up chord plot
def _set_up_chord_plot(
    start_angles,
    stop_angles,
    r1,
    r2,
    pretty_rois,
    freqs,
    freq_ticks,
    tick_extent,
    tick_label_extent,
    tick_label_fontsize,
    roi_fontsize,
    roi_extent,
    fontfamily,
    figsize):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Initialize figure
    fig = plt.figure(figsize=figsize)
    ax = plt.gca()

    # Set up axes and draw power plots
    for i, (c1, c2) in enumerate(zip(start_angles, stop_angles)):
        # Draw power axis
        _draw_power_axis(
            r1=r1,
            r2=r2,
            theta1=c1,
            theta2=c2,
            ax=ax)

        # Plot ticks
        if freqs is not None and freq_ticks is not None:
            _plot_ticks(
                r=r2,
                theta1=c1,
                theta2=c2,
                ax=ax,
                freqs=freqs,
                freq_ticks=freq_ticks,
                tick_extent=tick_extent,
                tick_label_extent=tick_label_extent,
                tick_label_fontsize=tick_label_fontsize,
                fontfamily=fontfamily)

        # Annotate ROIs
        if pretty_rois is not None:
            _plot_roi_name(
                r=r2,
                theta=0.5 * (c1 + c2),
                ax=ax,
                roi=pretty_rois[i],
                extent=roi_extent,
                fontsize=roi_fontsize,
                fontfamily=fontfamily)

    # Axis limits
    ax.set_ylim(-1.5, 1.5)
    ax.set_xlim(-1.5, 1.5)
    plt.axis("off")

    # Return figure and axis variables
    return fig, ax


# Plot poly chord
def _plot_poly_chord(
    theta1,
    theta2,
    diff,
    ax,
    color,
    n=50,
    alpha=0.5):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Chord chonnection points
    points1 = _chord_helper(theta1, theta2, n=n)
    rot_mat = np.array([[np.cos(diff), -np.sin(diff)], [np.sin(diff), np.cos(diff)]])
    points2 = rot_mat @ points1
    points = np.concatenate([points1, points2[:, ::-1]], axis=1).T

    # Chord connection polygon
    poly = Polygon(points, closed=True, fc=to_rgba(c=color, alpha=alpha))
    ax.add_patch(poly)

    # Return chord connection polygon
    return poly


# Chord helper
def _chord_helper(theta1, theta2, n=50):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Chord helper coordinates calculations
    a1, a2 = np.cos(theta1), np.sin(theta1)
    b1, b2 = np.cos(theta2), np.sin(theta2)
    denom = a1 * b2 - a2 * b1
    if np.abs(denom) < 1e-5:
        xs = np.linspace(a1, b1, n)
        ys = np.linspace(a2, b2, n)

        return np.vstack([xs, ys])
    v, w = 2.0 * (a2 - b2) / denom, 2.0 * (b1 - a1) / denom
    center = (-v / 2.0, -w / 2.0)
    radius = np.sqrt(((v ** 2.0 + w ** 2.0) / 4.0) - 1.0)
    angle1 = np.arctan2(a2 - center[1], a1 - center[0])
    angle2 = np.arctan2(b2 - center[1], b1 - center[0])
    angle1, angle2 = min(angle1, angle2), max(angle1, angle2)
    if angle2 - angle1 > np.pi:
        angle1, angle2 = angle2, angle1 + 2 * np.pi
    theta = np.linspace(angle1, angle2, n)
    xs = radius * np.cos(theta) + center[0]
    ys = radius * np.sin(theta) + center[1]

    # Return coordinates
    return np.vstack([xs, ys])


# Arc patch
def _arc_patch(
    r1,
    r2,
    theta1,
    theta2,
    ax,
    color,
    cmap=None,
    n=50,
    alpha=1.0,
    **kwargs):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Power arc points
    thetas = np.linspace(theta1, theta2, n)
    sin_thetas, cos_thetas = np.sin(thetas), np.cos(thetas)
    points = np.vstack([cos_thetas, sin_thetas]).T
    points = np.concatenate([r1 * points, r2 * points[::-1]], axis=0)

    # Power arc polygon
    poly = Polygon(
        points,
        closed=True,
        fc=to_rgba(color, alpha=alpha),
        **kwargs)
    ax.add_patch(poly)

    # Return power arc polygon
    return poly


# Draw power axis
def _draw_power_axis(
    r1,
    r2,
    theta1,
    theta2,
    ax,
    n=50,
    **kwargs):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Power axis points
    thetas = np.linspace(theta1, theta2, n)
    sin_thetas, cos_thetas = np.sin(thetas), np.cos(thetas)
    points = np.vstack([cos_thetas, sin_thetas]).T
    points = np.concatenate([r1 * points, r2 * points[::-1]], axis=0)
    points = np.concatenate([points, points[:1]], axis=0)

    # Power axis handle
    handle = ax.plot(points[:, 0], points[:, 1], c='k', **kwargs)

    # Return handle
    return handle


# Plot ticks
def _plot_ticks(
    r,
    theta1,
    theta2,
    ax,
    freqs,
    freq_ticks,
    tick_extent=0.03,
    tick_label_extent=0.11,
    tick_label_fontsize=10.0,
    n=5,
    fontfamily='sans-serif',
    **kwargs):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Tick offset
    offset = 0.0 if np.cos((theta1 + theta2) / 2.0) > 0.0 else 180.0

    # Iterate over frequency tick values
    for freq in freq_ticks:
        # Tick location and rotation
        theta = theta1 + (theta2 - theta1) * (freq - freqs[0]) / (freqs[-1] - freqs[0])
        x = [r * np.cos(theta), (r + tick_extent) * np.cos(theta)]
        y = [r * np.sin(theta), (r + tick_extent) * np.sin(theta)]

        # Plot tick
        ax.plot(x, y, c="k", **kwargs)

        # Tick label location and rotation
        x = (r + tick_label_extent) * np.cos(theta)
        y = (r + tick_label_extent) * np.sin(theta)
        rotation = (theta * 180.0 / np.pi) + offset

        # Tick text / label
        ax.text(
            x=x,
            y=y,
            s=str(freq),
            rotation=rotation,
            fontfamily=fontfamily,
            fontsize=tick_label_fontsize,
            ha='center',
            va='center')


# Plot ROI name
def _plot_roi_name(
    r,
    theta,
    ax,
    roi,
    extent=0.3,
    fontsize=13,
    fontfamily='sans-serif'):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # ROI location and rotation
    x, y = (r + extent) * np.cos(theta), (r + extent) * np.sin(theta)
    rotation = (theta * 180.0 / np.pi) - 90.0

    # Offset rotation
    if np.sin(theta) < 0.0:
        rotation += 180.0

    # ROI text
    ax.text(
        x=x,
        y=y,
        s=roi,
        rotation=rotation,
        ha='center',
        va='center',
        fontfamily=fontfamily,
        fontsize=fontsize)




In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt
import pickle as pkl

from general_utils.misc import get_topk_graph_mask

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title


factor_names = ["Social Preference (SP)", "Object Preference (OP)",
                "UNKNOWN 1 (U1)", "UNKNOWN 2 (U2)", "UNKNOWN 3 (U3)",
                "UNKNOWN 4 (U4)", "UNKNOWN 5 (U5)", "UNKNOWN 6 (U6)",
                "UNKNOWN 7 (U7)", "UNKNOWN 8 (U8)", "UNKNOWN 9 (U9)",
                "UNKNOWN 10 (U10)", "UNKNOWN 11 (U11)", "UNKNOWN 12 (U12)",
                "UNKNOWN 13 (U13)", "UNKNOWN 14 (U14)", "UNKNOWN 15 (U15)",
                "UNKNOWN 16 (U16)", ]
channel_names = ['Amy_BLA', 'Amy_CeA', 'Cg_Cx_R', 'Hipp', 'NAc_Core', 'NAc_Shell', 'PrL_Cx_R', 'VTA_L', 'VTA_R']
model0 = torch.load("final_best_model_FOLD0.bin", map_location=torch.device('cpu'), weights_only=False)
model1 = torch.load("final_best_model_FOLD1.bin", map_location=torch.device('cpu'), weights_only=False)
model2 = torch.load("final_best_model_FOLD2.bin", map_location=torch.device('cpu'), weights_only=False)
model3 = torch.load("final_best_model_FOLD3.bin", map_location=torch.device('cpu'), weights_only=False)
model4 = torch.load("final_best_model_FOLD4.bin", map_location=torch.device('cpu'), weights_only=False)
curr_gc_factor_ests0 = [x.detach().numpy() for x in model0.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_gc_factor_ests0 = [x/np.max(x) for x in curr_gc_factor_ests0]
curr_gc_factor_ests1 = [x.detach().numpy() for x in model1.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_gc_factor_ests1 = [x/np.max(x) for x in curr_gc_factor_ests1]
curr_gc_factor_ests2 = [x.detach().numpy() for x in model2.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_gc_factor_ests2 = [x/np.max(x) for x in curr_gc_factor_ests2]
curr_gc_factor_ests3 = [x.detach().numpy() for x in model3.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_gc_factor_ests3 = [x/np.max(x) for x in curr_gc_factor_ests3]
curr_gc_factor_ests4 = [x.detach().numpy() for x in model4.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_gc_factor_ests4 = [x/np.max(x) for x in curr_gc_factor_ests4]


curr_gc_factor_ests = [x0+x1+x2+x3+x4 for (x0,x1,x2,x3,x4) in zip(curr_gc_factor_ests0, curr_gc_factor_ests1, curr_gc_factor_ests2, curr_gc_factor_ests3, curr_gc_factor_ests4)]
curr_gc_factor_ests = [x/5. for x in curr_gc_factor_ests]


for i in range(len(curr_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()



print("\n\n\n # OFF-DIAGONAL VISUALIZATIONS ##########################################################################################################")

curr_offDiag_gc_factor_ests = [x - x*np.expand_dims(np.eye(x.shape[0]), axis=2) for x in curr_gc_factor_ests]
curr_offDiag_gc_factor_ests = [x/np.max(x) for x in curr_offDiag_gc_factor_ests]

for i in range(len(curr_offDiag_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()

### Identifying Social Preference Model

In [None]:
import numpy as np
from matplotlib import pyplot as plt

#2-factor
nk2_perfs = [4.001132041613262,
3.4969392490386957,
3.559678573608399,
3.870644167264303,
3.327145659128825, ]

#4-factor
nk4_perfs = [3.6608014297485347,
3.8341940339406326,
3.976153351465861,
3.234516398111979,
3.4700066741307576, ]

#6-factor
nk6_perfs = [8.950920476913453,
10.350576349894206,
10.162520256042482,
9.243894300460816,
10.51370607058207, ]

#9-factor
nk9_perfs = [3.516380707422892,
3.3691013908386225,
4.0688151931762695,
3.250964085261027,
3.072849206924438, ]

#18-factor
nk18_perfs = [3.4153781859079997,
3.2008409102757773,
2.9720478153228758,
3.280436124801636,
3.933880195617675, ]

#36-factor
nk36_perfs = [3.753356472651164,
4.154083093007405,
3.461827759742737,
3.7385053253173837,
3.3300649420420334, ]


num_factors_tested = [2, 4, 6, 9, 18, 36]
means_perf = [np.mean(x) for x in [nk2_perfs, nk4_perfs, nk6_perfs, nk9_perfs, nk18_perfs, nk36_perfs]]
plt.plot(num_factors_tested, means_perf)
plt.scatter(num_factors_tested, means_perf)
plt.xlabel("Number of Factors Tested")
plt.ylabel("Mean Performance")
plt.title("Social Preference Factor Selection")
plt.show()



num_factors_tested = [2, 4, 9, 18, 36]
means_perf = [np.mean(x) for x in [nk2_perfs, nk4_perfs, nk9_perfs, nk18_perfs, nk36_perfs]]
plt.plot(num_factors_tested, means_perf)
plt.scatter(num_factors_tested, means_perf)
plt.xlabel("Number of Factors Tested")
plt.ylabel("Mean Performance")
plt.title("Social Preference Factor Selection")
plt.show()

In [None]:
import numpy as np
naive_adjacency_tensor = np.array([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]])
print(naive_adjacency_tensor.shape)

print(naive_adjacency_tensor[:9,:9,:])

In [None]:
import random
import numpy as np
import pickle as pkl

np.random.seed(42)
random.seed(42)

unique_mice_names = [
    'Mouse8881', 'Mouse6991', 'Mouse5333', 'Mouse699L', 'Mouse6664', 'Mouse8884',
    'Mouse8893', 'Mouse6990', 'Mouse0641', 'Mouse6992', 'Mouse8891', 'Mouse0642',
    'Mouse0643', 'Mouse533L', 'Mouse0631', 'Mouse0632', 'Mouse0640', 'Mouse6662',
    'Mouse0644', 'Mouse0633', 'Mouse6674', 'Mouse0630', 'Mouse8882', 'Mouse8894',
    'Mouse5321', 'Mouse0634', 'Mouse5331', 'Mouse5332',
]
keys_of_interest = [
    'Amy_BLA_03', 'Amy_CeA_01', 'Amy_CeA_02', 'Cg_Cx_R_01', 'Hipp_01', 'NAc_Core_01',
    'NAc_Core_02', 'NAc_Shell_01', 'NAc_Shell_02', 'NAc_Shell_03', 'NAc_Shell_04',
    'PrL_Cx_R_01', 'VTA_L_01', 'VTA_L_02', 'VTA_R_01', 'VTA_R_02'
]
print("len(unique_mice_names) == ", len(unique_mice_names))
print("len(keys_of_interest) == ", len(keys_of_interest))

NUM_HOLDOUT_MICE = 8
NUM_TRAIN_MICE = 16
NUM_VAL_MICE = 4
NUM_CV_FOLDS = 5

print("unique_mice_names == ", unique_mice_names)
random.shuffle(unique_mice_names)
print("post-shuffle unique_mice_names == ", unique_mice_names)

holdout_mice = unique_mice_names[-1*NUM_HOLDOUT_MICE:]
mice_for_cv_exps = unique_mice_names[:-1*NUM_HOLDOUT_MICE]
print("holdout_mice == ", holdout_mice)
print("mice_for_cv_exps == ", mice_for_cv_exps)

for i in range(NUM_CV_FOLDS):
    print("i == ", i)
    val_mice = mice_for_cv_exps[i*NUM_VAL_MICE:(i+1)*NUM_VAL_MICE]
    train_mice = [mouse for mouse in mice_for_cv_exps if mouse not in val_mice]
    print("\t val_mice == ", val_mice)
    print("\t train_mice == ", train_mice)

In [None]:
proper_s_label_files = [
    'Mouse8881_100617_SocialPreference_Class.mat', 'Mouse699L_061616_SocialPreference_Class.mat', 'Mouse6992_061416_SocialPreference_Class.mat', 'Mouse6662_032618_SocialPreference_Class.mat', 'Mouse8894_100517_SocialPreference_Class.mat', 'Mouse8881_092917_SocialPreference_Class.mat', 'Mouse8893_090917_SocialPreference_Class.mat', 'Mouse8884_100617_SocialPreference_Class.mat', 'Mouse6992_060916_SocialPreference_Class.mat', 'Mouse5321_052616_SocialPreference_Class.mat', 'Mouse6662_040218_SocialPreference_Class.mat', 'Mouse0632_092517_SocialPreference_Class.mat', 'Mouse8882_092217_SocialPreference_Class.mat', 'Mouse6674_041118_SocialPreference_Class.mat', 'Mouse0634_092017_SocialPreference_Class.mat', 'Mouse8882_100217_SocialPreference_Class.mat', 'Mouse8894_091417_SocialPreference_Class.mat', 'Mouse0633_100417_SocialPreference_Class.mat', 'Mouse0631_100617_SocialPreference_Class.mat', 'Mouse0643_091917_SocialPreference_Class.mat', 'Mouse0632_091817_SocialPreference_Class.mat', 'Mouse0634_100917_SocialPreference_Class.mat', 'Mouse533L_070716_SocialPreference_Class.mat', 'Mouse6662_040418_SocialPreference_Class.mat', 'Mouse5333_060916_SocialPreference_Class.mat', 'Mouse6992_052416_SocialPreference_Class.mat', 'Mouse5331_070716_SocialPreference_Class.mat', 'Mouse6662_040918_SocialPreference_Class.mat', 'Mouse5333_061416_SocialPreference_Class.mat', 'Mouse6664_040218_SocialPreference_Class.mat', 'Mouse0634_092917_SocialPreference_Class.mat', 'Mouse6991_060216_SocialPreference_Class.mat', 'Mouse8893_091417_SocialPreference_Class.mat', 'Mouse0640_100517_SocialPreference_Class.mat', 'Mouse5321_060716_SocialPreference_Class.mat', 'Mouse0644_101017_SocialPreference_Class.mat', 'Mouse0631_092017_SocialPreference_Class.mat', 'Mouse8893_091617_SocialPreference_Class.mat', 'Mouse6664_040618_SocialPreference_Class.mat', 'Mouse0632_092917_SocialPreference_Class.mat', 'Mouse533L_052616_SocialPreference_Class.mat', 'Mouse8884_100417_SocialPreference_Class.mat', 'Mouse0631_092217_SocialPreference_Class.mat', 'Mouse5321_060216_SocialPreference_Class.mat', 'Mouse0643_091417_SocialPreference_Class.mat', 'Mouse8882_091117_SocialPreference_Class.mat', 'Mouse8893_092617_SocialPreference_Class.mat', 'Mouse8884_092917_SocialPreference_Class.mat', 'Mouse6662_033018_SocialPreference_Class.mat', 'Mouse0632_100617_SocialPreference_Class.mat', 'Mouse8891_100317_SocialPreference_Class.mat', 'Mouse5321_060916_SocialPreference_Class.mat', 'Mouse6674_041618_SocialPreference_Class.mat', 'Mouse6662_041118_SocialPreference_Class.mat', 'Mouse0641_091917_SocialPreference_Class.mat', 'Mouse533L_060216_SocialPreference_Class.mat', 'Mouse6664_032618_SocialPreference_Class.mat', 'Mouse0642_091617_SocialPreference_Class.mat', 'Mouse8893_093017_SocialPreference_Class.mat', 'Mouse6992_060716_SocialPreference_Class.mat', 'Mouse6991_070716_SocialPreference_Class.mat', 'Mouse8884_091117_SocialPreference_Class.mat', 'Mouse0644_093017_SocialPreference_Class.mat', 'Mouse0640_100717_SocialPreference_Class.mat', 'Mouse8891_090917_SocialPreference_Class.mat', 'Mouse5332_060216_SocialPreference_Class.mat', 'Mouse0642_093017_SocialPreference_Class.mat', 'Mouse0634_091817_SocialPreference_Class.mat', 'Mouse6990_060916_SocialPreference_Class.mat', 'Mouse0634_092517_SocialPreference_Class.mat', 'Mouse0644_100517_SocialPreference_Class.mat', 'Mouse6992_053116_SocialPreference_Class.mat', 'Mouse8891_091617_SocialPreference_Class.mat', 'Mouse8884_092217_SocialPreference_Class.mat', 'Mouse5321_061416_SocialPreference_Class.mat', 'Mouse8881_092217_SocialPreference_Class.mat', 'Mouse0633_092017_SocialPreference_Class.mat', 'Mouse0633_092917_SocialPreference_Class.mat', 'Mouse0640_092617_SocialPreference_Class.mat', 'Mouse0632_100417_SocialPreference_Class.mat', 'Mouse8891_092617_SocialPreference_Class.mat', 'Mouse533L_061416_SocialPreference_Class.mat', 'Mouse0633_092217_SocialPreference_Class.mat', 'Mouse6674_040418_SocialPreference_Class.mat', 'Mouse0631_100217_SocialPreference_Class.mat', 'Mouse0632_091517_SocialPreference_Class.mat', 'Mouse0631_100417_SocialPreference_Class.mat', 'Mouse6674_040618_SocialPreference_Class.mat', 'Mouse0633_091517_SocialPreference_Class.mat', 'Mouse6990_060716_SocialPreference_Class.mat', 'Mouse8881_091117_SocialPreference_Class.mat', 'Mouse0631_091817_SocialPreference_Class.mat', 'Mouse0642_092117_SocialPreference_Class.mat', 'Mouse6674_040218_SocialPreference_Class.mat', 'Mouse5321_061616_SocialPreference_Class.mat', 'Mouse8893_091217_SocialPreference_Class.mat', 'Mouse0634_092217_SocialPreference_Class.mat', 'Mouse8882_092517_SocialPreference_Class.mat', 'Mouse5331_070516_SocialPreference_Class.mat', 'Mouse6990_061416_SocialPreference_Class.mat', 'Mouse8882_092917_SocialPreference_Class.mat', 'Mouse0640_091417_SocialPreference_Class.mat', 'Mouse0642_100717_SocialPreference_Class.mat', 'Mouse5331_060716_SocialPreference_Class.mat', 'Mouse6990_052616_SocialPreference_Class.mat', 'Mouse6991_052416_SocialPreference_Class.mat', 'Mouse0644_092317_SocialPreference_Class.mat', 'Mouse6674_032618_SocialPreference_Class.mat', 'Mouse6662_041618_SocialPreference_Class.mat', 'Mouse8884_091517_SocialPreference_Class.mat', 'Mouse8891_093017_SocialPreference_Class.mat', 'Mouse0633_100917_SocialPreference_Class.mat', 'Mouse0643_092617_SocialPreference_Class.mat', 'Mouse8894_091617_SocialPreference_Class.mat', 'Mouse0630_091517_SocialPreference_Class.mat', 'Mouse8893_092317_SocialPreference_Class.mat', 'Mouse8882_100417_SocialPreference_Class.mat', 'Mouse0630_100217_SocialPreference_Class.mat', 'Mouse0640_092317_SocialPreference_Class.mat', 'Mouse6990_061616_SocialPreference_Class.mat', 'Mouse8882_091317_SocialPreference_Class.mat', 'Mouse0632_100217_SocialPreference_Class.mat', 'Mouse0630_092217_SocialPreference_Class.mat', 'Mouse0642_091917_SocialPreference_Class.mat', 'Mouse8884_100217_SocialPreference_Class.mat', 'Mouse6992_061616_SocialPreference_Class.mat', 'Mouse0644_100717_SocialPreference_Class.mat', 'Mouse0642_091417_SocialPreference_Class.mat', 'Mouse0641_100517_SocialPreference_Class.mat', 'Mouse0632_092017_SocialPreference_Class.mat', 'Mouse0631_092517_SocialPreference_Class.mat', 'Mouse8894_093017_SocialPreference_Class.mat', 'Mouse0631_092917_SocialPreference_Class.mat', 'Mouse699L_060216_SocialPreference_Class.mat', 'Mouse6991_060716_SocialPreference_Class.mat', 'Mouse0633_100217_SocialPreference_Class.mat', 'Mouse8884_092517_SocialPreference_Class.mat', 'Mouse0633_100617_SocialPreference_Class.mat', 'Mouse6674_041318_SocialPreference_Class.mat', 'Mouse5332_052616_SocialPreference_Class.mat', 'Mouse8891_091417_SocialPreference_Class.mat', 'Mouse6664_040918_SocialPreference_Class.mat', 'Mouse699L_061416_SocialPreference_Class.mat', 'Mouse8891_100517_SocialPreference_Class.mat', 'Mouse0641_100317_SocialPreference_Class.mat', 'Mouse0634_091517_SocialPreference_Class.mat', 'Mouse6662_040618_SocialPreference_Class.mat', 'Mouse0630_100617_SocialPreference_Class.mat', 'Mouse8894_092617_SocialPreference_Class.mat', 'Mouse8894_092317_SocialPreference_Class.mat', 'Mouse6992_052616_SocialPreference_Class.mat', 'Mouse8882_090817_SocialPreference_Class.mat', 'Mouse8893_100517_SocialPreference_Class.mat', 'Mouse5333_060716_SocialPreference_Class.mat', 'Mouse533L_070516_SocialPreference_Class.mat', 'Mouse8884_090817_SocialPreference_Class.mat', 'Mouse8882_100617_SocialPreference_Class.mat', 'Mouse0634_100617_SocialPreference_Class.mat', 'Mouse0630_092517_SocialPreference_Class.mat', 'Mouse0632_092217_SocialPreference_Class.mat', 'Mouse0642_100317_SocialPreference_Class.mat', 'Mouse5333_070516_SocialPreference_Class.mat', 'Mouse699L_053116_SocialPreference_Class.mat', 'Mouse8894_100717_SocialPreference_Class.mat', 'Mouse8881_100217_SocialPreference_Class.mat', 'Mouse0641_091417_SocialPreference_Class.mat', 'Mouse0630_091817_SocialPreference_Class.mat', 'Mouse6991_052616_SocialPreference_Class.mat', 'Mouse0633_091817_SocialPreference_Class.mat', 'Mouse0642_101017_SocialPreference_Class.mat', 'Mouse0641_092117_SocialPreference_Class.mat', 'Mouse0641_092617_SocialPreference_Class.mat', 'Mouse5332_061416_SocialPreference_Class.mat', 'Mouse699L_070516_SocialPreference_Class.mat', 'Mouse0641_100717_SocialPreference_Class.mat', 'Mouse0632_100917_SocialPreference_Class.mat', 'Mouse0640_091617_SocialPreference_Class.mat', 'Mouse5333_060216_SocialPreference_Class.mat', 'Mouse699L_052616_SocialPreference_Class.mat', 'Mouse5332_052416_SocialPreference_Class.mat', 'Mouse6664_040418_SocialPreference_Class.mat', 'Mouse0641_091617_SocialPreference_Class.mat', 'Mouse8881_091317_SocialPreference_Class.mat', 'Mouse0643_092317_SocialPreference_Class.mat', 'Mouse5332_070716_SocialPreference_Class.mat', 'Mouse6992_060216_SocialPreference_Class.mat', 'Mouse0633_092517_SocialPreference_Class.mat', 'Mouse6674_040918_SocialPreference_Class.mat', 'Mouse6662_041318_SocialPreference_Class.mat', 'Mouse6991_070516_SocialPreference_Class.mat', 'Mouse0634_100417_SocialPreference_Class.mat', 'Mouse533L_060716_SocialPreference_Class.mat', 'Mouse8891_092317_SocialPreference_Class.mat', 'Mouse5321_053116_SocialPreference_Class.mat', 'Mouse0643_092117_SocialPreference_Class.mat', 'Mouse0633_091417_SocialPreference_Class.mat', 'Mouse8881_090817_SocialPreference_Class.mat', 'Mouse8893_100317_SocialPreference_Class.mat', 'Mouse0644_091617_SocialPreference_Class.mat', 'Mouse5331_061616_SocialPreference_Class.mat', 'Mouse0632_091417_SocialPreference_Class.mat', 'Mouse0642_092317_SocialPreference_Class.mat', 'Mouse0631_091417_SocialPreference_Class.mat', 'Mouse0642_100517_SocialPreference_Class.mat', 'Mouse0634_100217_SocialPreference_Class.mat', 'Mouse0630_100917_SocialPreference_Class.mat', 'Mouse0643_100317_SocialPreference_Class.mat', 'Mouse0630_091417_SocialPreference_Class.mat', 'Mouse0643_100717_SocialPreference_Class.mat', 'Mouse5333_061616_SocialPreference_Class.mat', 'Mouse0643_093017_SocialPreference_Class.mat', 'Mouse0644_092117_SocialPreference_Class.mat', 'Mouse6992_070716_SocialPreference_Class.mat', 'Mouse5333_070716_SocialPreference_Class.mat', 'Mouse0630_100417_SocialPreference_Class.mat', 'Mouse6664_041118_SocialPreference_Class.mat', 'Mouse8884_091317_SocialPreference_Class.mat', 'Mouse0634_091417_SocialPreference_Class.mat', 'Mouse0641_101017_SocialPreference_Class.mat', 'Mouse0644_100317_SocialPreference_Class.mat', 'Mouse8894_100317_SocialPreference_Class.mat', 'Mouse533L_053116_SocialPreference_Class.mat', 'Mouse0644_091917_SocialPreference_Class.mat', 'Mouse5331_061416_SocialPreference_Class.mat', 'Mouse5332_061616_SocialPreference_Class.mat', 'Mouse8894_090917_SocialPreference_Class.mat', 'Mouse5332_053116_SocialPreference_Class.mat', 'Mouse8893_100717_SocialPreference_Class.mat', 'Mouse6991_061616_SocialPreference_Class.mat', 'Mouse0630_092017_SocialPreference_Class.mat', 'Mouse8891_091217_SocialPreference_Class.mat', 'Mouse8881_091517_SocialPreference_Class.mat', 'Mouse5331_052416_SocialPreference_Class.mat', 'Mouse0640_092117_SocialPreference_Class.mat', 'Mouse699L_052416_SocialPreference_Class.mat', 'Mouse5331_053116_SocialPreference_Class.mat', 'Mouse8882_091517_SocialPreference_Class.mat', 'Mouse0640_093017_SocialPreference_Class.mat', 'Mouse8891_100717_SocialPreference_Class.mat', 'Mouse0631_100917_SocialPreference_Class.mat', 'Mouse0640_101017_SocialPreference_Class.mat', 'Mouse0640_100317_SocialPreference_Class.mat', 'Mouse0641_092317_SocialPreference_Class.mat', 'Mouse0641_093017_SocialPreference_Class.mat', 'Mouse6664_041318_SocialPreference_Class.mat', 'Mouse6664_041618_SocialPreference_Class.mat', 'Mouse8894_091217_SocialPreference_Class.mat', 'Mouse8881_092517_SocialPreference_Class.mat', 'Mouse5333_052616_SocialPreference_Class.mat', 'Mouse0640_091917_SocialPreference_Class.mat', 'Mouse5333_052416_SocialPreference_Class.mat', 'Mouse6991_061416_SocialPreference_Class.mat', 'Mouse6991_060916_SocialPreference_Class.mat', 'Mouse5332_070516_SocialPreference_Class.mat', 'Mouse0644_092617_SocialPreference_Class.mat', 'Mouse8881_100417_SocialPreference_Class.mat', 'Mouse0631_091517_SocialPreference_Class.mat', 'Mouse0644_091417_SocialPreference_Class.mat', 'Mouse5332_060716_SocialPreference_Class.mat', 'Mouse0630_092917_SocialPreference_Class.mat', 'Mouse6991_053116_SocialPreference_Class.mat', 'Mouse0643_100517_SocialPreference_Class.mat', 'Mouse0643_091617_SocialPreference_Class.mat', 'Mouse0642_092617_SocialPreference_Class.mat', 'Mouse6992_070516_SocialPreference_Class.mat', 'Mouse0643_101017_SocialPreference_Class.mat', 'Mouse699L_070716_SocialPreference_Class.mat'
]
proper_o_label_files = [
    'Mouse8881_100617_SocialPreference_Class.mat', 'Mouse699L_061616_SocialPreference_Class.mat', 'Mouse6992_061416_SocialPreference_Class.mat', 'Mouse6662_032618_SocialPreference_Class.mat', 'Mouse8894_100517_SocialPreference_Class.mat', 'Mouse8881_092917_SocialPreference_Class.mat', 'Mouse8893_090917_SocialPreference_Class.mat', 'Mouse8884_100617_SocialPreference_Class.mat', 'Mouse6992_060916_SocialPreference_Class.mat', 'Mouse5321_052616_SocialPreference_Class.mat', 'Mouse6662_040218_SocialPreference_Class.mat', 'Mouse0632_092517_SocialPreference_Class.mat', 'Mouse8882_092217_SocialPreference_Class.mat', 'Mouse6674_041118_SocialPreference_Class.mat', 'Mouse0634_092017_SocialPreference_Class.mat', 'Mouse8882_100217_SocialPreference_Class.mat', 'Mouse8894_091417_SocialPreference_Class.mat', 'Mouse0633_100417_SocialPreference_Class.mat', 'Mouse0631_100617_SocialPreference_Class.mat', 'Mouse0643_091917_SocialPreference_Class.mat', 'Mouse0632_091817_SocialPreference_Class.mat', 'Mouse0634_100917_SocialPreference_Class.mat', 'Mouse533L_070716_SocialPreference_Class.mat', 'Mouse6662_040418_SocialPreference_Class.mat', 'Mouse5333_060916_SocialPreference_Class.mat', 'Mouse6992_052416_SocialPreference_Class.mat', 'Mouse5331_070716_SocialPreference_Class.mat', 'Mouse6662_040918_SocialPreference_Class.mat', 'Mouse5333_061416_SocialPreference_Class.mat', 'Mouse6664_040218_SocialPreference_Class.mat', 'Mouse0634_092917_SocialPreference_Class.mat', 'Mouse6991_060216_SocialPreference_Class.mat', 'Mouse8893_091417_SocialPreference_Class.mat', 'Mouse0640_100517_SocialPreference_Class.mat', 'Mouse5321_060716_SocialPreference_Class.mat', 'Mouse0644_101017_SocialPreference_Class.mat', 'Mouse0631_092017_SocialPreference_Class.mat', 'Mouse8893_091617_SocialPreference_Class.mat', 'Mouse6664_040618_SocialPreference_Class.mat', 'Mouse0632_092917_SocialPreference_Class.mat', 'Mouse533L_052616_SocialPreference_Class.mat', 'Mouse8884_100417_SocialPreference_Class.mat', 'Mouse0631_092217_SocialPreference_Class.mat', 'Mouse5321_060216_SocialPreference_Class.mat', 'Mouse0643_091417_SocialPreference_Class.mat', 'Mouse8882_091117_SocialPreference_Class.mat', 'Mouse8893_092617_SocialPreference_Class.mat', 'Mouse8884_092917_SocialPreference_Class.mat', 'Mouse6662_033018_SocialPreference_Class.mat', 'Mouse0632_100617_SocialPreference_Class.mat', 'Mouse8891_100317_SocialPreference_Class.mat', 'Mouse5321_060916_SocialPreference_Class.mat', 'Mouse6674_041618_SocialPreference_Class.mat', 'Mouse6662_041118_SocialPreference_Class.mat', 'Mouse0641_091917_SocialPreference_Class.mat', 'Mouse533L_060216_SocialPreference_Class.mat', 'Mouse6664_032618_SocialPreference_Class.mat', 'Mouse0642_091617_SocialPreference_Class.mat', 'Mouse8893_093017_SocialPreference_Class.mat', 'Mouse6992_060716_SocialPreference_Class.mat', 'Mouse6991_070716_SocialPreference_Class.mat', 'Mouse8884_091117_SocialPreference_Class.mat', 'Mouse0644_093017_SocialPreference_Class.mat', 'Mouse0640_100717_SocialPreference_Class.mat', 'Mouse8891_090917_SocialPreference_Class.mat', 'Mouse5332_060216_SocialPreference_Class.mat', 'Mouse0642_093017_SocialPreference_Class.mat', 'Mouse0634_091817_SocialPreference_Class.mat', 'Mouse6990_060916_SocialPreference_Class.mat', 'Mouse0634_092517_SocialPreference_Class.mat', 'Mouse0644_100517_SocialPreference_Class.mat', 'Mouse6992_053116_SocialPreference_Class.mat', 'Mouse8891_091617_SocialPreference_Class.mat', 'Mouse8884_092217_SocialPreference_Class.mat', 'Mouse5321_061416_SocialPreference_Class.mat', 'Mouse8881_092217_SocialPreference_Class.mat', 'Mouse0633_092017_SocialPreference_Class.mat', 'Mouse0633_092917_SocialPreference_Class.mat', 'Mouse0640_092617_SocialPreference_Class.mat', 'Mouse0632_100417_SocialPreference_Class.mat', 'Mouse8891_092617_SocialPreference_Class.mat', 'Mouse533L_061416_SocialPreference_Class.mat', 'Mouse0633_092217_SocialPreference_Class.mat', 'Mouse6674_040418_SocialPreference_Class.mat', 'Mouse0631_100217_SocialPreference_Class.mat', 'Mouse0632_091517_SocialPreference_Class.mat', 'Mouse0631_100417_SocialPreference_Class.mat', 'Mouse6674_040618_SocialPreference_Class.mat', 'Mouse0633_091517_SocialPreference_Class.mat', 'Mouse6990_060716_SocialPreference_Class.mat', 'Mouse8881_091117_SocialPreference_Class.mat', 'Mouse0631_091817_SocialPreference_Class.mat', 'Mouse0642_092117_SocialPreference_Class.mat', 'Mouse6674_040218_SocialPreference_Class.mat', 'Mouse5321_061616_SocialPreference_Class.mat', 'Mouse8893_091217_SocialPreference_Class.mat', 'Mouse0634_092217_SocialPreference_Class.mat', 'Mouse8882_092517_SocialPreference_Class.mat', 'Mouse5331_070516_SocialPreference_Class.mat', 'Mouse6990_061416_SocialPreference_Class.mat', 'Mouse8882_092917_SocialPreference_Class.mat', 'Mouse0640_091417_SocialPreference_Class.mat', 'Mouse0642_100717_SocialPreference_Class.mat', 'Mouse5331_060716_SocialPreference_Class.mat', 'Mouse6990_052616_SocialPreference_Class.mat', 'Mouse6991_052416_SocialPreference_Class.mat', 'Mouse0644_092317_SocialPreference_Class.mat', 'Mouse6674_032618_SocialPreference_Class.mat', 'Mouse6662_041618_SocialPreference_Class.mat', 'Mouse8884_091517_SocialPreference_Class.mat', 'Mouse8891_093017_SocialPreference_Class.mat', 'Mouse0633_100917_SocialPreference_Class.mat', 'Mouse0643_092617_SocialPreference_Class.mat', 'Mouse8894_091617_SocialPreference_Class.mat', 'Mouse0630_091517_SocialPreference_Class.mat', 'Mouse8893_092317_SocialPreference_Class.mat', 'Mouse8882_100417_SocialPreference_Class.mat', 'Mouse0630_100217_SocialPreference_Class.mat', 'Mouse0640_092317_SocialPreference_Class.mat', 'Mouse6990_061616_SocialPreference_Class.mat', 'Mouse8882_091317_SocialPreference_Class.mat', 'Mouse0632_100217_SocialPreference_Class.mat', 'Mouse0630_092217_SocialPreference_Class.mat', 'Mouse0642_091917_SocialPreference_Class.mat', 'Mouse8884_100217_SocialPreference_Class.mat', 'Mouse6992_061616_SocialPreference_Class.mat', 'Mouse0644_100717_SocialPreference_Class.mat', 'Mouse0642_091417_SocialPreference_Class.mat', 'Mouse0641_100517_SocialPreference_Class.mat', 'Mouse0632_092017_SocialPreference_Class.mat', 'Mouse0631_092517_SocialPreference_Class.mat', 'Mouse8894_093017_SocialPreference_Class.mat', 'Mouse0631_092917_SocialPreference_Class.mat', 'Mouse699L_060216_SocialPreference_Class.mat', 'Mouse6991_060716_SocialPreference_Class.mat', 'Mouse0633_100217_SocialPreference_Class.mat', 'Mouse8884_092517_SocialPreference_Class.mat', 'Mouse0633_100617_SocialPreference_Class.mat', 'Mouse6674_041318_SocialPreference_Class.mat', 'Mouse5332_052616_SocialPreference_Class.mat', 'Mouse8891_091417_SocialPreference_Class.mat', 'Mouse6664_040918_SocialPreference_Class.mat', 'Mouse699L_061416_SocialPreference_Class.mat', 'Mouse8891_100517_SocialPreference_Class.mat', 'Mouse0641_100317_SocialPreference_Class.mat', 'Mouse0634_091517_SocialPreference_Class.mat', 'Mouse6662_040618_SocialPreference_Class.mat', 'Mouse0630_100617_SocialPreference_Class.mat', 'Mouse8894_092617_SocialPreference_Class.mat', 'Mouse8894_092317_SocialPreference_Class.mat', 'Mouse6992_052616_SocialPreference_Class.mat', 'Mouse8882_090817_SocialPreference_Class.mat', 'Mouse8893_100517_SocialPreference_Class.mat', 'Mouse5333_060716_SocialPreference_Class.mat', 'Mouse533L_070516_SocialPreference_Class.mat', 'Mouse8884_090817_SocialPreference_Class.mat', 'Mouse8882_100617_SocialPreference_Class.mat', 'Mouse0634_100617_SocialPreference_Class.mat', 'Mouse0630_092517_SocialPreference_Class.mat', 'Mouse0632_092217_SocialPreference_Class.mat', 'Mouse0642_100317_SocialPreference_Class.mat', 'Mouse5333_070516_SocialPreference_Class.mat', 'Mouse699L_053116_SocialPreference_Class.mat', 'Mouse8894_100717_SocialPreference_Class.mat', 'Mouse8881_100217_SocialPreference_Class.mat', 'Mouse0641_091417_SocialPreference_Class.mat', 'Mouse0630_091817_SocialPreference_Class.mat', 'Mouse6991_052616_SocialPreference_Class.mat', 'Mouse0633_091817_SocialPreference_Class.mat', 'Mouse0642_101017_SocialPreference_Class.mat', 'Mouse0641_092117_SocialPreference_Class.mat', 'Mouse0641_092617_SocialPreference_Class.mat', 'Mouse5332_061416_SocialPreference_Class.mat', 'Mouse699L_070516_SocialPreference_Class.mat', 'Mouse0641_100717_SocialPreference_Class.mat', 'Mouse0632_100917_SocialPreference_Class.mat', 'Mouse0640_091617_SocialPreference_Class.mat', 'Mouse5333_060216_SocialPreference_Class.mat', 'Mouse699L_052616_SocialPreference_Class.mat', 'Mouse5332_052416_SocialPreference_Class.mat', 'Mouse6664_040418_SocialPreference_Class.mat', 'Mouse0641_091617_SocialPreference_Class.mat', 'Mouse8881_091317_SocialPreference_Class.mat', 'Mouse0643_092317_SocialPreference_Class.mat', 'Mouse5332_070716_SocialPreference_Class.mat', 'Mouse6992_060216_SocialPreference_Class.mat', 'Mouse0633_092517_SocialPreference_Class.mat', 'Mouse6674_040918_SocialPreference_Class.mat', 'Mouse6662_041318_SocialPreference_Class.mat', 'Mouse6991_070516_SocialPreference_Class.mat', 'Mouse0634_100417_SocialPreference_Class.mat', 'Mouse533L_060716_SocialPreference_Class.mat', 'Mouse8891_092317_SocialPreference_Class.mat', 'Mouse5321_053116_SocialPreference_Class.mat', 'Mouse0643_092117_SocialPreference_Class.mat', 'Mouse0633_091417_SocialPreference_Class.mat', 'Mouse8881_090817_SocialPreference_Class.mat', 'Mouse8893_100317_SocialPreference_Class.mat', 'Mouse0644_091617_SocialPreference_Class.mat', 'Mouse5331_061616_SocialPreference_Class.mat', 'Mouse0632_091417_SocialPreference_Class.mat', 'Mouse0642_092317_SocialPreference_Class.mat', 'Mouse0631_091417_SocialPreference_Class.mat', 'Mouse0642_100517_SocialPreference_Class.mat', 'Mouse0634_100217_SocialPreference_Class.mat', 'Mouse0630_100917_SocialPreference_Class.mat', 'Mouse0643_100317_SocialPreference_Class.mat', 'Mouse0630_091417_SocialPreference_Class.mat', 'Mouse0643_100717_SocialPreference_Class.mat', 'Mouse5333_061616_SocialPreference_Class.mat', 'Mouse0643_093017_SocialPreference_Class.mat', 'Mouse0644_092117_SocialPreference_Class.mat', 'Mouse6992_070716_SocialPreference_Class.mat', 'Mouse5333_070716_SocialPreference_Class.mat', 'Mouse0630_100417_SocialPreference_Class.mat', 'Mouse6664_041118_SocialPreference_Class.mat', 'Mouse8884_091317_SocialPreference_Class.mat', 'Mouse0634_091417_SocialPreference_Class.mat', 'Mouse0641_101017_SocialPreference_Class.mat', 'Mouse0644_100317_SocialPreference_Class.mat', 'Mouse8894_100317_SocialPreference_Class.mat', 'Mouse533L_053116_SocialPreference_Class.mat', 'Mouse0644_091917_SocialPreference_Class.mat', 'Mouse5331_061416_SocialPreference_Class.mat', 'Mouse5332_061616_SocialPreference_Class.mat', 'Mouse8894_090917_SocialPreference_Class.mat', 'Mouse5332_053116_SocialPreference_Class.mat', 'Mouse8893_100717_SocialPreference_Class.mat', 'Mouse6991_061616_SocialPreference_Class.mat', 'Mouse0630_092017_SocialPreference_Class.mat', 'Mouse8891_091217_SocialPreference_Class.mat', 'Mouse8881_091517_SocialPreference_Class.mat', 'Mouse5331_052416_SocialPreference_Class.mat', 'Mouse0640_092117_SocialPreference_Class.mat', 'Mouse699L_052416_SocialPreference_Class.mat', 'Mouse5331_053116_SocialPreference_Class.mat', 'Mouse8882_091517_SocialPreference_Class.mat', 'Mouse0640_093017_SocialPreference_Class.mat', 'Mouse8891_100717_SocialPreference_Class.mat', 'Mouse0631_100917_SocialPreference_Class.mat', 'Mouse0640_101017_SocialPreference_Class.mat', 'Mouse0640_100317_SocialPreference_Class.mat', 'Mouse0641_092317_SocialPreference_Class.mat', 'Mouse0641_093017_SocialPreference_Class.mat', 'Mouse6664_041318_SocialPreference_Class.mat', 'Mouse6664_041618_SocialPreference_Class.mat', 'Mouse8894_091217_SocialPreference_Class.mat', 'Mouse8881_092517_SocialPreference_Class.mat', 'Mouse5333_052616_SocialPreference_Class.mat', 'Mouse0640_091917_SocialPreference_Class.mat', 'Mouse5333_052416_SocialPreference_Class.mat', 'Mouse6991_061416_SocialPreference_Class.mat', 'Mouse6991_060916_SocialPreference_Class.mat', 'Mouse5332_070516_SocialPreference_Class.mat', 'Mouse0644_092617_SocialPreference_Class.mat', 'Mouse8881_100417_SocialPreference_Class.mat', 'Mouse0631_091517_SocialPreference_Class.mat', 'Mouse0644_091417_SocialPreference_Class.mat', 'Mouse5332_060716_SocialPreference_Class.mat', 'Mouse0630_092917_SocialPreference_Class.mat', 'Mouse0643_100517_SocialPreference_Class.mat', 'Mouse0643_091617_SocialPreference_Class.mat', 'Mouse0642_092617_SocialPreference_Class.mat', 'Mouse6992_070516_SocialPreference_Class.mat', 'Mouse0643_101017_SocialPreference_Class.mat', 'Mouse699L_070716_SocialPreference_Class.mat'
]

proper_label_files = list(set([x for x in proper_s_label_files+proper_o_label_files if x in proper_s_label_files and x in proper_o_label_files]))
print(len(proper_label_files))

---

# TST Experiment Analyses

## Visualizing TST REDCLIFF-S GC Models

In [None]:
!pip3 install torcheeg

In [None]:
!pip3 install torch_scatter

In [None]:
# ----------------------------------------------------------------------------------------------------------------------
# FILE DESCRIPTION
# ----------------------------------------------------------------------------------------------------------------------

# File:  plot.py
# Author:  anonymous
# Date written:  01-18-2022
# Last modified:  10-25-2023

r"""
Description:
"""


# ----------------------------------------------------------------------------------------------------------------------
# IMPORT STATEMENTS
# ----------------------------------------------------------------------------------------------------------------------

# Import statements
import numpy as np
from scipy.stats import pearsonr
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba, LinearSegmentedColormap
from matplotlib.patches import Polygon
import seaborn as sns

# Constants
R1 = 1.0  # inner radius of power plots


# ----------------------------------------------------------------------------------------------------------------------
# FUNCTION DEFINITIONS
# ----------------------------------------------------------------------------------------------------------------------

# Chord plot
def chord_plot(
    x,
    rois=None,
    freqs=None,
    freq_ticks=None,
    max_alpha=0.7,
    buffer_percent=1.0,
    outer_radius=1.2,
    min_max_quantiles=(0.5, 0.9),
    color=None,
    cmap=None,
    roi_fontsize=13,
    roi_extent=0.28,
    tick_extent=0.03,
    tick_label_extent=0.11,
    tick_label_fontsize=10.0,
    fontfamily='sans-serif',
    figsize=(7, 7)):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Check arguments
    assert x.ndim == 3
    assert x.shape[0] == x.shape[1]
    assert max_alpha >= 0.0 and max_alpha <= 1.0, f"{max_alpha}"
    n_roi, n_freq = x.shape[1:]
    assert freqs is None or len(freqs) == n_freq, f"{len(freqs)} != {n_freq}"

    # Replace ROI underscores with spaces
    if rois is not None:
        assert len(rois) == n_roi, f"{len(rois)} != {n_roi}"
        pretty_rois = [roi.replace("_", " ") for roi in rois]

    # Default color
    if color is None and cmap is None:
        color = 'tab:blue'

    # Set color to None if color map is provided
    if cmap is not None:
        color = None

    #  variables
    r2 = outer_radius
    center_angles = np.linspace(0, 2 * np.pi, n_roi + 1)
    buffer = buffer_percent / 100.0 * 2.0 * np.pi
    start_angles = center_angles[:-1] + buffer
    stop_angles = center_angles[1:] - buffer
    freq_diff = (stop_angles[0] - start_angles[0]) / (n_freq + 1)
    min_val, max_val = np.quantile(x, min_max_quantiles)
    x = max_alpha * np.clip((x - min_val) / (max_val - min_val), 0.0, 1.0)

    # Set up axes and labels and ticks
    _, ax = _set_up_chord_plot(
        start_angles=start_angles,
        stop_angles=stop_angles,
        r1=R1,
        r2=r2,
        pretty_rois=pretty_rois,
        freqs=freqs,
        freq_ticks=freq_ticks,
        tick_extent=tick_extent,
        tick_label_extent=tick_label_extent,
        tick_label_fontsize=tick_label_fontsize,
        roi_fontsize=roi_fontsize,
        roi_extent=roi_extent,
        fontfamily=fontfamily,
        figsize=figsize)

    # Add the power and chord plots
    _update_chord_plot(
        x=x,
        ax=ax,
        start_angles=start_angles,
        stop_angles=stop_angles,
        freq_diff=freq_diff,
        outer_radius=outer_radius,
        color=color,
        cmap=cmap)

    # Return plot axis
    return ax


# Update chord plot
def _update_chord_plot(
    x,
    ax,
    start_angles,
    stop_angles,
    freq_diff,
    outer_radius,
    color,
    cmap=None):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Initialize variables
    r2 = outer_radius
    handles = []
    n_roi, n_freq = x.shape[1:]

    # Create colormap array for different frequencies
    if cmap is not None:
        if isinstance(cmap, str):
            cmap_idx = np.linspace(0, 1, n_freq)
            # cmap_color_arr = mpl.colormaps[cmap](cmap_idx)[:, :3]
            cmap_color_arr = mpl.cm.get_cmap(cmap)(cmap_idx)[:, :3]
        elif isinstance(cmap, LinearSegmentedColormap):
            cmap_idx = np.linspace(0, 1, n_freq)
            cmap_color_arr = cmap(cmap_idx)[:, :3]
        else:
            cmap_color_arr = None
            color = 'b'
    else:
        cmap_color_arr = None

    # Draw the power plots
    for i, (c1, c2) in enumerate(zip(start_angles, stop_angles)):
        # Iterate over frequencies
        for j in range(n_freq):
            # Retrieve colormap color for corresponding frequency value
            if cmap_color_arr is not None:
                color = cmap_color_arr[j]

            # Generate arc power patch
            if x[i, i, j] > 0:
                # Arc power patch rotation
                diff1 = j * (c2 - c1) / n_freq
                diff2 = (j + 1) * (c2 - c1) / n_freq

                # Transparency value
                alpha = x[i, i, j]

                # Append arc power patch to handles
                h = _arc_patch(
                    r1=R1,
                    r2=r2,
                    theta1=c1 + diff1,
                    theta2=c1 + diff2,
                    ax=ax,
                    color=color,
                    cmap=cmap,
                    n=5,
                    alpha=alpha)
                handles.append(h)

    # Draw the chords to represent cross-power
    print("WARNING - CROSS-POWER PLOT HAS BEEN OVER-WRITTEN TO REPRESENT DIRECTIONAL ADJACENCY MATRICES!!!")
    for i in range(n_roi):# - 1):
        for j in range(n_roi):#i + 1, n_roi):
            # Iterate over frequency values
            for k in range(n_freq):
                # Frequency-specific color map color
                if cmap_color_arr is not None:
                    color = cmap_color_arr[k]
                if i >= j:
                    color = 'b'
                else:
                    color = 'r'

                # Generate chord connection
                if x[i, j, k] > 0.0:
                    # Chord connection rotation
                    theta1 = start_angles[i] + freq_diff * k
                    theta2 = start_angles[j] + freq_diff * k

                    # Transparency value
                    alpha = x[i, j, k]

                    # Append chord connection polygon to handles
                    h = _plot_poly_chord(
                        theta1=theta1,
                        theta2=theta2,
                        diff=freq_diff,
                        ax=ax,
                        color=color,
                        alpha=alpha)
                    handles.append(h)

    # Return collection of power arc and chord connection handles
    return handles


# Set up chord plot
def _set_up_chord_plot(
    start_angles,
    stop_angles,
    r1,
    r2,
    pretty_rois,
    freqs,
    freq_ticks,
    tick_extent,
    tick_label_extent,
    tick_label_fontsize,
    roi_fontsize,
    roi_extent,
    fontfamily,
    figsize):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Initialize figure
    fig = plt.figure(figsize=figsize)
    ax = plt.gca()

    # Set up axes and draw power plots
    for i, (c1, c2) in enumerate(zip(start_angles, stop_angles)):
        # Draw power axis
        _draw_power_axis(
            r1=r1,
            r2=r2,
            theta1=c1,
            theta2=c2,
            ax=ax)

        # Plot ticks
        if freqs is not None and freq_ticks is not None:
            _plot_ticks(
                r=r2,
                theta1=c1,
                theta2=c2,
                ax=ax,
                freqs=freqs,
                freq_ticks=freq_ticks,
                tick_extent=tick_extent,
                tick_label_extent=tick_label_extent,
                tick_label_fontsize=tick_label_fontsize,
                fontfamily=fontfamily)

        # Annotate ROIs
        if pretty_rois is not None:
            _plot_roi_name(
                r=r2,
                theta=0.5 * (c1 + c2),
                ax=ax,
                roi=pretty_rois[i],
                extent=roi_extent,
                fontsize=roi_fontsize,
                fontfamily=fontfamily)

    # Axis limits
    ax.set_ylim(-1.5, 1.5)
    ax.set_xlim(-1.5, 1.5)
    plt.axis("off")

    # Return figure and axis variables
    return fig, ax


# Plot poly chord
def _plot_poly_chord(
    theta1,
    theta2,
    diff,
    ax,
    color,
    n=50,
    alpha=0.5):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Chord chonnection points
    points1 = _chord_helper(theta1, theta2, n=n)
    rot_mat = np.array([[np.cos(diff), -np.sin(diff)], [np.sin(diff), np.cos(diff)]])
    points2 = rot_mat @ points1
    points = np.concatenate([points1, points2[:, ::-1]], axis=1).T

    # Chord connection polygon
    poly = Polygon(points, closed=True, fc=to_rgba(c=color, alpha=alpha))
    ax.add_patch(poly)

    # Return chord connection polygon
    return poly


# Chord helper
def _chord_helper(theta1, theta2, n=50):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Chord helper coordinates calculations
    a1, a2 = np.cos(theta1), np.sin(theta1)
    b1, b2 = np.cos(theta2), np.sin(theta2)
    denom = a1 * b2 - a2 * b1
    if np.abs(denom) < 1e-5:
        xs = np.linspace(a1, b1, n)
        ys = np.linspace(a2, b2, n)

        return np.vstack([xs, ys])
    v, w = 2.0 * (a2 - b2) / denom, 2.0 * (b1 - a1) / denom
    center = (-v / 2.0, -w / 2.0)
    radius = np.sqrt(((v ** 2.0 + w ** 2.0) / 4.0) - 1.0)
    angle1 = np.arctan2(a2 - center[1], a1 - center[0])
    angle2 = np.arctan2(b2 - center[1], b1 - center[0])
    angle1, angle2 = min(angle1, angle2), max(angle1, angle2)
    if angle2 - angle1 > np.pi:
        angle1, angle2 = angle2, angle1 + 2 * np.pi
    theta = np.linspace(angle1, angle2, n)
    xs = radius * np.cos(theta) + center[0]
    ys = radius * np.sin(theta) + center[1]

    # Return coordinates
    return np.vstack([xs, ys])


# Arc patch
def _arc_patch(
    r1,
    r2,
    theta1,
    theta2,
    ax,
    color,
    cmap=None,
    n=50,
    alpha=1.0,
    **kwargs):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Power arc points
    thetas = np.linspace(theta1, theta2, n)
    sin_thetas, cos_thetas = np.sin(thetas), np.cos(thetas)
    points = np.vstack([cos_thetas, sin_thetas]).T
    points = np.concatenate([r1 * points, r2 * points[::-1]], axis=0)

    # Power arc polygon
    poly = Polygon(
        points,
        closed=True,
        fc=to_rgba(color, alpha=alpha),
        **kwargs)
    ax.add_patch(poly)

    # Return power arc polygon
    return poly


# Draw power axis
def _draw_power_axis(
    r1,
    r2,
    theta1,
    theta2,
    ax,
    n=50,
    **kwargs):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Power axis points
    thetas = np.linspace(theta1, theta2, n)
    sin_thetas, cos_thetas = np.sin(thetas), np.cos(thetas)
    points = np.vstack([cos_thetas, sin_thetas]).T
    points = np.concatenate([r1 * points, r2 * points[::-1]], axis=0)
    points = np.concatenate([points, points[:1]], axis=0)

    # Power axis handle
    handle = ax.plot(points[:, 0], points[:, 1], c='k', **kwargs)

    # Return handle
    return handle


# Plot ticks
def _plot_ticks(
    r,
    theta1,
    theta2,
    ax,
    freqs,
    freq_ticks,
    tick_extent=0.03,
    tick_label_extent=0.11,
    tick_label_fontsize=10.0,
    n=5,
    fontfamily='sans-serif',
    **kwargs):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # Tick offset
    offset = 0.0 if np.cos((theta1 + theta2) / 2.0) > 0.0 else 180.0

    # Iterate over frequency tick values
    for freq in freq_ticks:
        # Tick location and rotation
        theta = theta1 + (theta2 - theta1) * (freq - freqs[0]) / (freqs[-1] - freqs[0])
        x = [r * np.cos(theta), (r + tick_extent) * np.cos(theta)]
        y = [r * np.sin(theta), (r + tick_extent) * np.sin(theta)]

        # Plot tick
        ax.plot(x, y, c="k", **kwargs)

        # Tick label location and rotation
        x = (r + tick_label_extent) * np.cos(theta)
        y = (r + tick_label_extent) * np.sin(theta)
        rotation = (theta * 180.0 / np.pi) + offset

        # Tick text / label
        ax.text(
            x=x,
            y=y,
            s=str(freq),
            rotation=rotation,
            fontfamily=fontfamily,
            fontsize=tick_label_fontsize,
            ha='center',
            va='center')


# Plot ROI name
def _plot_roi_name(
    r,
    theta,
    ax,
    roi,
    extent=0.3,
    fontsize=13,
    fontfamily='sans-serif'):
    r"""

    Parameters
    ----------

    Returns
    -------
    """

    # ROI location and rotation
    x, y = (r + extent) * np.cos(theta), (r + extent) * np.sin(theta)
    rotation = (theta * 180.0 / np.pi) - 90.0

    # Offset rotation
    if np.sin(theta) < 0.0:
        rotation += 180.0

    # ROI text
    ax.text(
        x=x,
        y=y,
        s=roi,
        rotation=rotation,
        ha='center',
        va='center',
        fontfamily=fontfamily,
        fontsize=fontsize)




## Visualization of All Folds of TST 9-factor REDCLIFF-S Models, 01/15/2025

### Fold 0

In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt
import pickle as pkl

from general_utils.misc import get_topk_graph_mask

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title


factor_names = ["Home Cage (HC)", "Open Field (OF)", "Tail Suspended (TS)",
                "UNKNOWN 1 (U1)", "UNKNOWN 2 (U2)", "UNKNOWN 3 (U3)",
                "UNKNOWN 4 (U4)", "UNKNOWN 5 (U5)", "UNKNOWN 6 (U6)", ]
channel_names = [
    'Acb_Core', 'Acb_Sh', 'IL_Cx', 'L_VTA', 'Md_Thal', 'PrL_Cx', 'R_VTA', 'aILH_Hab', 'IDHip', 'lSNC', 'mDHip', 'mSNC'
]
model = torch.load("final_best_model_FOLD0.bin", map_location=torch.device('cpu'))
curr_gc_factor_ests = [x.detach().numpy() for x in model.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]


curr_gc_factor_ests = [x/np.max(x) for x in curr_gc_factor_ests]


for i in range(len(curr_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()



print("\n\n\n # OFF-DIAGONAL VISUALIZATIONS ##########################################################################################################")

curr_offDiag_gc_factor_ests = [x.detach().numpy() for x in model.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_offDiag_gc_factor_ests = [x - x*np.expand_dims(np.eye(x.shape[0]), axis=2) for x in curr_offDiag_gc_factor_ests]
curr_offDiag_gc_factor_ests = [x/np.max(x) for x in curr_offDiag_gc_factor_ests]

for i in range(len(curr_offDiag_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()

### Fold 1

In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt
import pickle as pkl

from general_utils.misc import get_topk_graph_mask

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title


factor_names = ["Home Cage (HC)", "Open Field (OF)", "Tail Suspended (TS)",
                "UNKNOWN 1 (U1)", "UNKNOWN 2 (U2)", "UNKNOWN 3 (U3)",
                "UNKNOWN 4 (U4)", "UNKNOWN 5 (U5)", "UNKNOWN 6 (U6)", ]
channel_names = [
    'Acb_Core', 'Acb_Sh', 'IL_Cx', 'L_VTA', 'Md_Thal', 'PrL_Cx', 'R_VTA', 'aILH_Hab', 'IDHip', 'lSNC', 'mDHip', 'mSNC'
]
model = torch.load("final_best_model_FOLD1.bin", map_location=torch.device('cpu'))
curr_gc_factor_ests = [x.detach().numpy() for x in model.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]


curr_gc_factor_ests = [x/np.max(x) for x in curr_gc_factor_ests]


for i in range(len(curr_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()



print("\n\n\n # OFF-DIAGONAL VISUALIZATIONS ##########################################################################################################")

curr_offDiag_gc_factor_ests = [x.detach().numpy() for x in model.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_offDiag_gc_factor_ests = [x - x*np.expand_dims(np.eye(x.shape[0]), axis=2) for x in curr_offDiag_gc_factor_ests]
curr_offDiag_gc_factor_ests = [x/np.max(x) for x in curr_offDiag_gc_factor_ests]

for i in range(len(curr_offDiag_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()

### Fold 2

In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt
import pickle as pkl

from general_utils.misc import get_topk_graph_mask

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title


factor_names = ["Home Cage (HC)", "Open Field (OF)", "Tail Suspended (TS)",
                "UNKNOWN 1 (U1)", "UNKNOWN 2 (U2)", "UNKNOWN 3 (U3)",
                "UNKNOWN 4 (U4)", "UNKNOWN 5 (U5)", "UNKNOWN 6 (U6)", ]
channel_names = [
    'Acb_Core', 'Acb_Sh', 'IL_Cx', 'L_VTA', 'Md_Thal', 'PrL_Cx', 'R_VTA', 'aILH_Hab', 'IDHip', 'lSNC', 'mDHip', 'mSNC'
]
model = torch.load("final_best_model_FOLD2.bin", map_location=torch.device('cpu'))
curr_gc_factor_ests = [x.detach().numpy() for x in model.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]


curr_gc_factor_ests = [x/np.max(x) for x in curr_gc_factor_ests]


for i in range(len(curr_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()



print("\n\n\n # OFF-DIAGONAL VISUALIZATIONS ##########################################################################################################")

curr_offDiag_gc_factor_ests = [x.detach().numpy() for x in model.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_offDiag_gc_factor_ests = [x - x*np.expand_dims(np.eye(x.shape[0]), axis=2) for x in curr_offDiag_gc_factor_ests]
curr_offDiag_gc_factor_ests = [x/np.max(x) for x in curr_offDiag_gc_factor_ests]

for i in range(len(curr_offDiag_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()

### Fold 3

In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt
import pickle as pkl

from general_utils.misc import get_topk_graph_mask

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title


factor_names = ["Home Cage (HC)", "Open Field (OF)", "Tail Suspended (TS)",
                "UNKNOWN 1 (U1)", "UNKNOWN 2 (U2)", "UNKNOWN 3 (U3)",
                "UNKNOWN 4 (U4)", "UNKNOWN 5 (U5)", "UNKNOWN 6 (U6)", ]
channel_names = [
    'Acb_Core', 'Acb_Sh', 'IL_Cx', 'L_VTA', 'Md_Thal', 'PrL_Cx', 'R_VTA', 'aILH_Hab', 'IDHip', 'lSNC', 'mDHip', 'mSNC'
]
model = torch.load("final_best_model_FOLD3.bin", map_location=torch.device('cpu'))
curr_gc_factor_ests = [x.detach().numpy() for x in model.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]


curr_gc_factor_ests = [x/np.max(x) for x in curr_gc_factor_ests]


for i in range(len(curr_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()



print("\n\n\n # OFF-DIAGONAL VISUALIZATIONS ##########################################################################################################")

curr_offDiag_gc_factor_ests = [x.detach().numpy() for x in model.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_offDiag_gc_factor_ests = [x - x*np.expand_dims(np.eye(x.shape[0]), axis=2) for x in curr_offDiag_gc_factor_ests]
curr_offDiag_gc_factor_ests = [x/np.max(x) for x in curr_offDiag_gc_factor_ests]

for i in range(len(curr_offDiag_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()

### Fold 4

In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt
import pickle as pkl

from general_utils.misc import get_topk_graph_mask

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title


factor_names = ["Home Cage (HC)", "Open Field (OF)", "Tail Suspended (TS)",
                "UNKNOWN 1 (U1)", "UNKNOWN 2 (U2)", "UNKNOWN 3 (U3)",
                "UNKNOWN 4 (U4)", "UNKNOWN 5 (U5)", "UNKNOWN 6 (U6)", ]
channel_names = [
    'Acb_Core', 'Acb_Sh', 'IL_Cx', 'L_VTA', 'Md_Thal', 'PrL_Cx', 'R_VTA', 'aILH_Hab', 'IDHip', 'lSNC', 'mDHip', 'mSNC'
]
model = torch.load("final_best_model_FOLD4.bin", map_location=torch.device('cpu'))
curr_gc_factor_ests = [x.detach().numpy() for x in model.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]


curr_gc_factor_ests = [x/np.max(x) for x in curr_gc_factor_ests]


for i in range(len(curr_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()



print("\n\n\n # OFF-DIAGONAL VISUALIZATIONS ##########################################################################################################")

curr_offDiag_gc_factor_ests = [x.detach().numpy() for x in model.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_offDiag_gc_factor_ests = [x - x*np.expand_dims(np.eye(x.shape[0]), axis=2) for x in curr_offDiag_gc_factor_ests]
curr_offDiag_gc_factor_ests = [x/np.max(x) for x in curr_offDiag_gc_factor_ests]

for i in range(len(curr_offDiag_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()

### Avg. Across Folds

In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt
import pickle as pkl

from general_utils.misc import get_topk_graph_mask

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title


factor_names = ["Home Cage (HC)", "Open Field (OF)", "Tail Suspended (TS)",
                "UNKNOWN 1 (U1)", "UNKNOWN 2 (U2)", "UNKNOWN 3 (U3)",
                "UNKNOWN 4 (U4)", "UNKNOWN 5 (U5)", "UNKNOWN 6 (U6)", ]
channel_names = [
    'Acb_Core', 'Acb_Sh', 'IL_Cx', 'L_VTA', 'Md_Thal', 'PrL_Cx', 'R_VTA', 'aILH_Hab', 'IDHip', 'lSNC', 'mDHip', 'mSNC'
]
model0 = torch.load("final_best_model_FOLD0.bin", map_location=torch.device('cpu'))
model1 = torch.load("final_best_model_FOLD1.bin", map_location=torch.device('cpu'))
model2 = torch.load("final_best_model_FOLD2.bin", map_location=torch.device('cpu'))
model3 = torch.load("final_best_model_FOLD3.bin", map_location=torch.device('cpu'))
model4 = torch.load("final_best_model_FOLD4.bin", map_location=torch.device('cpu'))
curr_gc_factor_ests0 = [x.detach().numpy() for x in model0.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_gc_factor_ests0 = [x/np.max(x) for x in curr_gc_factor_ests0]
curr_gc_factor_ests1 = [x.detach().numpy() for x in model1.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_gc_factor_ests1 = [x/np.max(x) for x in curr_gc_factor_ests1]
curr_gc_factor_ests2 = [x.detach().numpy() for x in model2.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_gc_factor_ests2 = [x/np.max(x) for x in curr_gc_factor_ests2]
curr_gc_factor_ests3 = [x.detach().numpy() for x in model3.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_gc_factor_ests3 = [x/np.max(x) for x in curr_gc_factor_ests3]
curr_gc_factor_ests4 = [x.detach().numpy() for x in model4.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_gc_factor_ests4 = [x/np.max(x) for x in curr_gc_factor_ests4]


curr_gc_factor_ests = [x0+x1+x2+x3+x4 for (x0,x1,x2,x3,x4) in zip(curr_gc_factor_ests0, curr_gc_factor_ests1, curr_gc_factor_ests2, curr_gc_factor_ests3, curr_gc_factor_ests4)]
curr_gc_factor_ests = [x/5. for x in curr_gc_factor_ests]


for i in range(len(curr_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()



print("\n\n\n # OFF-DIAGONAL VISUALIZATIONS ##########################################################################################################")

curr_offDiag_gc_factor_ests = [x.detach().numpy() for x in model.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_offDiag_gc_factor_ests = [x - x*np.expand_dims(np.eye(x.shape[0]), axis=2) for x in curr_offDiag_gc_factor_ests]
curr_offDiag_gc_factor_ests = [x/np.max(x) for x in curr_offDiag_gc_factor_ests]

for i in range(len(curr_offDiag_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()

## Visualization of Fold4 TST 9-factor REDCLIFF-S Model, 01/04/2025

In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt
import pickle as pkl

from general_utils.misc import get_topk_graph_mask

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title


factor_names = ["Home Cage (HC)", "Open Field (OF)", "Tail Suspended (TS)",
                "UNKNOWN 1 (U1)", "UNKNOWN 2 (U2)", "UNKNOWN 3 (U3)",
                "UNKNOWN 4 (U4)", "UNKNOWN 5 (U5)", "UNKNOWN 6 (U6)", ]
channel_names = [
    'Acb_Core', 'Acb_Sh', 'IL_Cx', 'L_VTA', 'Md_Thal', 'PrL_Cx', 'R_VTA', 'aILH_Hab', 'IDHip', 'lSNC', 'mDHip', 'mSNC'
]
model = torch.load("RegAvg_final_best_model.bin", map_location=torch.device('cpu'))
curr_gc_factor_ests = [x.detach().numpy() for x in model.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]


curr_gc_factor_ests = [x/np.max(x) for x in curr_gc_factor_ests]


for i in range(len(curr_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")#.set_fontsize(15)
        plt.ylabel("Receiving Channels")#.set_fontsize(15)
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])#.set_fontsize(23)#plt.title("Factor"+str(j)+" - Factor"+str(i)).set_fontsize(20)
        plt.show()



print("\n\n\n # OFF-DIAGONAL VISUALIZATIONS ##########################################################################################################")

curr_offDiag_gc_factor_ests = [x.detach().numpy() for x in model.GC("fixed_factor_exclusive", X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)[0]]
curr_offDiag_gc_factor_ests = [x - x*np.expand_dims(np.eye(x.shape[0]), axis=2) for x in curr_offDiag_gc_factor_ests]
curr_offDiag_gc_factor_ests = [x/np.max(x) for x in curr_offDiag_gc_factor_ests]

for i in range(len(curr_offDiag_gc_factor_ests)):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
    im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r')
    fig.colorbar(im, orientation='vertical')
    plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
    plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
    plt.xlabel("Driving Channels")
    plt.ylabel("Receiving Channels")
    plt.title("Est. Causality: "+factor_names[i]+"\n")
    plt.show()

print("DIFFERENCE BETWEEN FACTOR 1 AND FACTOR 2 ------------------------")

for i in range(3):
    for j in range(i,3):
        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[j]+" - "+factor_names[i])
        plt.show()

        fig, ax = plt.subplots()
        fig.set_size_inches(9, 9)# see https://stackoverflow.com/questions/14770735/how-do-i-change-the-figure-size-with-subplots
        im = ax.imshow(curr_offDiag_gc_factor_ests[i][:,:,:].sum(axis=2) - curr_offDiag_gc_factor_ests[j][:,:,:].sum(axis=2), cmap='RdGy_r', vmin=-1., vmax=1.)
        fig.colorbar(im, orientation='vertical')
        plt.xticks(range(0,len(channel_names)), channel_names, rotation=70)
        plt.yticks(range(0,len(channel_names)), channel_names, rotation=0)
        plt.xlabel("Driving Channels")
        plt.ylabel("Receiving Channels")
        plt.title("Diff. in Estimated Granger Causality:\n"+factor_names[i]+" - "+factor_names[j])
        plt.show()

## TST REDCLIFF-S Factor Selection

In [None]:
# 01/02/2025: Comparing cross-validated stopping criteria performance of different numbers of factors for REDCLIFF-S (with smoothing) models on TST FULL

from matplotlib import pyplot as plt
import numpy as np

a = (1.8152634328206378 + 1.6735604640007018 + 1.7691205581029255 + 2.0231672916412355 + 1.8211184668540952)/5.

b = (1.644940238793691 + 1.5510279677708942 + 1.5539443830490112 + 1.6118970626195268 + 1.5942805566628777)/5.

c = (1.6701837122599286 + 1.4978212252934775 + 1.5763399982452393 + 1.3798458410898844 + 1.5405776233673096)/5.

d = (1.3647625046412146 + 1.4716529840628305 + 1.521159366607666 + 1.538966859380404 + 1.6362474794705708)/5.

e = (1.5955952912012736 + 1.560348970413208 + 1.1719687076687815 + 1.4126835092902184 + 1.4325757805506387)/5.

f = (1.4241380001703898 + 1.4003433541456858 + 1.5324883454640708 + 1.3831266793568928 + 1.4104654241402943)/5.

print("REDCLIFF_S_CMLP_nK3_nsK3: ", a)
print("REDCLIFF_S_CMLP_nK4_nsK3: ", b)
print("REDCLIFF_S_CMLP_nK5_nsK3: ", c)
print("REDCLIFF_S_CMLP_nK6_nsK3: ", d)
print("REDCLIFF_S_CMLP_nK9_nsK3: ", e)
print("REDCLIFF_S_CMLP_nK18_nsK3: ", f)

plt.plot([3,4,5,6,9,18,], [a,b,c,d,e,f], color='grey', alpha=0.5)
plt.scatter([3,4,5,6,18,], [a,b,c,d,f,], marker="+", color='k')
plt.scatter([9,], [e,], marker="^", label="Selected Model", color="orangered")
plt.xlabel("Number of Factors in Model")
plt.ylabel("Avg. Stopping Criteria Performance Across Folds")
plt.title("Determining the Number of Factors for TST (Full) REDCLIFF-S Model")
plt.legend()
plt.show()

In [None]:
# 01/02/2025: Comparing cross-validated stopping criteria performance of different numbers of factors for REDCLIFF-S (with smoothing) models on TST SUBSET
print("REDCLIFF_S_CMLP_nK3_nsK3: ", (2.4684973526000977 + 2.954598375956217 + 2.9015039634704594 + 2.4891041628519694 + 2.678949788411458)/5.)

print("REDCLIFF_S_CMLP_nK4_nsK3: ", (2.564681212107341 + 2.581714407602946 + 2.5638554255167647 + 2.7155224482218423 + 2.8519512557983395)/5.)

print("REDCLIFF_S_CMLP_nK5_nsK3: ", (2.271371332804362 + 2.059069341023763 + 2.5489660135904946 + 2.4858938280741376 + 2.601390247344971)/5.)

print("REDCLIFF_S_CMLP_nK6_nsK3: ", (2.195953181584676 + 2.5948117383321128 + 2.5258967526753744 + 2.5001147365570064 + 2.7355075772603357)/5.)

print("REDCLIFF_S_CMLP_nK9_nsK3: ", (2.130304797490438 + 2.8107280000050863 + 2.250448354085287 + 2.1140146795908614 + 2.0482943630218506)/5.)

print("REDCLIFF_S_CMLP_nK18_nsK3: ", (2.5113415718078613 + 2.098219045003255 + 2.7476954587300617 + 2.7606082439422606 + 2.124862407048543)/5.)

print("REDCLIFF_S_CMLP_nK30_nsK3: ", (2.176120545069377 + 2.146807723045349 + 2.4143293126424155 + 2.2496908124287924 + 2.1703793573379517)/5.)

print("REDCLIFF_S_CMLP_nK45_nsK3: ", (2.1492295328776043 + 2.4391604391733805 + 2.133251234690348 + 2.4489138730367026 + 2.380597751935323)/5.)


from matplotlib import pyplot as plt
import numpy as np
a = (2.4684973526000977 + 2.954598375956217 + 2.9015039634704594 + 2.4891041628519694 + 2.678949788411458)/5.
b = (2.564681212107341 + 2.581714407602946 + 2.5638554255167647 + 2.7155224482218423 + 2.8519512557983395)/5.
c = (2.271371332804362 + 2.059069341023763 + 2.5489660135904946 + 2.4858938280741376 + 2.601390247344971)/5.
d = (2.195953181584676 + 2.5948117383321128 + 2.5258967526753744 + 2.5001147365570064 + 2.7355075772603357)/5.
e = (2.130304797490438 + 2.8107280000050863 + 2.250448354085287 + 2.1140146795908614 + 2.0482943630218506)/5.
f = (2.5113415718078613 + 2.098219045003255 + 2.7476954587300617 + 2.7606082439422606 + 2.124862407048543)/5.
g = (2.176120545069377 + 2.146807723045349 + 2.4143293126424155 + 2.2496908124287924 + 2.1703793573379517)/5.
h = (2.1492295328776043 + 2.4391604391733805 + 2.133251234690348 + 2.4489138730367026 + 2.380597751935323)/5.

plt.plot([3,4,5,6,9,18,30,45], [a,b,c,d,e,f,g,h], color='grey', alpha=0.5)
plt.scatter([3,4,5,6,18,30,45], [a,b,c,d,f,g,h], marker="+", color='k')
plt.scatter([9,], [e,], marker="^", label="Selected Model", color="orangered")
plt.xlabel("Number of Factors in Model")
plt.ylabel("Avg. Stopping Criteria Performance Across Folds")
plt.title("Determining the Number of Factors for TST (Subset) REDCLIFF-S Model")
plt.legend()
plt.show()


## Preparing REDCLIFF-S TST Reg. Avg. Model Parameters for Appendices

In [None]:
import numpy as np
# model parameters
# tau_in (i.e. gen_lag_and_input_len) -> USE ORIGINAL VALUE IN CACHED ARGS
# omega (i.e. FORECAST_COEFF)         -> USE ORIGINAL VALUE IN CACHED ARGS
# rho (i.e. FACTOR_COS_SIM_COEFF) -> use FACTOR_COS_SIM_COEFF/sum([1.*i for i in range(1,args_dict["num_factors"])])
print("rho is ", 1.0/sum([1.*i for i in range(1,9)]))
# eta (i.e. ADJ_L1_REG_COEFF) -> ADJ_L1_REG_COEFF*(1./(1.*args_dict["num_factors"]))*(1./np.sqrt(args_dict["num_channels"]**2. - 1.))
print("eta is ", 0.1*(1./(1.*9))*(1./np.sqrt(12**2. - 1.)))
# gamma (i.e. FACTOR_WEIGHT_L1_COEFF) -> USE ORIGINAL VALUE IN CACHED ARGS
# lambda (i.e. FACTOR_SCORE_COEFF)    -> USE ORIGINAL VALUE IN CACHED ARGS


---

# Synthetic Systems Experiment Analyses

## SynSys 12-11-2 MSNR Analyses

NON-TRANSPOSED PREDICTION STATS

In [None]:
import numpy as np
import pickle as pkl

og_stats_fold0 = pkl.load(open("stats_by_alg_key_dict_fold0.pkl", "rb"))
og_stats_fold1 = pkl.load(open("stats_by_alg_key_dict_fold1.pkl", "rb"))
og_stats_fold2 = pkl.load(open("stats_by_alg_key_dict_fold2.pkl", "rb"))
og_stats_fold3 = pkl.load(open("stats_by_alg_key_dict_fold3.pkl", "rb"))
og_stats_fold4 = pkl.load(open("stats_by_alg_key_dict_fold4.pkl", "rb"))

print("og_stats_fold0.keys() == ", og_stats_fold0.keys())
performance_across_folds_by_alg = {key[len("SynSys12112MSNRFold0_model_name_"):]:{"fold_"+str(i): dict() for i in range(5)} for key in og_stats_fold0.keys() if key.startswith("SynSys12112MSNRFold0_model_name_")}
print("performance_across_folds_by_alg.keys() == ", performance_across_folds_by_alg.keys())

for f_ind, og_stats in enumerate([og_stats_fold0, og_stats_fold1, og_stats_fold2, og_stats_fold3, og_stats_fold4]):
    print("f_ind == ", f_ind)
    for og_alg_key in og_stats.keys():
        for alg_key in performance_across_folds_by_alg.keys():
            if alg_key in og_alg_key:
                for factor_key in sorted(list(og_stats[og_alg_key].keys())):
                    for stat_key in og_stats[og_alg_key][factor_key].keys():
                        if stat_key in performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"].keys():
                            if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
                                try:
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key]["normalized"].append(og_stats[og_alg_key][factor_key][stat_key][0])
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key]["raw_count"].append(og_stats[og_alg_key][factor_key][stat_key][1])
                                except:
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key]["normalized"].append(og_stats[og_alg_key][factor_key][stat_key])
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key]["raw_count"].append(og_stats[og_alg_key][factor_key][stat_key])
                            else:
                                performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key].append(og_stats[og_alg_key][factor_key][stat_key])
                        else:
                            if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
                                try:
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key] = {
                                        "normalized": [og_stats[og_alg_key][factor_key][stat_key][0]],
                                        "raw_count": [og_stats[og_alg_key][factor_key][stat_key][1]]
                                    }
                                except:
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key] = {
                                        "normalized": [og_stats[og_alg_key][factor_key][stat_key]],
                                        "raw_count": [og_stats[og_alg_key][factor_key][stat_key]]
                                    }
                            else:
                                performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key] = [og_stats[og_alg_key][factor_key][stat_key]]

print("performance_across_folds_by_alg == ", performance_across_folds_by_alg)
summary_stats_by_alg = {alg_key: dict() for alg_key in performance_across_folds_by_alg.keys()}
for alg_key in performance_across_folds_by_alg.keys():
    for fold_key in performance_across_folds_by_alg[alg_key].keys():
        for stat_key in performance_across_folds_by_alg[alg_key][fold_key].keys():
            if stat_key in summary_stats_by_alg[alg_key].keys():
                if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
                    for substat_key in performance_across_folds_by_alg[alg_key][fold_key][stat_key].keys():
                        summary_stats_by_alg[alg_key][stat_key][substat_key]["fold_means"].append(np.mean(performance_across_folds_by_alg[alg_key][fold_key][stat_key][substat_key]))
                        summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"] = summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"] + performance_across_folds_by_alg[alg_key][fold_key][stat_key][substat_key]
                else:
                    summary_stats_by_alg[alg_key][stat_key]["fold_means"].append(np.mean(performance_across_folds_by_alg[alg_key][fold_key][stat_key]))
                    summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"] = summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"] + performance_across_folds_by_alg[alg_key][fold_key][stat_key]
            else:
                if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
                    summary_stats_by_alg[alg_key][stat_key] = dict()
                    for substat_key in performance_across_folds_by_alg[alg_key][fold_key][stat_key].keys():
                        summary_stats_by_alg[alg_key][stat_key][substat_key] = {
                            "fold_means": [np.mean(performance_across_folds_by_alg[alg_key][fold_key][stat_key][substat_key])],
                            "stat_vals_across_folds_and_factors": []+performance_across_folds_by_alg[alg_key][fold_key][stat_key][substat_key],
                            "combo_stats_mean": None,
                            "combo_stats_sem": None,
                            "mean_of_fold_means": None,
                            "sem_of_fold_means": None,
                        }
                else:
                    summary_stats_by_alg[alg_key][stat_key] = {
                        "fold_means": [np.mean(performance_across_folds_by_alg[alg_key][fold_key][stat_key])],
                        "stat_vals_across_folds_and_factors": []+performance_across_folds_by_alg[alg_key][fold_key][stat_key],
                        "combo_stats_mean": None,
                        "combo_stats_sem": None,
                        "mean_of_fold_means": None,
                        "sem_of_fold_means": None,
                    }
    for stat_key in summary_stats_by_alg[alg_key].keys():
        if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
            for substat_key in performance_across_folds_by_alg[alg_key][fold_key][stat_key].keys():
                summary_stats_by_alg[alg_key][stat_key][substat_key]["combo_stats_mean"] = np.mean(summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"])
                summary_stats_by_alg[alg_key][stat_key][substat_key]["combo_stats_sem"] = np.std(summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"]) / np.sqrt(len(summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"]))
                summary_stats_by_alg[alg_key][stat_key][substat_key]["mean_of_fold_means"] = np.mean(summary_stats_by_alg[alg_key][stat_key][substat_key]["fold_means"])
                summary_stats_by_alg[alg_key][stat_key][substat_key]["sem_of_fold_means"] = np.std(summary_stats_by_alg[alg_key][stat_key][substat_key]["fold_means"]) / np.sqrt(len(summary_stats_by_alg[alg_key][stat_key][substat_key]["fold_means"]))
        else:
            summary_stats_by_alg[alg_key][stat_key]["combo_stats_mean"] = np.mean(summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"])
            summary_stats_by_alg[alg_key][stat_key]["combo_stats_sem"] = np.std(summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"]) / np.sqrt(len(summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"]))
            summary_stats_by_alg[alg_key][stat_key]["mean_of_fold_means"] = np.mean(summary_stats_by_alg[alg_key][stat_key]["fold_means"])
            summary_stats_by_alg[alg_key][stat_key]["sem_of_fold_means"] = np.std(summary_stats_by_alg[alg_key][stat_key]["fold_means"]) / np.sqrt(len(summary_stats_by_alg[alg_key][stat_key]["fold_means"]))

print("summary_stats_by_alg == ", summary_stats_by_alg)
for alg in summary_stats_by_alg.keys():
    print("alg == ", alg)
    for stat in summary_stats_by_alg[alg].keys():
        print("\t stat == ", stat)
        for summary_key in summary_stats_by_alg[alg][stat].keys():
            if "_aid" in stat or "_shd" in stat:
                print("\t\t sub_stat == ", summary_key)
                for sub_stat_summary in summary_stats_by_alg[alg][stat][summary_key].keys():
                    print("\t\t\t summary_key == ", sub_stat_summary, " == ", summary_stats_by_alg[alg][stat][summary_key][sub_stat_summary])
            else:
                print("\t\t summary_key == ", summary_key, " == ", summary_stats_by_alg[alg][stat][summary_key])

TRANSPOSED PREDICTION STATS

In [None]:
import numpy as np
import pickle as pkl

og_stats_fold0 = pkl.load(open("stats_by_alg_key_dict_fold0.pkl", "rb"))
og_stats_fold1 = pkl.load(open("stats_by_alg_key_dict_fold1.pkl", "rb"))
og_stats_fold2 = pkl.load(open("stats_by_alg_key_dict_fold2.pkl", "rb"))
og_stats_fold3 = pkl.load(open("stats_by_alg_key_dict_fold3.pkl", "rb"))
og_stats_fold4 = pkl.load(open("stats_by_alg_key_dict_fold4.pkl", "rb"))

print("og_stats_fold0.keys() == ", og_stats_fold0.keys())
performance_across_folds_by_alg = {key[len("SynSys12112MSNRFold0_model_name_"):]:{"fold_"+str(i): dict() for i in range(5)} for key in og_stats_fold0.keys() if key.startswith("SynSys12112MSNRFold0_model_name_")}
print("performance_across_folds_by_alg.keys() == ", performance_across_folds_by_alg.keys())

for f_ind, og_stats in enumerate([og_stats_fold0, og_stats_fold1, og_stats_fold2, og_stats_fold3, og_stats_fold4]):
    print("f_ind == ", f_ind)
    for og_alg_key in og_stats.keys():
        for alg_key in performance_across_folds_by_alg.keys():
            if alg_key in og_alg_key:
                for factor_key in sorted(list(og_stats[og_alg_key].keys())):
                    for stat_key in og_stats[og_alg_key][factor_key].keys():
                        if stat_key in performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"].keys():
                            if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
                                try:
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key]["normalized"].append(og_stats[og_alg_key][factor_key][stat_key][0])
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key]["raw_count"].append(og_stats[og_alg_key][factor_key][stat_key][1])
                                except:
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key]["normalized"].append(og_stats[og_alg_key][factor_key][stat_key])
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key]["raw_count"].append(og_stats[og_alg_key][factor_key][stat_key])
                            else:
                                performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key].append(og_stats[og_alg_key][factor_key][stat_key])
                        else:
                            if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
                                try:
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key] = {
                                        "normalized": [og_stats[og_alg_key][factor_key][stat_key][0]],
                                        "raw_count": [og_stats[og_alg_key][factor_key][stat_key][1]]
                                    }
                                except:
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key] = {
                                        "normalized": [og_stats[og_alg_key][factor_key][stat_key]],
                                        "raw_count": [og_stats[og_alg_key][factor_key][stat_key]]
                                    }
                            else:
                                performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key] = [og_stats[og_alg_key][factor_key][stat_key]]

print("performance_across_folds_by_alg == ", performance_across_folds_by_alg)
summary_stats_by_alg = {alg_key: dict() for alg_key in performance_across_folds_by_alg.keys()}
for alg_key in performance_across_folds_by_alg.keys():
    for fold_key in performance_across_folds_by_alg[alg_key].keys():
        for stat_key in performance_across_folds_by_alg[alg_key][fold_key].keys():
            if stat_key in summary_stats_by_alg[alg_key].keys():
                if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
                    for substat_key in performance_across_folds_by_alg[alg_key][fold_key][stat_key].keys():
                        summary_stats_by_alg[alg_key][stat_key][substat_key]["fold_means"].append(np.mean(performance_across_folds_by_alg[alg_key][fold_key][stat_key][substat_key]))
                        summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"] = summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"] + performance_across_folds_by_alg[alg_key][fold_key][stat_key][substat_key]
                else:
                    summary_stats_by_alg[alg_key][stat_key]["fold_means"].append(np.mean(performance_across_folds_by_alg[alg_key][fold_key][stat_key]))
                    summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"] = summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"] + performance_across_folds_by_alg[alg_key][fold_key][stat_key]
            else:
                if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
                    summary_stats_by_alg[alg_key][stat_key] = dict()
                    for substat_key in performance_across_folds_by_alg[alg_key][fold_key][stat_key].keys():
                        summary_stats_by_alg[alg_key][stat_key][substat_key] = {
                            "fold_means": [np.mean(performance_across_folds_by_alg[alg_key][fold_key][stat_key][substat_key])],
                            "stat_vals_across_folds_and_factors": []+performance_across_folds_by_alg[alg_key][fold_key][stat_key][substat_key],
                            "combo_stats_mean": None,
                            "combo_stats_sem": None,
                            "mean_of_fold_means": None,
                            "sem_of_fold_means": None,
                        }
                else:
                    summary_stats_by_alg[alg_key][stat_key] = {
                        "fold_means": [np.mean(performance_across_folds_by_alg[alg_key][fold_key][stat_key])],
                        "stat_vals_across_folds_and_factors": []+performance_across_folds_by_alg[alg_key][fold_key][stat_key],
                        "combo_stats_mean": None,
                        "combo_stats_sem": None,
                        "mean_of_fold_means": None,
                        "sem_of_fold_means": None,
                    }
    for stat_key in summary_stats_by_alg[alg_key].keys():
        if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
            for substat_key in performance_across_folds_by_alg[alg_key][fold_key][stat_key].keys():
                summary_stats_by_alg[alg_key][stat_key][substat_key]["combo_stats_mean"] = np.mean(summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"])
                summary_stats_by_alg[alg_key][stat_key][substat_key]["combo_stats_sem"] = np.std(summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"]) / np.sqrt(len(summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"]))
                summary_stats_by_alg[alg_key][stat_key][substat_key]["mean_of_fold_means"] = np.mean(summary_stats_by_alg[alg_key][stat_key][substat_key]["fold_means"])
                summary_stats_by_alg[alg_key][stat_key][substat_key]["sem_of_fold_means"] = np.std(summary_stats_by_alg[alg_key][stat_key][substat_key]["fold_means"]) / np.sqrt(len(summary_stats_by_alg[alg_key][stat_key][substat_key]["fold_means"]))
        else:
            summary_stats_by_alg[alg_key][stat_key]["combo_stats_mean"] = np.mean(summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"])
            summary_stats_by_alg[alg_key][stat_key]["combo_stats_sem"] = np.std(summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"]) / np.sqrt(len(summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"]))
            summary_stats_by_alg[alg_key][stat_key]["mean_of_fold_means"] = np.mean(summary_stats_by_alg[alg_key][stat_key]["fold_means"])
            summary_stats_by_alg[alg_key][stat_key]["sem_of_fold_means"] = np.std(summary_stats_by_alg[alg_key][stat_key]["fold_means"]) / np.sqrt(len(summary_stats_by_alg[alg_key][stat_key]["fold_means"]))

print("summary_stats_by_alg == ", summary_stats_by_alg)
for alg in summary_stats_by_alg.keys():
    print("alg == ", alg)
    for stat in summary_stats_by_alg[alg].keys():
        print("\t stat == ", stat)
        for summary_key in summary_stats_by_alg[alg][stat].keys():
            if "_aid" in stat or "_shd" in stat:
                print("\t\t sub_stat == ", summary_key)
                for sub_stat_summary in summary_stats_by_alg[alg][stat][summary_key].keys():
                    print("\t\t\t summary_key == ", sub_stat_summary, " == ", summary_stats_by_alg[alg][stat][summary_key][sub_stat_summary])
            else:
                print("\t\t summary_key == ", summary_key, " == ", summary_stats_by_alg[alg][stat][summary_key])

## Visualizing Estimated Synthetic Systems Factors 01/29/2025

In [None]:
!pip3 install torcheeg

In [None]:
!pip3 install torch_scatter

In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt


# configure model
model_645Fold4 = torch.load("645Fold4_final_model.bin")
model_645Fold4.primary_gc_est_mode = "fixed_factor_exclusive"

In [None]:
# get gc factor ests
gc_ests_by_sample = model_645Fold4.GC(model_645Fold4.primary_gc_est_mode, X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)
assert len(gc_ests_by_sample) == 1
gc_ests = [x.detach().numpy() for x in gc_ests_by_sample[0]]

print("len(gc_ests) == ", len(gc_ests))
print("gc_ests[0].shape == ", gc_ests[0].shape)

gc_ests_noLags = [np.sum(x, axis=2) for x in gc_ests]
print("gc_ests_noLags[0].shape == ", gc_ests_noLags[0].shape)

off_diag_nolag_ests = [x-np.eye(x.shape[0])*x for x in gc_ests_noLags]
off_diag_nolag_ests = [x/np.max(x) for x in off_diag_nolag_ests]

mask = off_diag_nolag_ests[1] > 0.34
f1_gc_nolag_est = off_diag_nolag_ests[1] * mask
im1 = plt.imshow(f1_gc_nolag_est)
plt.title('Factor 1 Top-5 GC Estimate, Threshold: '+str(0.34))
plt.ylabel('Affected series')
plt.xlabel('Causal series')
plt.xticks([])
plt.yticks([])
plt.colorbar(im1)
plt.show()


mask = off_diag_nolag_ests[3] > 0.4
f3_gc_nolag_est = off_diag_nolag_ests[3] * mask
im2 = plt.imshow(f3_gc_nolag_est)
plt.title('Factor 3 Top-5 GC Estimate, Threshold: '+str(0.4))
plt.ylabel('Affected series')
plt.xlabel('Causal series')
plt.xticks([])
plt.yticks([])
plt.colorbar(im2)
plt.show()

In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt


# configure model
model_12_11_5_Fold4 = torch.load("12_11_5_Fold4_final_model.bin")
model_12_11_5_Fold4.primary_gc_est_mode = "fixed_factor_exclusive"

In [None]:
# get gc factor ests
gc_ests_by_sample = model_12_11_5_Fold4.GC(model_12_11_5_Fold4.primary_gc_est_mode, X=None, threshold=False, ignore_lag=False, combine_wavelet_representations=True, rank_wavelets=False)
assert len(gc_ests_by_sample) == 1
gc_ests = [x.detach().numpy() for x in gc_ests_by_sample[0]]

print("len(gc_ests) == ", len(gc_ests))
print("gc_ests[0].shape == ", gc_ests[0].shape)

gc_ests_noLags = [np.sum(x, axis=2) for x in gc_ests]
print("gc_ests_noLags[0].shape == ", gc_ests_noLags[0].shape)

off_diag_nolag_ests = [x-np.eye(x.shape[0])*x for x in gc_ests_noLags]
off_diag_nolag_ests = [x/np.max(x) for x in off_diag_nolag_ests]

masking_thresholds = [0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, ]
for t in masking_thresholds:
    print("Threshold: ", t, " ---------------------------------------------")
    for i, gc_nolag_est in enumerate(off_diag_nolag_ests):
        mask = gc_nolag_est > t
        gc_nolag_est = gc_nolag_est * mask
        im1 = plt.imshow(gc_nolag_est)
        plt.title('GC Estimate, Threshold: '+str(t))
        plt.ylabel('Affected series')
        plt.xlabel('Causal series')
        plt.xticks([])
        plt.yticks([])
        plt.colorbar(im1)
        plt.show()

## Highlighted Synth Data Results (01/21/2025)

In [None]:

import numpy as np
import pickle as pkl
from matplotlib import pyplot as plt
from matplotlib import gridspec

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title



def get_data_name_alias_for_plot_axes(orig_name):
    return orig_name + "   "

def get_data_name_alias(orig_name):
    split_data_name = orig_name.split("_")
    abrieve_split_data_name = [int(x[2:]) for x in split_data_name]
    data_name_alias = "-".join([str(x) for x in abrieve_split_data_name])
    return data_name_alias

def get_alg_name_alias(orig_name):
    if orig_name == 'REDCLIFF_S_CMLP':
        return 'REDCLIFF-S (cMLP)'
    elif orig_name == 'REDCLIFF_S_CMLP_WithSmoothing':
        return 'REDCLIFF-S (cMLP)'
    elif orig_name == 'CMLP':
        return 'cMLP'
    elif orig_name == 'CLSTM':
        return 'cLSTM'
    elif orig_name == 'DCSFA':
        return 'dCSFA-NMF'
    elif orig_name == 'DYNOTEARS_Vanilla':
        return 'DYNOTEARS'
    elif orig_name == 'NAVAR_CMLP':
        return 'NAVAR-P'
    elif orig_name == 'NAVAR_CLSTM':
        return 'NAVAR-R'
    return orig_name



dataset_names = [
    "nN6_nE2_nF2",
    "nN6_nE4_nF2",
    "nN12_nE11_nF2",
    "nN12_nE11_nF5",
]
alg_names = []
mean_colors = ["darkorange", "darkred", "mediumvioletred", "darkslateblue", "indigo"]
sem_colors = ["orangered", "tomato", "lightcoral", "slategrey", "mediumpurple"]

alg_performance_means_tOpt = None
alg_performance_sems_tOpt = None
alg_vREDC_performance_means_tOpt = None
alg_vREDC_performance_sems_tOpt = None


for d, dataset in enumerate(dataset_names):
    curr_alg_performance_means_tOpt = []
    curr_alg_performance_sems_tOpt = []
    curr_alg_vREDC_performance_means_tOpt = []
    curr_alg_vREDC_performance_sems_tOpt = []

    # read in bsOH results
    curr_synth_results = None
    with open(dataset+"_full_comparrisson_summary.pkl", "rb") as f:
        curr_synth_results = pkl.load(f)

    assert len(curr_synth_results.keys()) == 1
    cv_key = list(curr_synth_results.keys())[0]
    for i, alg_key in enumerate(curr_synth_results[cv_key]['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'].keys()):
        if d == 0:
            alg_names.append(get_alg_name_alias(alg_key))
        else:
            assert get_alg_name_alias(alg_key) == alg_names[i]

        curr_alg_eval_stats = curr_synth_results[cv_key]['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'][alg_key]
        for stat_key in curr_alg_eval_stats.keys():
            if 'f1' in stat_key and "mean_across_factors" in stat_key:
                curr_alg_performance_means_tOpt.append(curr_alg_eval_stats[stat_key])
            elif 'f1' in stat_key and "mean_std_err_across_factors" in stat_key:
                curr_alg_performance_sems_tOpt.append(curr_alg_eval_stats[stat_key])
            elif 'f1' in stat_key and "vals_across_factors" in stat_key:
                curr_diffs = [x-y for x, y in zip(curr_synth_results[cv_key]['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag']['REDCLIFF_S_CMLP_WithSmoothing'][stat_key], curr_alg_eval_stats[stat_key])]
                curr_alg_vREDC_performance_means_tOpt.append(np.mean(curr_diffs))
                curr_alg_vREDC_performance_sems_tOpt.append(np.std(curr_diffs)/np.sqrt(1.*len(curr_diffs)))

        assert len(curr_alg_performance_means_tOpt) == i+1
        assert len(curr_alg_performance_sems_tOpt) == i+1
        assert len(curr_alg_vREDC_performance_means_tOpt) == i+1
        assert len(curr_alg_vREDC_performance_sems_tOpt) == i+1

    if d == 0:
        alg_performance_means_tOpt = [None for _ in range(len(dataset_names)) for _ in range(len(alg_names))]
        alg_performance_sems_tOpt = [None for _ in range(len(dataset_names)) for _ in range(len(alg_names))]
        alg_vREDC_performance_means_tOpt = [None for _ in range(len(dataset_names)) for _ in range(len(alg_names))]
        alg_vREDC_performance_sems_tOpt = [None for _ in range(len(dataset_names)) for _ in range(len(alg_names))]

    for i in range(len(alg_names)):
        if curr_alg_performance_means_tOpt[i] is None or np.isfinite(curr_alg_performance_means_tOpt[i]):
            alg_performance_means_tOpt[d*len(alg_names) + i] = curr_alg_performance_means_tOpt[i]
        else:
            alg_performance_means_tOpt[d*len(alg_names) + i] = np.nan
        if curr_alg_performance_sems_tOpt[i] is None or np.isfinite(curr_alg_performance_sems_tOpt[i]):
            alg_performance_sems_tOpt[d*len(alg_names) + i] = curr_alg_performance_sems_tOpt[i]
        else:
            alg_performance_sems_tOpt[d*len(alg_names) + i] = np.nan
        if curr_alg_vREDC_performance_means_tOpt[i] is None or np.isfinite(curr_alg_vREDC_performance_means_tOpt[i]):
            alg_vREDC_performance_means_tOpt[d*len(alg_names) + i] = curr_alg_vREDC_performance_means_tOpt[i]
        else:
            alg_vREDC_performance_means_tOpt[d*len(alg_names) + i] = np.nan
        if curr_alg_vREDC_performance_sems_tOpt[i] is None or np.isfinite(curr_alg_vREDC_performance_sems_tOpt[i]):
            alg_vREDC_performance_sems_tOpt[d*len(alg_names) + i] = curr_alg_vREDC_performance_sems_tOpt[i]
        else:
            alg_vREDC_performance_sems_tOpt[d*len(alg_names) + i] = np.nan


# remainder of code drafted with help from ChatGPT
# Create figure and axis
fig = plt.figure(figsize=(9.5, 7))
gs = gridspec.GridSpec(1, 1)
ax1 = plt.subplot(gs[0])


bar_width = 0.9
index = np.arange((len(alg_names)+1)*len(dataset_names))
total_num_bars_per_dset = len(alg_names)# * 2.
width_of_bars_per_dset = bar_width#/2

alg_names = alg_names + [None]
mean_colors = mean_colors + [None]
sem_colors = sem_colors + [None]
alg_index_offset = []
for d in range(len(dataset_names)):
    alg_index_offset = alg_index_offset + [d for _ in range(len(alg_names))]

# Horizontal bar plot with whiskers
for a, (alg_name, mean_color, sem_color) in enumerate(zip(alg_names, mean_colors, sem_colors)):
    if alg_name is not None:
        curr_inds = None
        curr_means = None
        curr_sems = None

        print("alg_name == ", alg_name)
        print("\t len(alg_performance_means_tOpt) == ", len(alg_performance_means_tOpt))

        curr_inds = [ind for ind in index if ind % len(alg_names) == a]
        print("\t curr_inds == ", curr_inds)
        curr_means = [alg_performance_means_tOpt[ind-offset] for ind, offset in zip(index, alg_index_offset) if ind % len(alg_names) == a]
        curr_sems = [alg_performance_sems_tOpt[ind-offset] for ind, offset in zip(index, alg_index_offset) if ind % len(alg_names) == a]

        print("\t len(curr_sems) == ", len(curr_sems))

        ax1.barh([ind - width_of_bars_per_dset/2 for ind in curr_inds], curr_means, xerr=curr_sems, ecolor=sem_color, height=bar_width, color=mean_color, capsize=5, label=alg_name)

ax1.set_yticks([i-1 for i in index])
ylabels = []
for ind, offset in zip(index, alg_index_offset):
    if ind % len(alg_names) == 0:
        curr_label = None
        curr_data_loc = offset
        curr_label = dataset_names[curr_data_loc]
        curr_label = get_data_name_alias(curr_label)
        ylabels.append(curr_label)
    else:
        ylabels.append("")
ax1.set_yticklabels([get_data_name_alias_for_plot_axes(lab) for lab in ylabels], rotation=90)
ax1.yaxis.set_ticks_position('none')

# Customize the grid: Add both major and minor grid lines
ax1.grid(True, axis='x', which='major', linestyle=':', linewidth=0.75, color='grey')  # Major grid lines
ax1.minorticks_on()  # Enable minor ticks
ax1.grid(True, axis='x', which='minor', linestyle=':', linewidth=0.5, color='lightgray')  # Minor grid lines

# Optional: Customize appearance
ax1.invert_yaxis()  # Invert y-axis to display the first category at the top

# Add legend
ax1.legend()

# Show plot
ax1.set_title('Synthetic System Edge Prediction')
ax1.set_xlabel('Avg. Optimal F1-Score '+r'$\pm$'+' Std. Err. of the Mean')
ax1.set_ylabel("Synthetic System Name ("+r'$n_c$'+"-"+r'$n_e$'+"-"+r'$n_k$'+")")
ax1.set_xlim(.0, 0.70)
plt.tight_layout()
plt.show()


# remainder of code drafted with help from ChatGPT
# Create figure and axis
fig = plt.figure(figsize=(9.5, 7))
gs = gridspec.GridSpec(1, 1)
ax1 = plt.subplot(gs[0])

# Horizontal bar plot with whiskers
for a, (alg_name, mean_color, sem_color) in enumerate(zip(alg_names, mean_colors, sem_colors)):
    if alg_name is not None:
        curr_inds = None
        curr_means = None
        curr_sems = None

        print("alg_name == ", alg_name)
        print("\t len(alg_vREDC_performance_means_tOpt) == ", len(alg_vREDC_performance_means_tOpt))

        curr_inds = [ind for ind in index if ind % len(alg_names) == a]
        print("\t curr_inds == ", curr_inds)
        curr_means = [alg_vREDC_performance_means_tOpt[ind-offset] for ind, offset in zip(index, alg_index_offset) if ind % len(alg_names) == a]
        curr_sems = [alg_vREDC_performance_sems_tOpt[ind-offset] for ind, offset in zip(index, alg_index_offset) if ind % len(alg_names) == a]

        print("\t len(curr_sems) == ", len(curr_sems))

        ax1.barh([ind - width_of_bars_per_dset/2 for ind in curr_inds], curr_means, xerr=curr_sems, ecolor=sem_color, height=bar_width, color=mean_color, capsize=5, label=alg_name)

ax1.set_yticks([i-1 for i in index])
ylabels = []
for ind, offset in zip(index, alg_index_offset):
    if ind % len(alg_names) == 0:
        curr_label = None
        curr_data_loc = offset
        curr_label = dataset_names[curr_data_loc]
        curr_label = get_data_name_alias(curr_label)
        ylabels.append(curr_label)
    else:
        ylabels.append("")
ax1.set_yticklabels([get_data_name_alias_for_plot_axes(lab) for lab in ylabels], rotation=90)
ax1.yaxis.set_ticks_position('none')

# Customize the grid: Add both major and minor grid lines
ax1.grid(True, axis='x', which='major', linestyle=':', linewidth=0.75, color='grey')  # Major grid lines
ax1.minorticks_on()  # Enable minor ticks
ax1.grid(True, axis='x', which='minor', linestyle=':', linewidth=0.5, color='lightgray')  # Minor grid lines

# Optional: Customize appearance
ax1.invert_yaxis()  # Invert y-axis to display the first category at the top

# Add legend
ax1.legend()

# Show plot
ax1.set_title('Pair-wise Improvement by REDCLIFF-S for\nSynthetic System Edge Prediction')
ax1.set_xlabel('Avg. Difference In Optimal F1-Score '+r'$\pm$'+' Std. Err. of the Mean')
ax1.set_ylabel("Synthetic System Name ("+r'$n_c$'+"-"+r'$n_e$'+"-"+r'$n_k$'+")")
ax1.set_xlim(-.005, 0.4)
plt.tight_layout()
plt.show()


## Summarizing All Synthetic Systems Experiments 01/21/2025

In [None]:
# REDCLIFF-S WITH NO SIGMOID ACITVATION Improvements On Gaussian-noise/Innov Data

import numpy as np
from matplotlib import pyplot as plt

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title

systems = [
    "3-1-2", "3-1-3", "3-1-4", "3-1-5", "3-2-2",
    "6-2-2", "6-2-3", "6-2-4", "6-2-5", "6-2-6", "6-2-7", "6-2-8", "6-2-9", "6-2-10",
    "6-4-2", "6-4-3", "6-4-4", "6-4-5", "6-4-6",
    "6-6-2", "6-6-3", "6-6-4", "6-8-2", "6-8-3", "6-10-2", "6-12-2",
    "12-11-2", "12-11-3", "12-11-4", "12-11-5", "12-11-6", "12-11-7", "12-11-8", "12-11-9", "12-11-10",
    "12-12-2", "12-12-3", "12-12-4", "12-12-5", "12-12-6", "12-12-7", "12-12-8", "12-12-9",
    "12-22-2", "12-22-3", "12-22-4", "12-22-5",
    "12-33-2", "12-33-3", "12-44-2", "12-55-2",
]

systems_where_no_significant_improvement_shown = {
    "3-1-2": [-1], "3-1-3": [-1], "3-1-4": [-1], "3-1-5": [-1], "3-2-2": [-1],
    "6-2-5": [-1], "6-4-3": [-1], "6-6-3": [-1], "6-8-2": [-1], "6-8-3": [-1], "6-12-2": [-1],
    "12-22-3": [-1], "12-22-4": [-1], "12-22-5": [-1],
    "12-33-2": [-1], "12-33-3": [-1], "12-44-2": [-1], "12-55-2": [-1],
}
systems_where_signif_improv_shown_withOptThresh = {
    "6-2-2": [1], "6-2-3": [1], "6-2-6": [1], "6-2-7": [1], "6-2-8": [1], "6-2-9": [1],
    "6-4-2": [1], "6-4-4": [1], "6-4-5": [1], "6-4-6": [1], "6-6-2": [1], "6-6-4": [1], "6-10-2": [1],
    "12-11-2": [1], "12-11-3": [1], "12-11-4": [1], "12-11-5": [1], "12-11-6": [1], "12-11-7": [1], "12-11-8": [1], "12-11-9": [1],
    "12-12-2": [1], "12-12-3": [1], "12-12-4": [1], "12-12-5": [1], "12-12-6": [1], "12-12-7": [1], "12-12-8": [1], "12-12-9": [1], "12-22-2": [1],
}
systems_with_no_results = {
    "6-2-4": [0],
    "6-2-10": [0],
    "12-11-10": [0],
}

system_colors = ['grey','darkorange', 'blue']
system_markers = ['X', 'P', 'o']
system_labels = ['Not Significant for All Baselines', "Significant for All Baselines", "No Result"]

C = lambda x: (x[1]/(x[0]**2 - x[0]))**(-1)

plt.rcParams["figure.figsize"] = (17,8)
labels_plotted = []
for key in systems:
    split_key = key.split("-")
    nc = int(split_key[0])
    ne = int(split_key[1])
    nk = int(split_key[2])
    curr_complexity = C((nc, ne, nk))

    print("nc ", nc, ", ne ", ne, ", nk ", nk, ": C == ", curr_complexity)

    for i, dataset_to_score_map in enumerate([systems_where_no_significant_improvement_shown, systems_where_signif_improv_shown_withOptThresh, systems_with_no_results]):
        if key in dataset_to_score_map.keys():
            if i not in labels_plotted:
                plt.scatter([key], [curr_complexity], color=system_colors[i], marker=system_markers[i], s=100, label=system_labels[i], alpha=1)
                labels_plotted.append(i)
            else:
                plt.scatter([key], [curr_complexity], color=system_colors[i], marker=system_markers[i], s=100, alpha=1)

plt.plot([7. for _ in range(len(systems))], '-k')
plt.plot([13. for _ in range(len(systems))], '-k')
plt.xticks(rotation=70)
plt.ylim(6.,14.)
plt.ylabel("System Complexity ("+r'$\mathfrak{C}$'+")")
plt.xlabel("Synthetic System Name ("+r'$n_c$'+"-"+r'$n_e$'+"-"+r'$n_k$'+")")
plt.title("Smoothed REDCLIFF-S Improvement Over Baselines (Gaussian Noise): Pair-wise Optimal F1-Scores")
plt.legend()
plt.show()

plt.rcParams["figure.figsize"] = (17,8)
labels_plotted = []
for key in systems:
    split_key = key.split("-")
    nc = int(split_key[0])
    ne = int(split_key[1])
    nk = int(split_key[2])
    curr_complexity = C((nc, ne, nk))

    print("nc ", nc, ", ne ", ne, ", nk ", nk, ": C == ", curr_complexity)

    for i, dataset_to_score_map in enumerate([systems_where_no_significant_improvement_shown, systems_where_signif_improv_shown_withOptThresh, systems_with_no_results]):
        if key in dataset_to_score_map.keys():
            if i not in labels_plotted:
                print("\t key == ", key)
                plt.scatter([key], [curr_complexity], color=system_colors[i], marker=system_markers[i], s=100, label=system_labels[i], alpha=1)
                labels_plotted.append(i)
            else:
                plt.scatter([key], [curr_complexity], color=system_colors[i], marker=system_markers[i], s=100, alpha=1)

plt.plot([7. for _ in range(len(systems))], '-k')
plt.plot([13. for _ in range(len(systems))], '-k')
plt.xticks(rotation=70)
plt.ylabel("System Complexity ("+r'$\mathfrak{C}$'+")")
plt.xlabel("Synthetic System Name ("+r'$n_c$'+"-"+r'$n_e$'+"-"+r'$n_k$'+")")
plt.title("REDCLIFF-S Improvement Over Baselines (Gaussian Noise): Pair-wise Optimal F1-Scores")
plt.legend()
plt.show()

## Preparing REDCLIFF-S Synthetic Systems Model Parameters for Appendices

In [None]:
import numpy as np
# model parameters
# tau_in (i.e. gen_lag_and_input_len) -> USE ORIGINAL VALUE IN CACHED ARGS
# omega (i.e. FORECAST_COEFF)         -> USE ORIGINAL VALUE IN CACHED ARGS
# rho (i.e. FACTOR_COS_SIM_COEFF) -> use FACTOR_COS_SIM_COEFF/sum([1.*i for i in range(1,args_dict["num_factors"])])
print("rho is ", _/sum([1.*i for i in range(1,_)]))
# eta (i.e. ADJ_L1_REG_COEFF) -> ADJ_L1_REG_COEFF*(1./(1.*args_dict["num_factors"]))*(1./np.sqrt(args_dict["num_channels"]**2. - 1.))
print("eta is ", _*(1./(1.*_))*(1./np.sqrt(_**2. - 1.)))
# gamma (i.e. FACTOR_WEIGHT_L1_COEFF) -> USE ORIGINAL VALUE IN CACHED ARGS
# lambda (i.e. FACTOR_SCORE_COEFF)    -> USE ORIGINAL VALUE IN CACHED ARGS


### Highlighted Synth Data Results (01/18/2025)

In [None]:

import numpy as np
import pickle as pkl
from matplotlib import pyplot as plt
from matplotlib import gridspec

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title



def get_data_name_alias_for_plot_axes(orig_name):
    return orig_name + "   "

def get_data_name_alias(orig_name):
    split_data_name = orig_name.split("_")
    abrieve_split_data_name = [int(x[2:]) for x in split_data_name]
    data_name_alias = "-".join([str(x) for x in abrieve_split_data_name])
    return data_name_alias

def get_alg_name_alias(orig_name):
    if orig_name == 'REDCLIFF_S_CMLP':
        return 'REDCLIFF-S (cMLP)'
    elif orig_name == 'REDCLIFF_S_CMLP_WithSmoothing':
        return 'REDCLIFF-S (cMLP)'
    elif orig_name == 'CMLP':
        return 'cMLP'
    elif orig_name == 'CLSTM':
        return 'cLSTM'
    elif orig_name == 'DCSFA':
        return 'dCSFA-NMF'
    elif orig_name == 'DYNOTEARS_Vanilla':
        return 'DYNOTEARS'
    elif orig_name == 'NAVAR_CMLP':
        return 'NAVAR-P'
    elif orig_name == 'NAVAR_CLSTM':
        return 'NAVAR-R'
    return orig_name



dataset_names = [
    "nN6_nE2_nF2",
    "nN6_nE4_nF2",
    "nN12_nE11_nF2",
    "nN12_nE11_nF5",
]
alg_names = []
mean_colors = ["darkorange", "darkred", "mediumvioletred", "darkslateblue", "indigo"]
sem_colors = ["orangered", "tomato", "lightcoral", "slategrey", "mediumpurple"]

alg_performance_means_tOpt = None
alg_performance_sems_tOpt = None
alg_vREDC_performance_means_tOpt = None
alg_vREDC_performance_sems_tOpt = None


for d, dataset in enumerate(dataset_names):
    curr_alg_performance_means_tOpt = []
    curr_alg_performance_sems_tOpt = []
    curr_alg_vREDC_performance_means_tOpt = []
    curr_alg_vREDC_performance_sems_tOpt = []

    # read in bsOH results
    curr_synth_results = None
    with open(dataset+"_full_comparrisson_summary.pkl", "rb") as f:
        curr_synth_results = pkl.load(f)

    assert len(curr_synth_results.keys()) == 1
    cv_key = list(curr_synth_results.keys())[0]
    for i, alg_key in enumerate(curr_synth_results[cv_key]['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'].keys()):
        if d == 0:
            alg_names.append(get_alg_name_alias(alg_key))
        else:
            assert get_alg_name_alias(alg_key) == alg_names[i]

        curr_alg_eval_stats = curr_synth_results[cv_key]['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'][alg_key]
        for stat_key in curr_alg_eval_stats.keys():
            if 'f1' in stat_key and "mean_across_factors" in stat_key:
                curr_alg_performance_means_tOpt.append(curr_alg_eval_stats[stat_key])
            elif 'f1' in stat_key and "mean_std_err_across_factors" in stat_key:
                curr_alg_performance_sems_tOpt.append(curr_alg_eval_stats[stat_key])
            elif 'f1' in stat_key and "vals_across_factors" in stat_key:
                curr_diffs = [x-y for x, y in zip(curr_synth_results[cv_key]['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag']['REDCLIFF_S_CMLP_WithSmoothing'][stat_key], curr_alg_eval_stats[stat_key])]
                curr_alg_vREDC_performance_means_tOpt.append(np.mean(curr_diffs))
                curr_alg_vREDC_performance_sems_tOpt.append(np.std(curr_diffs)/np.sqrt(1.*len(curr_diffs)))

        assert len(curr_alg_performance_means_tOpt) == i+1
        assert len(curr_alg_performance_sems_tOpt) == i+1
        assert len(curr_alg_vREDC_performance_means_tOpt) == i+1
        assert len(curr_alg_vREDC_performance_sems_tOpt) == i+1

    if d == 0:
        alg_performance_means_tOpt = [None for _ in range(len(dataset_names)) for _ in range(len(alg_names))]
        alg_performance_sems_tOpt = [None for _ in range(len(dataset_names)) for _ in range(len(alg_names))]
        alg_vREDC_performance_means_tOpt = [None for _ in range(len(dataset_names)) for _ in range(len(alg_names))]
        alg_vREDC_performance_sems_tOpt = [None for _ in range(len(dataset_names)) for _ in range(len(alg_names))]

    for i in range(len(alg_names)):
        if curr_alg_performance_means_tOpt[i] is None or np.isfinite(curr_alg_performance_means_tOpt[i]):
            alg_performance_means_tOpt[d*len(alg_names) + i] = curr_alg_performance_means_tOpt[i]
        else:
            alg_performance_means_tOpt[d*len(alg_names) + i] = np.nan
        if curr_alg_performance_sems_tOpt[i] is None or np.isfinite(curr_alg_performance_sems_tOpt[i]):
            alg_performance_sems_tOpt[d*len(alg_names) + i] = curr_alg_performance_sems_tOpt[i]
        else:
            alg_performance_sems_tOpt[d*len(alg_names) + i] = np.nan
        if curr_alg_vREDC_performance_means_tOpt[i] is None or np.isfinite(curr_alg_vREDC_performance_means_tOpt[i]):
            alg_vREDC_performance_means_tOpt[d*len(alg_names) + i] = curr_alg_vREDC_performance_means_tOpt[i]
        else:
            alg_vREDC_performance_means_tOpt[d*len(alg_names) + i] = np.nan
        if curr_alg_vREDC_performance_sems_tOpt[i] is None or np.isfinite(curr_alg_vREDC_performance_sems_tOpt[i]):
            alg_vREDC_performance_sems_tOpt[d*len(alg_names) + i] = curr_alg_vREDC_performance_sems_tOpt[i]
        else:
            alg_vREDC_performance_sems_tOpt[d*len(alg_names) + i] = np.nan


# remainder of code drafted with help from ChatGPT
# Create figure and axis
fig = plt.figure(figsize=(9.5, 7))
gs = gridspec.GridSpec(1, 1)
ax1 = plt.subplot(gs[0])


bar_width = 0.9
index = np.arange((len(alg_names)+1)*len(dataset_names))
total_num_bars_per_dset = len(alg_names)# * 2.
width_of_bars_per_dset = bar_width#/2

alg_names = alg_names + [None]
mean_colors = mean_colors + [None]
sem_colors = sem_colors + [None]
alg_index_offset = []
for d in range(len(dataset_names)):
    alg_index_offset = alg_index_offset + [d for _ in range(len(alg_names))]

# Horizontal bar plot with whiskers
for a, (alg_name, mean_color, sem_color) in enumerate(zip(alg_names, mean_colors, sem_colors)):
    if alg_name is not None:
        curr_inds = None
        curr_means = None
        curr_sems = None

        print("alg_name == ", alg_name)
        print("\t len(alg_performance_means_tOpt) == ", len(alg_performance_means_tOpt))

        curr_inds = [ind for ind in index if ind % len(alg_names) == a]
        print("\t curr_inds == ", curr_inds)
        curr_means = [alg_performance_means_tOpt[ind-offset] for ind, offset in zip(index, alg_index_offset) if ind % len(alg_names) == a]
        curr_sems = [alg_performance_sems_tOpt[ind-offset] for ind, offset in zip(index, alg_index_offset) if ind % len(alg_names) == a]

        print("\t len(curr_sems) == ", len(curr_sems))

        ax1.barh([ind - width_of_bars_per_dset/2 for ind in curr_inds], curr_means, xerr=curr_sems, ecolor=sem_color, height=bar_width, color=mean_color, capsize=5, label=alg_name)

ax1.set_yticks([i-1 for i in index])
ylabels = []
for ind, offset in zip(index, alg_index_offset):
    if ind % len(alg_names) == 0:
        curr_label = None
        curr_data_loc = offset
        curr_label = dataset_names[curr_data_loc]
        curr_label = get_data_name_alias(curr_label)
        ylabels.append(curr_label)
    else:
        ylabels.append("")
ax1.set_yticklabels([get_data_name_alias_for_plot_axes(lab) for lab in ylabels], rotation=90)
ax1.yaxis.set_ticks_position('none')

# Customize the grid: Add both major and minor grid lines
ax1.grid(True, axis='x', which='major', linestyle=':', linewidth=0.75, color='grey')  # Major grid lines
ax1.minorticks_on()  # Enable minor ticks
ax1.grid(True, axis='x', which='minor', linestyle=':', linewidth=0.5, color='lightgray')  # Minor grid lines

# Optional: Customize appearance
ax1.invert_yaxis()  # Invert y-axis to display the first category at the top

# Add legend
ax1.legend()

# Show plot
ax1.set_title('Synthetic System Edge Prediction')
ax1.set_xlabel('Avg. Optimal F1-Score '+r'$\pm$'+' Std. Err. of the Mean')
ax1.set_ylabel("Synthetic System Name ("+r'$n_c$'+"-"+r'$n_e$'+"-"+r'$n_k$'+")")
ax1.set_xlim(.0, 0.70)
plt.tight_layout()
plt.show()


# remainder of code drafted with help from ChatGPT
# Create figure and axis
fig = plt.figure(figsize=(9.5, 7))
gs = gridspec.GridSpec(1, 1)
ax1 = plt.subplot(gs[0])

# Horizontal bar plot with whiskers
for a, (alg_name, mean_color, sem_color) in enumerate(zip(alg_names, mean_colors, sem_colors)):
    if alg_name is not None:
        curr_inds = None
        curr_means = None
        curr_sems = None

        print("alg_name == ", alg_name)
        print("\t len(alg_vREDC_performance_means_tOpt) == ", len(alg_vREDC_performance_means_tOpt))

        curr_inds = [ind for ind in index if ind % len(alg_names) == a]
        print("\t curr_inds == ", curr_inds)
        curr_means = [alg_vREDC_performance_means_tOpt[ind-offset] for ind, offset in zip(index, alg_index_offset) if ind % len(alg_names) == a]
        curr_sems = [alg_vREDC_performance_sems_tOpt[ind-offset] for ind, offset in zip(index, alg_index_offset) if ind % len(alg_names) == a]

        print("\t len(curr_sems) == ", len(curr_sems))

        ax1.barh([ind - width_of_bars_per_dset/2 for ind in curr_inds], curr_means, xerr=curr_sems, ecolor=sem_color, height=bar_width, color=mean_color, capsize=5, label=alg_name)

ax1.set_yticks([i-1 for i in index])
ylabels = []
for ind, offset in zip(index, alg_index_offset):
    if ind % len(alg_names) == 0:
        curr_label = None
        curr_data_loc = offset
        curr_label = dataset_names[curr_data_loc]
        curr_label = get_data_name_alias(curr_label)
        ylabels.append(curr_label)
    else:
        ylabels.append("")
ax1.set_yticklabels([get_data_name_alias_for_plot_axes(lab) for lab in ylabels], rotation=90)
ax1.yaxis.set_ticks_position('none')

# Customize the grid: Add both major and minor grid lines
ax1.grid(True, axis='x', which='major', linestyle=':', linewidth=0.75, color='grey')  # Major grid lines
ax1.minorticks_on()  # Enable minor ticks
ax1.grid(True, axis='x', which='minor', linestyle=':', linewidth=0.5, color='lightgray')  # Minor grid lines

# Optional: Customize appearance
ax1.invert_yaxis()  # Invert y-axis to display the first category at the top

# Add legend
ax1.legend()

# Show plot
ax1.set_title('Pair-wise Improvement by REDCLIFF-S for\nSynthetic System Edge Prediction')
ax1.set_xlabel('Avg. Difference In Optimal F1-Score '+r'$\pm$'+' Std. Err. of the Mean')
ax1.set_ylabel("Synthetic System Name ("+r'$n_c$'+"-"+r'$n_e$'+"-"+r'$n_k$'+")")
ax1.set_xlim(-.005, 0.4)
plt.tight_layout()
plt.show()


### Summarizing All Synthetic Systems Experiments 01/17/2025

In [None]:
# REDCLIFF-S WITH NO SIGMOID ACITVATION Improvements On Gaussian-noise/Innov Data

import numpy as np
from matplotlib import pyplot as plt

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title

systems = [
    "3-1-2", "3-1-3", "3-1-4", "3-1-5", "3-2-2",
    "6-2-2", "6-2-3", "6-2-4", "6-2-5", "6-2-6", "6-2-7", "6-2-8", "6-2-9", "6-2-10",
    "6-4-2", "6-4-3", "6-4-4", "6-4-5", "6-4-6",
    "6-6-2", "6-6-3", "6-6-4", "6-8-2", "6-8-3", "6-10-2", "6-12-2",
    "12-11-2", "12-11-3", "12-11-4", "12-11-5", "12-11-6", "12-11-7", "12-11-8", "12-11-9", "12-11-10",
    "12-12-2", "12-12-3", "12-12-4", "12-12-5", "12-12-6", "12-12-7", "12-12-8", "12-12-9",
    "12-22-2", "12-22-3", "12-22-4", "12-22-5",
    "12-33-2", "12-33-3", "12-44-2", "12-55-2",
]

systems_where_no_significant_improvement_shown = {
    "3-1-2": [-1], "3-1-3": [-1], "3-1-4": [-1], "3-1-5": [-1], "3-2-2": [-1],
    "6-2-5": [-1], "6-4-3": [-1], "6-6-3": [-1], "6-8-2": [-1], "6-8-3": [-1], "6-12-2": [-1],
    "12-12-7": [-1], "12-22-3": [-1], "12-22-4": [-1], "12-22-5": [-1],
    "12-33-2": [-1], "12-33-3": [-1], "12-44-2": [-1], "12-55-2": [-1],
}
systems_where_signif_improv_shown_withOptThresh = {
    "6-2-2": [1], "6-2-3": [1], "6-2-6": [1], "6-2-7": [1], "6-2-8": [1], "6-2-9": [1],
    "6-4-2": [1], "6-4-4": [1], "6-4-5": [1], "6-4-6": [1], "6-6-2": [1], "6-6-4": [1], "6-10-2": [1],
    "12-11-2": [1], "12-11-3": [1], "12-11-4": [1], "12-11-5": [1], "12-11-6": [1], "12-11-7": [1], "12-11-8": [1], "12-11-9": [1],
    "12-12-2": [1], "12-12-3": [1], "12-12-4": [1], "12-12-5": [1], "12-12-6": [1], "12-12-8": [1], "12-12-9": [1], "12-22-2": [1],
}
systems_with_no_results = {
    "6-2-4": [0],
    "6-2-10": [0],
    "12-11-10": [0],
}

system_colors = ['grey','darkorange', 'blue']
system_markers = ['X', 'P', 'o']
system_labels = ['Not Significant for All Baselines', "Significant for All Baselines", "No Result"]

C = lambda x: (x[1]/(x[0]**2 - x[0]))**(-1)

plt.rcParams["figure.figsize"] = (17,8)
labels_plotted = []
for key in systems:
    split_key = key.split("-")
    nc = int(split_key[0])
    ne = int(split_key[1])
    nk = int(split_key[2])
    curr_complexity = C((nc, ne, nk))

    print("nc ", nc, ", ne ", ne, ", nk ", nk, ": C == ", curr_complexity)

    for i, dataset_to_score_map in enumerate([systems_where_no_significant_improvement_shown, systems_where_signif_improv_shown_withOptThresh, systems_with_no_results]):
        if key in dataset_to_score_map.keys():
            if i not in labels_plotted:
                plt.scatter([key], [curr_complexity], color=system_colors[i], marker=system_markers[i], s=100, label=system_labels[i], alpha=1)
                labels_plotted.append(i)
            else:
                plt.scatter([key], [curr_complexity], color=system_colors[i], marker=system_markers[i], s=100, alpha=1)

plt.plot([7. for _ in range(len(systems))], '-k')
plt.plot([13. for _ in range(len(systems))], '-k')
plt.xticks(rotation=70)
plt.ylim(6.,14.)
plt.ylabel("System Complexity ("+r'$\mathfrak{C}$'+")")
plt.xlabel("Synthetic System Name ("+r'$n_c$'+"-"+r'$n_e$'+"-"+r'$n_k$'+")")
plt.title("Smoothed REDCLIFF-S Improvement Over Baselines (Gaussian Noise): Pair-wise Optimal F1-Scores")
plt.legend()
plt.show()

plt.rcParams["figure.figsize"] = (17,8)
labels_plotted = []
for key in systems:
    split_key = key.split("-")
    nc = int(split_key[0])
    ne = int(split_key[1])
    nk = int(split_key[2])
    curr_complexity = C((nc, ne, nk))

    print("nc ", nc, ", ne ", ne, ", nk ", nk, ": C == ", curr_complexity)

    for i, dataset_to_score_map in enumerate([systems_where_no_significant_improvement_shown, systems_where_signif_improv_shown_withOptThresh, systems_with_no_results]):
        if key in dataset_to_score_map.keys():
            if i not in labels_plotted:
                print("\t key == ", key)
                plt.scatter([key], [curr_complexity], color=system_colors[i], marker=system_markers[i], s=100, label=system_labels[i], alpha=1)
                labels_plotted.append(i)
            else:
                plt.scatter([key], [curr_complexity], color=system_colors[i], marker=system_markers[i], s=100, alpha=1)

plt.plot([7. for _ in range(len(systems))], '-k')
plt.plot([13. for _ in range(len(systems))], '-k')
plt.xticks(rotation=70)
plt.ylabel("System Complexity ("+r'$\mathfrak{C}$'+")")
plt.xlabel("Synthetic System Name ("+r'$n_c$'+"-"+r'$n_e$'+"-"+r'$n_k$'+")")
plt.title("REDCLIFF-S Improvement Over Baselines (Gaussian Noise): Pair-wise Optimal F1-Scores")
plt.legend()
plt.show()

---

# D4IC Experiment Analyses


## Ablation Summaries

In [None]:
import numpy as np

# CosSim (rho) ablation: key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag, f1_vals_across_factors
print("CosSim (rho) ablation:")
curr_results_by_alg = {
    'REDCLIFF_S_CMLP': [0.4, 0.28846153846153844, 0.27027027027027023, 0.30303030303030304, 0.31884057971014496, 0.30508474576271183, 0.3, 0.2553191489361702, 0.29411764705882354, 0.32, 0.3555555555555555, 0.297029702970297, 0.2666666666666667, 0.2777777777777778, 0.3137254901960785, 0.3414634146341463, 0.30927835051546393, 0.26829268292682923, 0.3050847457627119, 0.3018867924528302, 0.3157894736842105, 0.30303030303030304, 0.2531645569620253, 0.30985915492957744, 0.3287671232876712],
    'CMLP': [0.36363636363636365, 0.30985915492957744, 0.2962962962962963, 0.37735849056603776, 0.35, 0.3846153846153846, 0.2857142857142857, 0.23529411764705882, 0.27777777777777773, 0.37037037037037035, 0.4444444444444445, 0.33898305084745756, 0.24999999999999997, 0.4000000000000001, 0.3333333333333333, 0.4, 0.29032258064516125, 0.24489795918367346, 0.4375, 0.3137254901960785, 0.31578947368421056, 0.3488372093023256, 0.2758620689655173, 0.2857142857142857, 0.3018867924528302], 'CLSTM': [0.3448275862068966, 0.29166666666666663, 0.2553191489361702, 0.3225806451612903, 0.32941176470588235, 0.3384615384615385, 0.32558139534883723, 0.2553191489361702, 0.2758620689655173, 0.42857142857142855, 0.3, 0.37499999999999994, 0.26373626373626374, 0.3111111111111111, 0.3076923076923077, 0.34615384615384615, 0.3157894736842105, 0.3333333333333333, 0.33333333333333337, 0.37209302325581395, 0.34146341463414637, 0.339622641509434, 0.26666666666666666, 0.30769230769230776, 0.3529411764705882], 'DGCNN': [0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325], 'DCSFA': [0.3037974683544304, 0.32608695652173914, 0.2857142857142857, 0.28, 0.4166666666666667, 0.2912621359223301, 0.2912621359223301, 0.30303030303030304, 0.2857142857142857, 0.3018867924528302, 0.3157894736842105, 0.325, 0.2784810126582279, 0.32, 0.3018867924528302, 0.3137254901960784, 0.2857142857142857, 0.32432432432432434, 0.3181818181818182, 0.3902439024390244, 0.3466666666666666, 0.29411764705882354, 0.3703703703703704, 0.2857142857142857, 0.34285714285714286], 'DYNOTEARS_Vanilla': [0.27522935779816515, 0.36363636363636365, 0.2439024390243903, 0.24074074074074076, 0.3684210526315789, 0.3, 0.3703703703703704, 0.3448275862068966, 0.24074074074074076, 0.2926829268292683, 0.2803738317757009, 0.42857142857142855, 0.2962962962962963, 0.2476190476190476, 0.45714285714285713, 0.28846153846153844, 0.38709677419354843, 0.3333333333333333, 0.24528301886792453, 0.36363636363636365, 0.2857142857142857, 0.38095238095238093, 0.22727272727272727, 0.24299065420560748, 0.326530612244898], 'NAVAR_CLSTM': [0.3076923076923077, 0.3023255813953489, 0.24000000000000002, 0.3055555555555555, 0.3106796116504854, 0.3, 0.3146067415730337, 0.25641025641025644, 0.3157894736842105, 0.3106796116504854, 0.39344262295081966, 0.29268292682926833, 0.25, 0.3636363636363637, 0.30476190476190473, 0.3055555555555555, 0.303030303030303, 0.26829268292682923, 0.32, 0.3137254901960785, 0.3, 0.28846153846153844, 0.2439024390243903, 0.29411764705882354, 0.30927835051546393], 'NAVAR_CMLP': [0.3384615384615385, 0.28846153846153844, 0.25925925925925924, 0.3428571428571428, 0.34090909090909094, 0.3661971830985915, 0.30303030303030304, 0.26373626373626374, 0.3076923076923077, 0.4307692307692308, 0.36363636363636365, 0.29411764705882354, 0.28571428571428575, 0.3333333333333333, 0.375, 0.3728813559322034, 0.325, 0.25, 0.2978723404255319, 0.4057971014492754, 0.3478260869565218, 0.3, 0.28571428571428575, 0.3225806451612903, 0.39999999999999997]}

print(np.mean(curr_results_by_alg['REDCLIFF_S_CMLP']))
print(np.std(curr_results_by_alg['REDCLIFF_S_CMLP'])/np.sqrt(len(curr_results_by_alg['REDCLIFF_S_CMLP'])))


# 1-Factor Ablation: key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag, f1_vals_across_factors
print("\n 1-Factor Ablation:")
curr_results_by_alg = {
    'REDCLIFF_S_CMLP': [0.3333333333333333, 0.2912621359223301, 0.25, 0.26000000000000006, 0.4166666666666667, 0.32352941176470584, 0.3611111111111111, 0.2758620689655173, 0.3870967741935484, 0.3333333333333333, 0.3188405797101449, 0.325, 0.23529411764705882, 0.2702702702702703, 0.4210526315789473, 0.30612244897959184, 0.2921348314606742, 0.2424242424242424, 0.2758620689655173, 0.392156862745098, 0.325, 0.3294117647058824, 0.24175824175824176, 0.2894736842105263, 0.4],
    'CMLP': [0.36363636363636365, 0.30985915492957744, 0.2962962962962963, 0.37735849056603776, 0.35, 0.3846153846153846, 0.2857142857142857, 0.23529411764705882, 0.27777777777777773, 0.37037037037037035, 0.4444444444444445, 0.33898305084745756, 0.24999999999999997, 0.4000000000000001, 0.3333333333333333, 0.4, 0.29032258064516125, 0.24489795918367346, 0.4375, 0.3137254901960785, 0.31578947368421056, 0.3488372093023256, 0.2758620689655173, 0.2857142857142857, 0.3018867924528302], 'CLSTM': [0.3448275862068966, 0.29166666666666663, 0.2553191489361702, 0.3225806451612903, 0.32941176470588235, 0.3384615384615385, 0.32558139534883723, 0.2553191489361702, 0.2758620689655173, 0.42857142857142855, 0.3, 0.37499999999999994, 0.26373626373626374, 0.3111111111111111, 0.3076923076923077, 0.34615384615384615, 0.3157894736842105, 0.3333333333333333, 0.33333333333333337, 0.37209302325581395, 0.34146341463414637, 0.339622641509434, 0.26666666666666666, 0.30769230769230776, 0.3529411764705882], 'DGCNN': [0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325], 'DCSFA': [0.3037974683544304, 0.32608695652173914, 0.2857142857142857, 0.28, 0.4166666666666667, 0.2912621359223301, 0.2912621359223301, 0.30303030303030304, 0.2857142857142857, 0.3018867924528302, 0.3157894736842105, 0.325, 0.2784810126582279, 0.32, 0.3018867924528302, 0.3137254901960784, 0.2857142857142857, 0.32432432432432434, 0.3181818181818182, 0.3902439024390244, 0.3466666666666666, 0.29411764705882354, 0.3703703703703704, 0.2857142857142857, 0.34285714285714286], 'DYNOTEARS_Vanilla': [0.27522935779816515, 0.36363636363636365, 0.2439024390243903, 0.24074074074074076, 0.3684210526315789, 0.3, 0.3703703703703704, 0.3448275862068966, 0.24074074074074076, 0.2926829268292683, 0.2803738317757009, 0.42857142857142855, 0.2962962962962963, 0.2476190476190476, 0.45714285714285713, 0.28846153846153844, 0.38709677419354843, 0.3333333333333333, 0.24528301886792453, 0.36363636363636365, 0.2857142857142857, 0.38095238095238093, 0.22727272727272727, 0.24299065420560748, 0.326530612244898], 'NAVAR_CLSTM': [0.3076923076923077, 0.3023255813953489, 0.24000000000000002, 0.3055555555555555, 0.3106796116504854, 0.3, 0.3146067415730337, 0.25641025641025644, 0.3157894736842105, 0.3106796116504854, 0.39344262295081966, 0.29268292682926833, 0.25, 0.3636363636363637, 0.30476190476190473, 0.3055555555555555, 0.303030303030303, 0.26829268292682923, 0.32, 0.3137254901960785, 0.3, 0.28846153846153844, 0.2439024390243903, 0.29411764705882354, 0.30927835051546393], 'NAVAR_CMLP': [0.3384615384615385, 0.28846153846153844, 0.25925925925925924, 0.3428571428571428, 0.34090909090909094, 0.3661971830985915, 0.30303030303030304, 0.26373626373626374, 0.3076923076923077, 0.4307692307692308, 0.36363636363636365, 0.29411764705882354, 0.28571428571428575, 0.3333333333333333, 0.375, 0.3728813559322034, 0.325, 0.25, 0.2978723404255319, 0.4057971014492754, 0.3478260869565218, 0.3, 0.28571428571428575, 0.3225806451612903, 0.39999999999999997]}

print(np.mean(curr_results_by_alg['REDCLIFF_S_CMLP']))
print(np.std(curr_results_by_alg['REDCLIFF_S_CMLP'])/np.sqrt(len(curr_results_by_alg['REDCLIFF_S_CMLP'])))

# Fixed-Factor (alpha) ablation: key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag, f1_vals_across_factors
print("\n Fixed-Factor (alpha) ablation:")
curr_results_by_alg = {
    'REDCLIFF_S_CMLP': [0.31578947368421056, 0.4067796610169491, 0.2857142857142857, 0.3333333333333333, 0.34210526315789475, 0.34285714285714286, 0.4, 0.25974025974025977, 0.3589743589743589, 0.35714285714285715, 0.30769230769230765, 0.4067796610169491, 0.29411764705882354, 0.35000000000000003, 0.31884057971014496, 0.3103448275862069, 0.4, 0.2608695652173913, 0.3888888888888889, 0.3404255319148936, 0.3272727272727273, 0.3793103448275862, 0.27586206896551724, 0.31250000000000006, 0.3846153846153846], 'CMLP': [0.36363636363636365, 0.30985915492957744, 0.2962962962962963, 0.37735849056603776, 0.35, 0.3846153846153846, 0.2857142857142857, 0.23529411764705882, 0.27777777777777773, 0.37037037037037035, 0.4444444444444445, 0.33898305084745756, 0.24999999999999997, 0.4000000000000001, 0.3333333333333333, 0.4, 0.29032258064516125, 0.24489795918367346, 0.4375, 0.3137254901960785, 0.31578947368421056, 0.3488372093023256, 0.2758620689655173, 0.2857142857142857, 0.3018867924528302], 'CLSTM': [0.3448275862068966, 0.29166666666666663, 0.2553191489361702, 0.3225806451612903, 0.32941176470588235, 0.3384615384615385, 0.32558139534883723, 0.2553191489361702, 0.2758620689655173, 0.42857142857142855, 0.3, 0.37499999999999994, 0.26373626373626374, 0.3111111111111111, 0.3076923076923077, 0.34615384615384615, 0.3157894736842105, 0.3333333333333333, 0.33333333333333337, 0.37209302325581395, 0.34146341463414637, 0.339622641509434, 0.26666666666666666, 0.30769230769230776, 0.3529411764705882], 'DGCNN': [0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325], 'DCSFA': [0.3037974683544304, 0.32608695652173914, 0.2857142857142857, 0.28, 0.4166666666666667, 0.2912621359223301, 0.2912621359223301, 0.30303030303030304, 0.2857142857142857, 0.3018867924528302, 0.3157894736842105, 0.325, 0.2784810126582279, 0.32, 0.3018867924528302, 0.3137254901960784, 0.2857142857142857, 0.32432432432432434, 0.3181818181818182, 0.3902439024390244, 0.3466666666666666, 0.29411764705882354, 0.3703703703703704, 0.2857142857142857, 0.34285714285714286], 'DYNOTEARS_Vanilla': [0.27522935779816515, 0.36363636363636365, 0.2439024390243903, 0.24074074074074076, 0.3684210526315789, 0.3, 0.3703703703703704, 0.3448275862068966, 0.24074074074074076, 0.2926829268292683, 0.2803738317757009, 0.42857142857142855, 0.2962962962962963, 0.2476190476190476, 0.45714285714285713, 0.28846153846153844, 0.38709677419354843, 0.3333333333333333, 0.24528301886792453, 0.36363636363636365, 0.2857142857142857, 0.38095238095238093, 0.22727272727272727, 0.24299065420560748, 0.326530612244898], 'NAVAR_CLSTM': [0.3076923076923077, 0.3023255813953489, 0.24000000000000002, 0.3055555555555555, 0.3106796116504854, 0.3, 0.3146067415730337, 0.25641025641025644, 0.3157894736842105, 0.3106796116504854, 0.39344262295081966, 0.29268292682926833, 0.25, 0.3636363636363637, 0.30476190476190473, 0.3055555555555555, 0.303030303030303, 0.26829268292682923, 0.32, 0.3137254901960785, 0.3, 0.28846153846153844, 0.2439024390243903, 0.29411764705882354, 0.30927835051546393], 'NAVAR_CMLP': [0.3384615384615385, 0.28846153846153844, 0.25925925925925924, 0.3428571428571428, 0.34090909090909094, 0.3661971830985915, 0.30303030303030304, 0.26373626373626374, 0.3076923076923077, 0.4307692307692308, 0.36363636363636365, 0.29411764705882354, 0.28571428571428575, 0.3333333333333333, 0.375, 0.3728813559322034, 0.325, 0.25, 0.2978723404255319, 0.4057971014492754, 0.3478260869565218, 0.3, 0.28571428571428575, 0.3225806451612903, 0.39999999999999997]}

print(np.mean(curr_results_by_alg['REDCLIFF_S_CMLP']))
print(np.std(curr_results_by_alg['REDCLIFF_S_CMLP'])/np.sqrt(len(curr_results_by_alg['REDCLIFF_S_CMLP'])))

# Unsupervised (lambda) ablation: key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag, f1_vals_across_factors
print("\n Unsupervised (lambda) ablation:")
curr_results_by_alg = {
    'REDCLIFF_S_CMLP': [0.32608695652173914, 0.3829787234042553, 0.2962962962962963, 0.31999999999999995, 0.3376623376623376, 0.32183908045977017, 0.3661971830985915, 0.3225806451612903, 0.326530612244898, 0.33333333333333337, 0.32967032967032966, 0.37037037037037035, 0.30303030303030304, 0.30434782608695654, 0.37681159420289856, 0.3181818181818181, 0.4, 0.275, 0.31999999999999995, 0.3376623376623376, 0.31111111111111106, 0.3793103448275862, 0.30303030303030304, 0.3333333333333333, 0.31683168316831684], 'CMLP': [0.36363636363636365, 0.30985915492957744, 0.2962962962962963, 0.37735849056603776, 0.35, 0.3846153846153846, 0.2857142857142857, 0.23529411764705882, 0.27777777777777773, 0.37037037037037035, 0.4444444444444445, 0.33898305084745756, 0.24999999999999997, 0.4000000000000001, 0.3333333333333333, 0.4, 0.29032258064516125, 0.24489795918367346, 0.4375, 0.3137254901960785, 0.31578947368421056, 0.3488372093023256, 0.2758620689655173, 0.2857142857142857, 0.3018867924528302], 'CLSTM': [0.3448275862068966, 0.29166666666666663, 0.2553191489361702, 0.3225806451612903, 0.32941176470588235, 0.3384615384615385, 0.32558139534883723, 0.2553191489361702, 0.2758620689655173, 0.42857142857142855, 0.3, 0.37499999999999994, 0.26373626373626374, 0.3111111111111111, 0.3076923076923077, 0.34615384615384615, 0.3157894736842105, 0.3333333333333333, 0.33333333333333337, 0.37209302325581395, 0.34146341463414637, 0.339622641509434, 0.26666666666666666, 0.30769230769230776, 0.3529411764705882], 'DGCNN': [0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325, 0.3636363636363636, 0.2857142857142857, 0.2608695652173913, 0.34285714285714286, 0.325], 'DCSFA': [0.3037974683544304, 0.32608695652173914, 0.2857142857142857, 0.28, 0.4166666666666667, 0.2912621359223301, 0.2912621359223301, 0.30303030303030304, 0.2857142857142857, 0.3018867924528302, 0.3157894736842105, 0.325, 0.2784810126582279, 0.32, 0.3018867924528302, 0.3137254901960784, 0.2857142857142857, 0.32432432432432434, 0.3181818181818182, 0.3902439024390244, 0.3466666666666666, 0.29411764705882354, 0.3703703703703704, 0.2857142857142857, 0.34285714285714286], 'DYNOTEARS_Vanilla': [0.27522935779816515, 0.36363636363636365, 0.2439024390243903, 0.24074074074074076, 0.3684210526315789, 0.3, 0.3703703703703704, 0.3448275862068966, 0.24074074074074076, 0.2926829268292683, 0.2803738317757009, 0.42857142857142855, 0.2962962962962963, 0.2476190476190476, 0.45714285714285713, 0.28846153846153844, 0.38709677419354843, 0.3333333333333333, 0.24528301886792453, 0.36363636363636365, 0.2857142857142857, 0.38095238095238093, 0.22727272727272727, 0.24299065420560748, 0.326530612244898], 'NAVAR_CLSTM': [0.3076923076923077, 0.3023255813953489, 0.24000000000000002, 0.3055555555555555, 0.3106796116504854, 0.3, 0.3146067415730337, 0.25641025641025644, 0.3157894736842105, 0.3106796116504854, 0.39344262295081966, 0.29268292682926833, 0.25, 0.3636363636363637, 0.30476190476190473, 0.3055555555555555, 0.303030303030303, 0.26829268292682923, 0.32, 0.3137254901960785, 0.3, 0.28846153846153844, 0.2439024390243903, 0.29411764705882354, 0.30927835051546393], 'NAVAR_CMLP': [0.3384615384615385, 0.28846153846153844, 0.25925925925925924, 0.3428571428571428, 0.34090909090909094, 0.3661971830985915, 0.30303030303030304, 0.26373626373626374, 0.3076923076923077, 0.4307692307692308, 0.36363636363636365, 0.29411764705882354, 0.28571428571428575, 0.3333333333333333, 0.375, 0.3728813559322034, 0.325, 0.25, 0.2978723404255319, 0.4057971014492754, 0.3478260869565218, 0.3, 0.28571428571428575, 0.3225806451612903, 0.39999999999999997]}

print(np.mean(curr_results_by_alg['REDCLIFF_S_CMLP']))
print(np.std(curr_results_by_alg['REDCLIFF_S_CMLP'])/np.sqrt(len(curr_results_by_alg['REDCLIFF_S_CMLP'])))

## Summary Analyses V03312025

NON-TRANSPOSED PREDICTION RESULTS

In [None]:
import numpy as np
import pickle as pkl

og_stats_fold0 = pkl.load(open("stats_by_alg_key_dict_fold0.pkl", "rb"))
og_stats_fold1 = pkl.load(open("stats_by_alg_key_dict_fold1.pkl", "rb"))
og_stats_fold2 = pkl.load(open("stats_by_alg_key_dict_fold2.pkl", "rb"))
og_stats_fold3 = pkl.load(open("stats_by_alg_key_dict_fold3.pkl", "rb"))
#og_stats_fold4 = pkl.load(open("stats_by_alg_key_dict_fold4.pkl", "rb")) # MISSING DUE TO KILLED JOB 03/31/2025

print("og_stats_fold0.keys() == ", og_stats_fold0.keys())
performance_across_folds_by_alg = {key[len("D4ICMSNRFold0_model_name_"):]:{"fold_"+str(i): dict() for i in range(4)} for key in og_stats_fold0.keys() if key.startswith("D4ICMSNRFold0_model_name_") and "Fold4" not in key}
print("performance_across_folds_by_alg.keys() == ", performance_across_folds_by_alg.keys())

for f_ind, og_stats in enumerate([og_stats_fold0, og_stats_fold1, og_stats_fold2, og_stats_fold3, og_stats_fold4]):
    print("f_ind == ", f_ind)
    for og_alg_key in og_stats.keys():
        if "Fold4" not in og_alg_key:
            for alg_key in performance_across_folds_by_alg.keys():
                if alg_key in og_alg_key:
                    for factor_key in sorted(list(og_stats[og_alg_key].keys())):
                        for stat_key in og_stats[og_alg_key][factor_key].keys():
                            if stat_key in performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"].keys():
                                if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
                                    try:
                                        performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key]["normalized"].append(og_stats[og_alg_key][factor_key][stat_key][0])
                                        performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key]["raw_count"].append(og_stats[og_alg_key][factor_key][stat_key][1])
                                    except:
                                        performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key]["normalized"].append(og_stats[og_alg_key][factor_key][stat_key])
                                        performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key]["raw_count"].append(og_stats[og_alg_key][factor_key][stat_key])
                                else:
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key].append(og_stats[og_alg_key][factor_key][stat_key])
                            else:
                                if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
                                    try:
                                        performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key] = {
                                            "normalized": [og_stats[og_alg_key][factor_key][stat_key][0]],
                                            "raw_count": [og_stats[og_alg_key][factor_key][stat_key][1]]
                                        }
                                    except:
                                        performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key] = {
                                            "normalized": [og_stats[og_alg_key][factor_key][stat_key]],
                                            "raw_count": [og_stats[og_alg_key][factor_key][stat_key]]
                                        }
                                else:
                                    performance_across_folds_by_alg[alg_key][f"fold_{f_ind}"][stat_key] = [og_stats[og_alg_key][factor_key][stat_key]]

print("performance_across_folds_by_alg == ", performance_across_folds_by_alg)
summary_stats_by_alg = {alg_key: dict() for alg_key in performance_across_folds_by_alg.keys()}
for alg_key in performance_across_folds_by_alg.keys():
    for fold_key in performance_across_folds_by_alg[alg_key].keys():
        for stat_key in performance_across_folds_by_alg[alg_key][fold_key].keys():
            if stat_key in summary_stats_by_alg[alg_key].keys():
                if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
                    for substat_key in performance_across_folds_by_alg[alg_key][fold_key][stat_key].keys():
                        summary_stats_by_alg[alg_key][stat_key][substat_key]["fold_means"].append(np.mean(performance_across_folds_by_alg[alg_key][fold_key][stat_key][substat_key]))
                        summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"] = summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"] + performance_across_folds_by_alg[alg_key][fold_key][stat_key][substat_key]
                else:
                    summary_stats_by_alg[alg_key][stat_key]["fold_means"].append(np.mean(performance_across_folds_by_alg[alg_key][fold_key][stat_key]))
                    summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"] = summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"] + performance_across_folds_by_alg[alg_key][fold_key][stat_key]
            else:
                if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
                    summary_stats_by_alg[alg_key][stat_key] = dict()
                    for substat_key in performance_across_folds_by_alg[alg_key][fold_key][stat_key].keys():
                        summary_stats_by_alg[alg_key][stat_key][substat_key] = {
                            "fold_means": [np.mean(performance_across_folds_by_alg[alg_key][fold_key][stat_key][substat_key])],
                            "stat_vals_across_folds_and_factors": []+performance_across_folds_by_alg[alg_key][fold_key][stat_key][substat_key],
                            "combo_stats_mean": None,
                            "combo_stats_sem": None,
                            "mean_of_fold_means": None,
                            "sem_of_fold_means": None,
                        }
                else:
                    summary_stats_by_alg[alg_key][stat_key] = {
                        "fold_means": [np.mean(performance_across_folds_by_alg[alg_key][fold_key][stat_key])],
                        "stat_vals_across_folds_and_factors": []+performance_across_folds_by_alg[alg_key][fold_key][stat_key],
                        "combo_stats_mean": None,
                        "combo_stats_sem": None,
                        "mean_of_fold_means": None,
                        "sem_of_fold_means": None,
                    }
    for stat_key in summary_stats_by_alg[alg_key].keys():
        if "ancestor_aid" in stat_key or "oset_aid" in stat_key or "parent_aid" in stat_key or "_shd" in stat_key:
            for substat_key in performance_across_folds_by_alg[alg_key][fold_key][stat_key].keys():
                summary_stats_by_alg[alg_key][stat_key][substat_key]["combo_stats_mean"] = np.mean(summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"])
                summary_stats_by_alg[alg_key][stat_key][substat_key]["combo_stats_sem"] = np.std(summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"]) / np.sqrt(len(summary_stats_by_alg[alg_key][stat_key][substat_key]["stat_vals_across_folds_and_factors"]))
                summary_stats_by_alg[alg_key][stat_key][substat_key]["mean_of_fold_means"] = np.mean(summary_stats_by_alg[alg_key][stat_key][substat_key]["fold_means"])
                summary_stats_by_alg[alg_key][stat_key][substat_key]["sem_of_fold_means"] = np.std(summary_stats_by_alg[alg_key][stat_key][substat_key]["fold_means"]) / np.sqrt(len(summary_stats_by_alg[alg_key][stat_key][substat_key]["fold_means"]))
        else:
            summary_stats_by_alg[alg_key][stat_key]["combo_stats_mean"] = np.mean(summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"])
            summary_stats_by_alg[alg_key][stat_key]["combo_stats_sem"] = np.std(summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"]) / np.sqrt(len(summary_stats_by_alg[alg_key][stat_key]["stat_vals_across_folds_and_factors"]))
            summary_stats_by_alg[alg_key][stat_key]["mean_of_fold_means"] = np.mean(summary_stats_by_alg[alg_key][stat_key]["fold_means"])
            summary_stats_by_alg[alg_key][stat_key]["sem_of_fold_means"] = np.std(summary_stats_by_alg[alg_key][stat_key]["fold_means"]) / np.sqrt(len(summary_stats_by_alg[alg_key][stat_key]["fold_means"]))

print("summary_stats_by_alg == ", summary_stats_by_alg)
for alg in summary_stats_by_alg.keys():
    print("alg == ", alg)
    for stat in summary_stats_by_alg[alg].keys():
        print("\t stat == ", stat)
        for summary_key in summary_stats_by_alg[alg][stat].keys():
            if "_aid" in stat or "_shd" in stat:
                print("\t\t sub_stat == ", summary_key)
                for sub_stat_summary in summary_stats_by_alg[alg][stat][summary_key].keys():
                    print("\t\t\t summary_key == ", sub_stat_summary, " == ", summary_stats_by_alg[alg][stat][summary_key][sub_stat_summary])
            else:
                print("\t\t summary_key == ", summary_key, " == ", summary_stats_by_alg[alg][stat][summary_key])

TRANSPOSED PREDICTION RESULTS

## R-PCMCI D4IC HSNR Experiments

Resources:
 - tutorial: https://github.com/jakobrunge/tigramite/blob/master/tutorials/causal_discovery/tigramite_tutorial_regime_pcmci.ipynb
 - "Reconstructing regime-dependent causal relationships from observational time series" by Elena Saggioro; Jana de Wiljes; Marlene Kretschmer; Jakob Runge (https://pubs.aip.org/aip/cha/article/30/11/113115/595926/Reconstructing-regime-dependent-causal)

In [None]:
%pip install tigramite
%pip install ortools
%pip install dcor

In [None]:
# Imports
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
%matplotlib inline
import sklearn
from sklearn.linear_model import LinearRegression


import tigramite
from tigramite import data_processing as pp
from tigramite.toymodels import structural_causal_processes as toys
from tigramite import plotting as tp
from tigramite.pcmci import PCMCI
from tigramite.rpcmci import RPCMCI


from tigramite.independence_tests.parcorr import ParCorr
from tigramite.independence_tests.gpdc import GPDC
from tigramite.independence_tests.cmiknn import CMIknn
from tigramite.independence_tests.cmisymb import CMIsymb

In [None]:
import pickle as pkl
from sklearn.metrics import confusion_matrix, f1_score, precision_recall_curve

def get_d4ic_HSNR_repeat_true_standardized_graphs(repeat_id):
    net1_adjacency_tensor, net2_adjacency_tensor, net3_adjacency_tensor, net4_adjacency_tensor, net5_adjacency_tensor = None, None, None, None, None
    if repeat_id == 0:
        net1_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [1.], [0.], [0.], [0.], [1.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [1.], [1.], [0.], [0.], [1.], [1.]], [[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net2_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.]], [[1.], [0.], [1.], [0.], [0.], [0.], [1.], [0.], [0.], [1.]], [[1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.]]]).squeeze()
        net3_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net4_adjacency_tensor = np.array([[[0.], [0.], [0.], [1.], [1.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net5_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [1.], [1.], [0.], [0.], [0.], [0.], [1.]], [[0.], [1.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [1.], [1.], [0.], [0.], [0.], [0.], [1.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
    elif repeat_id == 1:
        net1_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [1.], [0.], [0.], [0.], [1.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [1.], [1.], [0.], [0.], [1.], [1.]], [[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net2_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.]], [[1.], [0.], [1.], [0.], [0.], [0.], [1.], [0.], [0.], [1.]], [[1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.]]]).squeeze()
        net3_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net4_adjacency_tensor = np.array([[[0.], [0.], [0.], [1.], [1.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net5_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [1.], [1.], [0.], [0.], [0.], [0.], [1.]], [[0.], [1.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [1.], [1.], [0.], [0.], [0.], [0.], [1.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
    elif repeat_id == 2:
        net1_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [1.], [0.], [0.], [0.], [1.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [1.], [1.], [0.], [0.], [1.], [1.]], [[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net2_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.]], [[1.], [0.], [1.], [0.], [0.], [0.], [1.], [0.], [0.], [1.]], [[1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.]]]).squeeze()
        net3_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net4_adjacency_tensor = np.array([[[0.], [0.], [0.], [1.], [1.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net5_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [1.], [1.], [0.], [0.], [0.], [0.], [1.]], [[0.], [1.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [1.], [1.], [0.], [0.], [0.], [0.], [1.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
    elif repeat_id == 3:
        net1_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [1.], [0.], [0.], [0.], [1.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [1.], [1.], [0.], [0.], [1.], [1.]], [[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net2_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.]], [[1.], [0.], [1.], [0.], [0.], [0.], [1.], [0.], [0.], [1.]], [[1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.]]]).squeeze()
        net3_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net4_adjacency_tensor = np.array([[[0.], [0.], [0.], [1.], [1.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net5_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [1.], [1.], [0.], [0.], [0.], [0.], [1.]], [[0.], [1.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [1.], [1.], [0.], [0.], [0.], [0.], [1.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
    elif repeat_id == 4:
        net1_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [1.], [0.], [0.], [0.], [1.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [1.], [1.], [0.], [0.], [1.], [1.]], [[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net2_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.]], [[1.], [0.], [1.], [0.], [0.], [0.], [1.], [0.], [0.], [1.]], [[1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.]]]).squeeze()
        net3_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net4_adjacency_tensor = np.array([[[0.], [0.], [0.], [1.], [1.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
        net5_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [1.], [1.], [0.], [0.], [0.], [0.], [1.]], [[0.], [1.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [1.], [1.], [0.], [0.], [0.], [0.], [1.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
    else:
        raise ValueError()
    true_graphs_by_regime = {
        0: net1_adjacency_tensor,
        1: net2_adjacency_tensor,
        2: net3_adjacency_tensor,
        3: net4_adjacency_tensor,
        4: net5_adjacency_tensor,
    }
    return true_graphs_by_regime

def prepare_data_for_rpcmci_modeling(orig_data):
    num_samps = len(orig_data)
    rpcmci_data = None
    rpcmci_labels = None
    T_window_size = None
    N = None
    num_regimes = None
    masks_by_regime_index = None
    for i, samp in enumerate(orig_data):
        x = samp[0]
        y = samp[1]
        curr_dom_regime = np.argmax(y.squeeze())
        if rpcmci_data is None:
            assert rpcmci_labels is None
            T_window_size = x.shape[0]
            N = x.shape[1]
            rpcmci_data = x
            assert rpcmci_data.shape == (T_window_size, N)
            num_regimes = y.shape[0]
            rpcmci_labels = np.concatenate([y.T for _ in range(T_window_size)], axis=0)
            assert rpcmci_labels.shape == (T_window_size, num_regimes)
            masks_by_regime_index = {r:np.zeros(rpcmci_data.shape) for r in range(num_regimes)}
        else:
            rpcmci_data = np.concatenate((rpcmci_data, x), axis=0)
            rpcmci_labels = np.concatenate([rpcmci_labels]+[y.T for _ in range(T_window_size)], axis=0)
            for r in masks_by_regime_index.keys():
                masks_by_regime_index[r] = np.concatenate([masks_by_regime_index[r], np.zeros(x.shape)], axis=0)
        for r in masks_by_regime_index.keys():
            assert masks_by_regime_index[r].shape == rpcmci_data.shape
        masks_by_regime_index[curr_dom_regime][-1*T_window_size:,:] = masks_by_regime_index[curr_dom_regime][-1*T_window_size:,:] + 1
    T = T_window_size*num_samps
    assert rpcmci_data.shape == (T, N)
    assert rpcmci_labels.shape == (T, num_regimes)
    assert np.max([np.max(masks_by_regime_index[r]) for r in range(num_regimes)]) == 1
    assert np.min([np.min(masks_by_regime_index[r]) for r in range(num_regimes)]) == 0
    return rpcmci_data, rpcmci_labels, masks_by_regime_index, T_window_size, T, N, num_regimes

def get_standardized_off_diagonal_relation_predictions(A_tensor, transpose=True):
    assert len(A_tensor.shape) == 3
    assert A_tensor.shape[0] == A_tensor.shape[1]
    standard_A = np.sum(np.abs(A_tensor), axis=2) # standard convention is that columns drive rows
    if transpose:
        standard_A = standard_A.T
    off_diag_mask = np.ones(standard_A.shape) - np.eye(standard_A.shape[0])
    off_diag_standard_A = standard_A*off_diag_mask
    return off_diag_standard_A

def compute_optimal_f1(labels, pred_logits):
    """
    See:
     - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html
     - https://stackoverflow.com/questions/70902917/how-to-calculate-precision-recall-and-f1-for-entity-prediction#:~:text=The%20F1%20score%20of%20a,a%20class%20in%20one%20metric.
     - https://stackoverflow.com/questions/57060907/compute-maximum-f1-score-using-precision-recall-curve
    """
    precision, recall, thresholds = precision_recall_curve(labels, pred_logits)
    precision = precision[:-1] # see https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html
    recall = recall[:-1] # see https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html
    f1_scores_by_threshold = (2.0 * precision * recall) / (precision + recall)
    for ind, f1 in enumerate(f1_scores_by_threshold):
        if not np.isfinite(f1):
            f1_scores_by_threshold[ind] = 0.
    opt_threshold = np.argmax(f1_scores_by_threshold)
    opt_f1 = np.max(f1_scores_by_threshold)
    assert np.isfinite(opt_f1)
    return opt_threshold, opt_f1

def get_pcmci_edge_preds_from_graph(graph):
    assert len(graph.shape) == 3
    assert graph.shape[0] == graph.shape[1]
    edge_pred_tensor = np.zeros(graph.shape)
    for i in range(graph.shape[0]):
        for j in range(graph.shape[1]):
            for k in range(graph.shape[2]):
                if graph[i,j,k] == "-->":
                    edge_pred_tensor[i,j,k] = 1
    return edge_pred_tensor


In [None]:
def run_d4ic_experiment(data_file_name, repeat_id, pred_source="graph", transpose=True):
    # load data and initialize necessary variables
    orig_train_data = pkl.load(open(data_file_name, "rb"))
    train_data, train_labels, masks_by_regime_index, T_window_size, T, N, num_regimes = prepare_data_for_rpcmci_modeling(orig_train_data)
    var_names = ["c"+str(i) for i in range(N)]
    datatime = np.arange(T)

    dataframe_plotting_by_regime = dict()
    for r in range(num_regimes):
        dataframe_plotting_by_regime[r] = pp.DataFrame(train_data, mask=masks_by_regime_index[r])

    # Case where causal regimes are known
    pcmci_by_regime = {r:PCMCI(dataframe=dataframe_plotting_by_regime[r], cond_ind_test=ParCorr(mask_type='y')) for r in range(num_regimes)}
    pcmci_results_by_regime = {r: pcmci_by_regime[r].run_pcmci(tau_min=1, tau_max=2, pc_alpha=0.2, alpha_level=0.01) for r in range(num_regimes)}

    pcmci_edge_preds_by_regime = None
    if pred_source == "graph":
        pcmci_edge_preds_by_regime = {r:get_pcmci_edge_preds_from_graph(pcmci_results_by_regime[r]['graph']) for r in range(num_regimes)}
    elif pred_source == "val_matrix":
        pcmci_edge_preds_by_regime = {r:pcmci_results_by_regime[r]['val_matrix'] for r in range(num_regimes)}
    else:
        raise ValueError()
    pcmci_standardizedRelationPreds_by_regime = {r:get_standardized_off_diagonal_relation_predictions(pcmci_edge_preds_by_regime[r], transpose=transpose) for r in range(num_regimes)}

    true_graphs_by_regime = get_d4ic_HSNR_repeat_true_standardized_graphs(repeat_id)
    pcmci_optF1Scores_by_regime = dict()
    for r in range(num_regimes):
        _, opt_f1 = compute_optimal_f1(true_graphs_by_regime[r].flatten(), pcmci_standardizedRelationPreds_by_regime[r].flatten())
        pcmci_optF1Scores_by_regime[r] = opt_f1
    pcmci_optF1Score_cross_regime_mean = np.mean([pcmci_optF1Scores_by_regime[r] for r in pcmci_optF1Scores_by_regime.keys()])
    pcmci_optF1Score_cross_regime_sem = np.std([pcmci_optF1Scores_by_regime[r] for r in pcmci_optF1Scores_by_regime.keys()])/np.sqrt(len(pcmci_optF1Scores_by_regime.keys()))
    print("pcmci_optF1Scores_by_regime == ", pcmci_optF1Scores_by_regime)
    print("mean_optF1Score for PCMCI on D4IC HSNR REPEAT0: ", pcmci_optF1Score_cross_regime_mean)
    print("sem_optF1Score for PCMCI on D4IC HSNR REPEAT0: ", pcmci_optF1Score_cross_regime_sem)
    return pcmci_optF1Scores_by_regime, pcmci_optF1Score_cross_regime_mean, pcmci_optF1Score_cross_regime_sem



In [None]:
# Experiment using results' graph key as edge prediction matrix

np.random.seed(0)

pcmci_optF1Scores_d4icHSNR_repeat0, rep0_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat0_train_subset_0.pkl", 0, pred_source="graph", transpose=True)
pcmci_optF1Scores_d4icHSNR_repeat1, rep1_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat1_train_subset_0.pkl", 1, pred_source="graph", transpose=True)
pcmci_optF1Scores_d4icHSNR_repeat2, rep2_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat2_train_subset_0.pkl", 2, pred_source="graph", transpose=True)
pcmci_optF1Scores_d4icHSNR_repeat3, rep3_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat3_train_subset_0.pkl", 3, pred_source="graph", transpose=True)
pcmci_optF1Scores_d4icHSNR_repeat4, rep4_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat4_train_subset_0.pkl", 4, pred_source="graph", transpose=True)

print("\n\n FULL EXPERIMENT SUMMARY STATISTICS -------------------------")
cross_exp_pcmci_optF1ScoreMeans = [rep0_f1ScoreMean, rep1_f1ScoreMean, rep2_f1ScoreMean, rep3_f1ScoreMean, rep4_f1ScoreMean]
print("cross_exp_pcmci_optF1ScoreMeans == ", cross_exp_pcmci_optF1ScoreMeans)
print("np.mean(cross_exp_pcmci_optF1ScoreMeans) == ", np.mean(cross_exp_pcmci_optF1ScoreMeans))
print("sem(cross_exp_pcmci_optF1ScoreMeans) == ", np.std(cross_exp_pcmci_optF1ScoreMeans) / np.sqrt(len(cross_exp_pcmci_optF1ScoreMeans)))

cross_exp_pcmci_optF1Scores = []
for repeat_map in [pcmci_optF1Scores_d4icHSNR_repeat0, pcmci_optF1Scores_d4icHSNR_repeat1, pcmci_optF1Scores_d4icHSNR_repeat2, pcmci_optF1Scores_d4icHSNR_repeat3, pcmci_optF1Scores_d4icHSNR_repeat4]:
    for key in repeat_map.keys():
        cross_exp_pcmci_optF1Scores.append(repeat_map[key])
print("cross_exp_pcmci_optF1Scores == ", cross_exp_pcmci_optF1Scores)
print("np.mean(cross_exp_pcmci_optF1Scores) == ", np.mean(cross_exp_pcmci_optF1Scores))
print("sem(cross_exp_pcmci_optF1Scores) == ", np.std(cross_exp_pcmci_optF1Scores)/np.sqrt(len(cross_exp_pcmci_optF1Scores)))

In [None]:
# Experiment using results' graph key as edge prediction matrix and NOT transposing it

np.random.seed(0)

pcmci_optF1Scores_d4icHSNR_repeat0, rep0_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat0_train_subset_0.pkl", 0, pred_source="graph", transpose=False)
pcmci_optF1Scores_d4icHSNR_repeat1, rep1_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat1_train_subset_0.pkl", 1, pred_source="graph", transpose=False)
pcmci_optF1Scores_d4icHSNR_repeat2, rep2_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat2_train_subset_0.pkl", 2, pred_source="graph", transpose=False)
pcmci_optF1Scores_d4icHSNR_repeat3, rep3_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat3_train_subset_0.pkl", 3, pred_source="graph", transpose=False)
pcmci_optF1Scores_d4icHSNR_repeat4, rep4_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat4_train_subset_0.pkl", 4, pred_source="graph", transpose=False)

print("\n\n FULL EXPERIMENT SUMMARY STATISTICS -------------------------")
cross_exp_pcmci_optF1ScoreMeans = [rep0_f1ScoreMean, rep1_f1ScoreMean, rep2_f1ScoreMean, rep3_f1ScoreMean, rep4_f1ScoreMean]
print("cross_exp_pcmci_optF1ScoreMeans == ", cross_exp_pcmci_optF1ScoreMeans)
print("np.mean(cross_exp_pcmci_optF1ScoreMeans) == ", np.mean(cross_exp_pcmci_optF1ScoreMeans))
print("sem(cross_exp_pcmci_optF1ScoreMeans) == ", np.std(cross_exp_pcmci_optF1ScoreMeans) / np.sqrt(len(cross_exp_pcmci_optF1ScoreMeans)))

cross_exp_pcmci_optF1Scores = []
for repeat_map in [pcmci_optF1Scores_d4icHSNR_repeat0, pcmci_optF1Scores_d4icHSNR_repeat1, pcmci_optF1Scores_d4icHSNR_repeat2, pcmci_optF1Scores_d4icHSNR_repeat3, pcmci_optF1Scores_d4icHSNR_repeat4]:
    for key in repeat_map.keys():
        cross_exp_pcmci_optF1Scores.append(repeat_map[key])
print("cross_exp_pcmci_optF1Scores == ", cross_exp_pcmci_optF1Scores)
print("np.mean(cross_exp_pcmci_optF1Scores) == ", np.mean(cross_exp_pcmci_optF1Scores))
print("sem(cross_exp_pcmci_optF1Scores) == ", np.std(cross_exp_pcmci_optF1Scores)/np.sqrt(len(cross_exp_pcmci_optF1Scores)))

In [None]:
# Experiment using results' val_matrix key as edge prediction matrix

np.random.seed(0)

pcmci_optF1Scores_d4icHSNR_repeat0, rep0_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat0_train_subset_0.pkl", 0, pred_source="val_matrix", transpose=True)
pcmci_optF1Scores_d4icHSNR_repeat1, rep1_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat1_train_subset_0.pkl", 1, pred_source="val_matrix", transpose=True)
pcmci_optF1Scores_d4icHSNR_repeat2, rep2_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat2_train_subset_0.pkl", 2, pred_source="val_matrix", transpose=True)
pcmci_optF1Scores_d4icHSNR_repeat3, rep3_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat3_train_subset_0.pkl", 3, pred_source="val_matrix", transpose=True)
pcmci_optF1Scores_d4icHSNR_repeat4, rep4_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat4_train_subset_0.pkl", 4, pred_source="val_matrix", transpose=True)

print("\n\n FULL EXPERIMENT SUMMARY STATISTICS -------------------------")
cross_exp_pcmci_optF1ScoreMeans = [rep0_f1ScoreMean, rep1_f1ScoreMean, rep2_f1ScoreMean, rep3_f1ScoreMean, rep4_f1ScoreMean]
print("cross_exp_pcmci_optF1ScoreMeans == ", cross_exp_pcmci_optF1ScoreMeans)
print("np.mean(cross_exp_pcmci_optF1ScoreMeans) == ", np.mean(cross_exp_pcmci_optF1ScoreMeans))
print("sem(cross_exp_pcmci_optF1ScoreMeans) == ", np.std(cross_exp_pcmci_optF1ScoreMeans) / np.sqrt(len(cross_exp_pcmci_optF1ScoreMeans)))

cross_exp_pcmci_optF1Scores = []
for repeat_map in [pcmci_optF1Scores_d4icHSNR_repeat0, pcmci_optF1Scores_d4icHSNR_repeat1, pcmci_optF1Scores_d4icHSNR_repeat2, pcmci_optF1Scores_d4icHSNR_repeat3, pcmci_optF1Scores_d4icHSNR_repeat4]:
    for key in repeat_map.keys():
        cross_exp_pcmci_optF1Scores.append(repeat_map[key])
print("cross_exp_pcmci_optF1Scores == ", cross_exp_pcmci_optF1Scores)
print("np.mean(cross_exp_pcmci_optF1Scores) == ", np.mean(cross_exp_pcmci_optF1Scores))
print("sem(cross_exp_pcmci_optF1Scores) == ", np.std(cross_exp_pcmci_optF1Scores)/np.sqrt(len(cross_exp_pcmci_optF1Scores)))

In [None]:
# Experiment using results' val_matrix key as edge prediction matrix and NOT transposing it

np.random.seed(0)

pcmci_optF1Scores_d4icHSNR_repeat0, rep0_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat0_train_subset_0.pkl", 0, pred_source="val_matrix", transpose=False)
pcmci_optF1Scores_d4icHSNR_repeat1, rep1_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat1_train_subset_0.pkl", 1, pred_source="val_matrix", transpose=False)
pcmci_optF1Scores_d4icHSNR_repeat2, rep2_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat2_train_subset_0.pkl", 2, pred_source="val_matrix", transpose=False)
pcmci_optF1Scores_d4icHSNR_repeat3, rep3_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat3_train_subset_0.pkl", 3, pred_source="val_matrix", transpose=False)
pcmci_optF1Scores_d4icHSNR_repeat4, rep4_f1ScoreMean, _ = run_d4ic_experiment("d4ic_hsnr_repeat4_train_subset_0.pkl", 4, pred_source="val_matrix", transpose=False)

print("\n\n FULL EXPERIMENT SUMMARY STATISTICS -------------------------")
cross_exp_pcmci_optF1ScoreMeans = [rep0_f1ScoreMean, rep1_f1ScoreMean, rep2_f1ScoreMean, rep3_f1ScoreMean, rep4_f1ScoreMean]
print("cross_exp_pcmci_optF1ScoreMeans == ", cross_exp_pcmci_optF1ScoreMeans)
print("np.mean(cross_exp_pcmci_optF1ScoreMeans) == ", np.mean(cross_exp_pcmci_optF1ScoreMeans))
print("sem(cross_exp_pcmci_optF1ScoreMeans) == ", np.std(cross_exp_pcmci_optF1ScoreMeans) / np.sqrt(len(cross_exp_pcmci_optF1ScoreMeans)))

cross_exp_pcmci_optF1Scores = []
for repeat_map in [pcmci_optF1Scores_d4icHSNR_repeat0, pcmci_optF1Scores_d4icHSNR_repeat1, pcmci_optF1Scores_d4icHSNR_repeat2, pcmci_optF1Scores_d4icHSNR_repeat3, pcmci_optF1Scores_d4icHSNR_repeat4]:
    for key in repeat_map.keys():
        cross_exp_pcmci_optF1Scores.append(repeat_map[key])
print("cross_exp_pcmci_optF1Scores == ", cross_exp_pcmci_optF1Scores)
print("np.mean(cross_exp_pcmci_optF1Scores) == ", np.mean(cross_exp_pcmci_optF1Scores))
print("sem(cross_exp_pcmci_optF1Scores) == ", np.std(cross_exp_pcmci_optF1Scores)/np.sqrt(len(cross_exp_pcmci_optF1Scores)))

### Sandbox

In [None]:
def prepare_data_for_rpcmci_modeling(orig_data):
    num_samps = len(orig_data)
    rpcmci_data = None
    rpcmci_labels = None
    T_window_size = None
    N = None
    num_regimes = None
    masks_by_regime_index = None
    for i, samp in enumerate(orig_data):
        x = samp[0]
        y = samp[1]
        curr_dom_regime = np.argmax(y.squeeze())
        if rpcmci_data is None:
            assert rpcmci_labels is None
            T_window_size = x.shape[0]
            N = x.shape[1]
            rpcmci_data = x
            assert rpcmci_data.shape == (T_window_size, N)
            num_regimes = y.shape[0]
            rpcmci_labels = np.concatenate([y.T for _ in range(T_window_size)], axis=0)
            assert rpcmci_labels.shape == (T_window_size, num_regimes)
            masks_by_regime_index = {r:np.zeros(rpcmci_data.shape) for r in range(num_regimes)}
        else:
            rpcmci_data = np.concatenate((rpcmci_data, x), axis=0)
            rpcmci_labels = np.concatenate([rpcmci_labels]+[y.T for _ in range(T_window_size)], axis=0)
            for r in masks_by_regime_index.keys():
                masks_by_regime_index[r] = np.concatenate([masks_by_regime_index[r], np.zeros(x.shape)], axis=0)
        for r in masks_by_regime_index.keys():
            assert masks_by_regime_index[r].shape == rpcmci_data.shape
        masks_by_regime_index[curr_dom_regime][-1*T_window_size:,:] = masks_by_regime_index[curr_dom_regime][-1*T_window_size:,:] + 1
    T = T_window_size*num_samps
    assert rpcmci_data.shape == (T, N)
    assert rpcmci_labels.shape == (T, num_regimes)
    assert np.max([np.max(masks_by_regime_index[r]) for r in range(num_regimes)]) == 1
    assert np.min([np.min(masks_by_regime_index[r]) for r in range(num_regimes)]) == 0
    return rpcmci_data, rpcmci_labels, masks_by_regime_index, T_window_size, T, N, num_regimes


# load data
import pickle as pkl

d4icHSNR_rep0_train_data = pkl.load(open("d4ic_hsnr_repeat0_train_subset_0.pkl", "rb"))
d4icHSNR_rep0_val_data = pkl.load(open("d4ic_hsnr_repeat0_val_subset_0.pkl", "rb"))

print("len(d4icHSNR_rep0_train_data) == ", len(d4icHSNR_rep0_train_data))
print(len(d4icHSNR_rep0_train_data[0]))
print(d4icHSNR_rep0_train_data[0][0].shape)
print(d4icHSNR_rep0_train_data[0][1].shape)

rpcmci_data, rpcmci_labels, masks_by_regime_index, T_window_size, T, N, num_regimes = prepare_data_for_rpcmci_modeling(d4icHSNR_rep0_train_data)
print("rpcmci_data.shape == ", rpcmci_data.shape)
print("rpcmci_labels.shape == ", rpcmci_labels.shape)
print("masks_by_regime_index == ", masks_by_regime_index)
print("T_window_size == ", T_window_size)
print("T == ", T)
print("N == ", N)
print("num_regimes == ", num_regimes)

var_names = ["c"+str(i) for i in range(N)]
rpcmci_datatime = np.arange(T)

dataframe_plotting_by_regime = dict()
for r in range(num_regimes):
    dataframe_plotting_by_regime[r] = pp.DataFrame(rpcmci_data, mask=masks_by_regime_index[r])
    tp.plot_timeseries(dataframe_plotting_by_regime[r], figsize=(N,N), grey_masked_samples='data')
    plt.xlabel("Regime "+str(r)+" Data (Grey)")
    plt.show()


In [None]:
def get_standardized_off_diagonal_relation_predictions(A_tensor):
    assert len(A_tensor.shape) == 3
    assert A_tensor.shape[0] == A_tensor.shape[1]
    standard_A = np.sum(np.abs(A_tensor), axis=2).T # standard convention is that columns drive rows
    off_diag_mask = np.ones(standard_A.shape) - np.eye(standard_A.shape[0])
    off_diag_standard_A = standard_A*off_diag_mask
    return off_diag_standard_A

# Case where causal regimes are known
pcmci_by_regime = {r:PCMCI(dataframe=dataframe_plotting_by_regime[r], cond_ind_test=ParCorr(mask_type='y')) for r in range(num_regimes)}
pcmci_results_by_regime = {r: pcmci_by_regime[r].run_pcmci(tau_min=1, tau_max=2, pc_alpha=0.2, alpha_level=0.01) for r in range(num_regimes)}
for r in pcmci_results_by_regime.keys():
    tp.plot_graph(
        val_matrix=pcmci_results_by_regime[r]['val_matrix'],
        graph=pcmci_results_by_regime[r]['graph'],
        var_names=var_names,
        node_aspect=0.5, node_size=0.5
    )
    plt.title("PCMCI Results for Regime "+str(r))
    plt.show()
pcmci_standardizedRelationPreds_by_regime = {r:get_standardized_off_diagonal_relation_predictions(pcmci_results_by_regime[r]['val_matrix']) for r in range(num_regimes)}
for r in pcmci_standardizedRelationPreds_by_regime.keys():
    plt.imshow(pcmci_standardizedRelationPreds_by_regime[r])
    plt.colorbar()
    plt.title("PCMCI Standardized Inter-Variable Relation Preds for Regime "+str(r))
    plt.show()

In [None]:
from sklearn.metrics import confusion_matrix, f1_score, precision_recall_curve

# evaluating standardized predictions

def get_d4ic_HSNR_rep0_true_standardized_graphs():
    net1_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [1.], [0.], [0.], [0.], [1.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [1.], [1.], [0.], [0.], [1.], [1.]], [[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
    net2_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [1.]], [[1.], [0.], [1.], [0.], [0.], [0.], [1.], [0.], [0.], [1.]], [[1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.]]]).squeeze()
    net3_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.]], [[1.], [0.], [0.], [1.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
    net4_adjacency_tensor = np.array([[[0.], [0.], [0.], [1.], [1.], [0.], [1.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [1.], [0.], [1.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
    net5_adjacency_tensor = np.array([[[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [1.], [1.], [0.], [0.], [0.], [0.], [1.]], [[0.], [1.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [1.], [1.], [1.], [0.], [0.], [0.], [0.], [1.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [1.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [1.], [0.], [0.], [0.], [0.]], [[0.], [1.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]]]).squeeze()
    true_graphs_by_regime = {
        0: net1_adjacency_tensor,
        1: net2_adjacency_tensor,
        2: net3_adjacency_tensor,
        3: net4_adjacency_tensor,
        4: net5_adjacency_tensor,
    }
    return true_graphs_by_regime

def compute_optimal_f1(labels, pred_logits):
    """
    See:
     - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html
     - https://stackoverflow.com/questions/70902917/how-to-calculate-precision-recall-and-f1-for-entity-prediction#:~:text=The%20F1%20score%20of%20a,a%20class%20in%20one%20metric.
     - https://stackoverflow.com/questions/57060907/compute-maximum-f1-score-using-precision-recall-curve
    """
    precision, recall, thresholds = precision_recall_curve(labels, pred_logits)
    precision = precision[:-1] # see https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html
    recall = recall[:-1] # see https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html
    f1_scores_by_threshold = (2.0 * precision * recall) / (precision + recall)
    for ind, f1 in enumerate(f1_scores_by_threshold):
        if not np.isfinite(f1):
            f1_scores_by_threshold[ind] = 0.
    opt_threshold = np.argmax(f1_scores_by_threshold)
    opt_f1 = np.max(f1_scores_by_threshold)
    assert np.isfinite(opt_f1)
    return opt_threshold, opt_f1


true_graphs_by_regime = get_d4ic_HSNR_rep0_true_standardized_graphs()
pcmci_optF1Scores_by_regime = dict()
for r in range(num_regimes):
    _, opt_f1 = compute_optimal_f1(true_graphs_by_regime[r].flatten(), pcmci_standardizedRelationPreds_by_regime[r].flatten())
    pcmci_optF1Scores_by_regime[r] = opt_f1
print(pcmci_optF1Scores_by_regime)
print("mean_optF1Score for PCMCI on D4IC HSNR REPEAT0: ", np.mean([pcmci_optF1Scores_by_regime[r] for r in pcmci_optF1Scores_by_regime.keys()]))
print("sem_optF1Score for PCMCI on D4IC HSNR REPEAT0: ", np.std([pcmci_optF1Scores_by_regime[r] for r in pcmci_optF1Scores_by_regime.keys()])/np.sqrt(len(pcmci_optF1Scores_by_regime.keys())))

## Computing Complexity Score of D4IC HSNR Networks 01/27/2025

In [None]:
c = lambda x: ((x[1]) / (x[0]**2. - x[0]))**(-1)

hsnr_fold0_net1_nE = 15
hsnr_fold0_net2_nE = 15
hsnr_fold0_net3_nE = 12
hsnr_fold0_net4_nE = 13
hsnr_fold0_net5_nE = 16
hsnr_fold0_net1_nC = 10
hsnr_fold0_net2_nC = 10
hsnr_fold0_net3_nC = 10
hsnr_fold0_net4_nC = 10
hsnr_fold0_net5_nC = 10
hsnr_fold0_net1_c = c((hsnr_fold0_net1_nC, hsnr_fold0_net1_nE))
hsnr_fold0_net2_c = c((hsnr_fold0_net2_nC, hsnr_fold0_net2_nE))
hsnr_fold0_net3_c = c((hsnr_fold0_net3_nC, hsnr_fold0_net3_nE))
hsnr_fold0_net4_c = c((hsnr_fold0_net4_nC, hsnr_fold0_net4_nE))
hsnr_fold0_net5_c = c((hsnr_fold0_net5_nC, hsnr_fold0_net5_nE))
print("hsnr_fold0_net1_c == ", hsnr_fold0_net1_c)
print("hsnr_fold0_net2_c == ", hsnr_fold0_net2_c)
print("hsnr_fold0_net3_c == ", hsnr_fold0_net3_c)
print("hsnr_fold0_net4_c == ", hsnr_fold0_net4_c)
print("hsnr_fold0_net5_c == ", hsnr_fold0_net5_c, "\n")

hsnr_fold1_net1_nE = 15
hsnr_fold1_net2_nE = 15
hsnr_fold1_net3_nE = 12
hsnr_fold1_net4_nE = 13
hsnr_fold1_net5_nE = 16
hsnr_fold1_net1_nC = 10
hsnr_fold1_net2_nC = 10
hsnr_fold1_net3_nC = 10
hsnr_fold1_net4_nC = 10
hsnr_fold1_net5_nC = 10
hsnr_fold1_net1_c = c((hsnr_fold1_net1_nC, hsnr_fold1_net1_nE))
hsnr_fold1_net2_c = c((hsnr_fold1_net2_nC, hsnr_fold1_net2_nE))
hsnr_fold1_net3_c = c((hsnr_fold1_net3_nC, hsnr_fold1_net3_nE))
hsnr_fold1_net4_c = c((hsnr_fold1_net4_nC, hsnr_fold1_net4_nE))
hsnr_fold1_net5_c = c((hsnr_fold1_net5_nC, hsnr_fold1_net5_nE))
print("hsnr_fold1_net1_c == ", hsnr_fold1_net1_c)
print("hsnr_fold1_net2_c == ", hsnr_fold1_net2_c)
print("hsnr_fold1_net3_c == ", hsnr_fold1_net3_c)
print("hsnr_fold1_net4_c == ", hsnr_fold1_net4_c)
print("hsnr_fold1_net5_c == ", hsnr_fold1_net5_c, "\n")

hsnr_fold2_net1_nE = 15
hsnr_fold2_net2_nE = 15
hsnr_fold2_net3_nE = 12
hsnr_fold2_net4_nE = 13
hsnr_fold2_net5_nE = 16
hsnr_fold2_net1_nC = 10
hsnr_fold2_net2_nC = 10
hsnr_fold2_net3_nC = 10
hsnr_fold2_net4_nC = 10
hsnr_fold2_net5_nC = 10
hsnr_fold2_net1_c = c((hsnr_fold2_net1_nC, hsnr_fold2_net1_nE))
hsnr_fold2_net2_c = c((hsnr_fold2_net2_nC, hsnr_fold2_net2_nE))
hsnr_fold2_net3_c = c((hsnr_fold2_net3_nC, hsnr_fold2_net3_nE))
hsnr_fold2_net4_c = c((hsnr_fold2_net4_nC, hsnr_fold2_net4_nE))
hsnr_fold2_net5_c = c((hsnr_fold2_net5_nC, hsnr_fold2_net5_nE))
print("hsnr_fold2_net1_c == ", hsnr_fold2_net1_c)
print("hsnr_fold2_net2_c == ", hsnr_fold2_net2_c)
print("hsnr_fold2_net3_c == ", hsnr_fold2_net3_c)
print("hsnr_fold2_net4_c == ", hsnr_fold2_net4_c)
print("hsnr_fold2_net5_c == ", hsnr_fold2_net5_c, "\n")

hsnr_fold3_net1_nE = 15
hsnr_fold3_net2_nE = 15
hsnr_fold3_net3_nE = 12
hsnr_fold3_net4_nE = 13
hsnr_fold3_net5_nE = 16
hsnr_fold3_net1_nC = 10
hsnr_fold3_net2_nC = 10
hsnr_fold3_net3_nC = 10
hsnr_fold3_net4_nC = 10
hsnr_fold3_net5_nC = 10
hsnr_fold3_net1_c = c((hsnr_fold3_net1_nC, hsnr_fold3_net1_nE))
hsnr_fold3_net2_c = c((hsnr_fold3_net2_nC, hsnr_fold3_net2_nE))
hsnr_fold3_net3_c = c((hsnr_fold3_net3_nC, hsnr_fold3_net3_nE))
hsnr_fold3_net4_c = c((hsnr_fold3_net4_nC, hsnr_fold3_net4_nE))
hsnr_fold3_net5_c = c((hsnr_fold3_net5_nC, hsnr_fold3_net5_nE))
print("hsnr_fold3_net1_c == ", hsnr_fold3_net1_c)
print("hsnr_fold3_net2_c == ", hsnr_fold3_net2_c)
print("hsnr_fold3_net3_c == ", hsnr_fold3_net3_c)
print("hsnr_fold3_net4_c == ", hsnr_fold3_net4_c)
print("hsnr_fold3_net5_c == ", hsnr_fold3_net5_c, "\n")

hsnr_fold4_net1_nE = 15
hsnr_fold4_net2_nE = 15
hsnr_fold4_net3_nE = 12
hsnr_fold4_net4_nE = 13
hsnr_fold4_net5_nE = 16
hsnr_fold4_net1_nC = 10
hsnr_fold4_net2_nC = 10
hsnr_fold4_net3_nC = 10
hsnr_fold4_net4_nC = 10
hsnr_fold4_net5_nC = 10
hsnr_fold4_net1_c = c((hsnr_fold4_net1_nC, hsnr_fold4_net1_nE))
hsnr_fold4_net2_c = c((hsnr_fold4_net2_nC, hsnr_fold4_net2_nE))
hsnr_fold4_net3_c = c((hsnr_fold4_net3_nC, hsnr_fold4_net3_nE))
hsnr_fold4_net4_c = c((hsnr_fold4_net4_nC, hsnr_fold4_net4_nE))
hsnr_fold4_net5_c = c((hsnr_fold4_net5_nC, hsnr_fold4_net5_nE))
print("hsnr_fold4_net1_c == ", hsnr_fold4_net1_c)
print("hsnr_fold4_net2_c == ", hsnr_fold4_net2_c)
print("hsnr_fold4_net3_c == ", hsnr_fold4_net3_c)
print("hsnr_fold4_net4_c == ", hsnr_fold4_net4_c)
print("hsnr_fold4_net5_c == ", hsnr_fold4_net5_c, "\n")

## Visualizing LSNR D4IC Results 01/27/2024

In [None]:
# D4IC LSNR Summary (01/27/2025):

import numpy as np
import pickle as pkl
from matplotlib import pyplot as plt
from matplotlib import gridspec

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title


def get_alg_name_alias_for_plot_axes(orig_name):
    if orig_name == 'REDCLIFF-S (cMLP)':
        return 'REDCLIFF-S\n(cMLP)'
    elif orig_name == 'REDCLIFF-S (cMLP)-vCosSim':
        return 'REDCLIFF-S\n(cMLP)'
    elif orig_name == 'REDCLIFF-S (cMLP)-vSC':
        return 'REDCLIFF-S\n(cMLP)' + " " + "*"
    elif orig_name == 'REDCLIFF-S (cMLP)' + " " + "\U000025C7":
        return 'REDCLIFF-S\n(cMLP)' + " " + "\U000025C7"
    elif orig_name == 'DYNOTEARS' + " " + "\U000025C7":
        return 'DYNOTEARS' + " " + "\U000025C7"
    elif orig_name == 'NAVAR-P' + " " + "\U000025C7":
        return 'NAVAR-P' + " " + "\U000025C7"
    elif orig_name == 'NAVAR-R' + " " + "\U000025C7":
        return 'NAVAR-R' + " " + "\U000025C7"
    elif orig_name == 'cMLP':
        return 'cMLP  '
    elif orig_name == 'CLSTM':
        return 'cLSTM '
    elif orig_name == 'DCSFA-NMF':
        return 'dCSFA-NMF'
    return orig_name

def get_alg_name_alias(orig_name):
    if orig_name == 'REDCLIFF_S_CMLP':
        return 'REDCLIFF-S (cMLP)'
    elif orig_name == 'REDCLIFF_S_CMLP_WithSmoothing':
        return 'REDCLIFF-S (cMLP)'
    elif orig_name == 'CMLP':
        return 'cMLP'
    elif orig_name == 'CLSTM':
        return 'cLSTM'
    elif orig_name == 'DCSFA':
        return 'dCSFA-NMF'
    elif orig_name == 'DYNOTEARS_Vanilla':
        return 'DYNOTEARS'
    elif orig_name == 'NAVAR_CMLP':
        return 'NAVAR-P'
    elif orig_name == 'NAVAR_CLSTM':
        return 'NAVAR-R'
    return orig_name


mean_colors = ["darkorange", "darkred", 'darkred', "mediumvioletred", "darkslateblue", 'darkslategrey', "grey", "black"]
sem_colors = ["orangered", "tomato", 'tomato', "lightcoral", "slategrey", "lightblue", "lightgrey", "darkgrey"]
alg_names = []
alg_performance_means = []
alg_performance_sems = []


# read in bCgs1v1223_REDCvNEWcMLP LSNR results
bsOH_lsnr_results = None
with open("bCgs1v1223_REDCvNEWcMLP_full_comparrisson_summary.pkl", "rb") as f:
    bsOH_lsnr_results = pkl.load(f)

print(bsOH_lsnr_results.keys())
for i, alg_key in enumerate(bsOH_lsnr_results['dream4_insilicoCombo_size10_LSNR']['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'].keys()):
    if alg_key in ['REDCLIFF_S_CMLP','REDCLIFF_S_CMLP_WithSmoothing','CMLP']:
        print("alg_key == ", alg_key)
        if alg_key == 'CMLP':
            alg_names.append(get_alg_name_alias(alg_key)+'-v2') # 'newer' CMLP, therefore dubbed v2
        else:
            assert alg_key == 'REDCLIFF_S_CMLP'
            alg_names.append(get_alg_name_alias(alg_key)+'-vCosSim') # this REDCLIFF-S model (from bCgs1v1223_REDCvNEWcMLP) has a better/smaller cosSim penalty

        alg_eval_stats = bsOH_lsnr_results['dream4_insilicoCombo_size10_LSNR']['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'][alg_key]
        for stat_key in alg_eval_stats.keys():
            if 'f1' in stat_key and "mean_across_factors" in stat_key:
                alg_performance_means.append(alg_eval_stats[stat_key])
            elif 'f1' in stat_key and "mean_std_err_across_factors" in stat_key:
                alg_performance_sems.append(alg_eval_stats[stat_key])

# read in bCgs1v1223_REDCvOGcMLP LSNR results
bsd4ic_lsnr_results = None
with open("bCgs1v1223_REDCvOGcMLP_full_comparrisson_summary.pkl", "rb") as f:
    bsd4ic_lsnr_results = pkl.load(f)

for i, alg_key in enumerate(bsd4ic_lsnr_results['dream4_insilicoCombo_size10_LSNR']['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'].keys()):
    print("alg_key == ", alg_key)
    if alg_key != "REDCLIFF_S_CMLP":
        if alg_key == "CMLP":
            alg_names.append(get_alg_name_alias(alg_key)+'-v1') # 'older' CMLP, therefore dubbed v1
        else:
            alg_names.append(get_alg_name_alias(alg_key))
        alg_eval_stats = bsd4ic_lsnr_results['dream4_insilicoCombo_size10_LSNR']['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'][alg_key]
        for stat_key in alg_eval_stats.keys():
                if 'f1' in stat_key and "mean_across_factors" in stat_key:
                    alg_performance_means.append(alg_eval_stats[stat_key])
                elif 'f1' in stat_key and "mean_std_err_across_factors" in stat_key:
                    alg_performance_sems.append(alg_eval_stats[stat_key])
        assert len(alg_performance_means) == i+2
        assert len(alg_performance_sems) == i+2


# ORGANIZE LISTS INTO SENSIBLE ORDERING FOR VISUALIZATION
alg_names = [alg_names[0]]+[alg_names[2]]+[alg_names[1]]+alg_names[3:]
alg_performance_means = [alg_performance_means[0]]+[alg_performance_means[2]]+[alg_performance_means[1]]+alg_performance_means[3:]
alg_performance_sems = [alg_performance_sems[0]]+[alg_performance_sems[2]]+[alg_performance_sems[1]]+alg_performance_sems[3:]


# remainder of code drafted with help from ChatGPT
# Create figure and axis
fig = plt.figure(figsize=(9.5, 8))
gs = gridspec.GridSpec(1, 1)
ax1 = plt.subplot(gs[0])


bar_width = 0.45
index = np.arange(len(alg_names))

# Horizontal bar plot with whiskers
for a, (alg_name, mean_color, sem_color) in enumerate(zip(alg_names, mean_colors, sem_colors)):
    curr_inds = None
    curr_means = None
    curr_sems = None

    curr_inds = [ind for ind in index if ind % len(alg_names) == a]
    curr_means = [alg_performance_means[ind] for ind in index if ind % len(alg_names) == a]
    curr_sems = [alg_performance_sems[ind] for ind in index if ind % len(alg_names) == a]
    for m in curr_means:
        if not np.isfinite(m):
            print("WARNING: m==", m, " when alg_name == ", alg_name)
    for s in curr_sems:
        if not np.isfinite(s):
            print("WARNING: s==", s, " when alg_name == ", alg_name)
    ax1.barh([ind - bar_width/2 for ind in curr_inds], curr_means, xerr=curr_sems, ecolor=sem_color, height=bar_width, color=mean_color, capsize=5, label=alg_name)

ax1.set_yticks(index-.25)
ax1.set_yticklabels([get_alg_name_alias_for_plot_axes(a) for a in alg_names], rotation=0)

ax1.set_xlim(0.26, 0.35)

# hide the spines between ax and ax2
ax1.yaxis.tick_left()

# Customize the grid: Add both major and minor grid lines
ax1.grid(True, axis='x', which='major', linestyle=':', linewidth=0.75, color='grey')  # Major grid lines
ax1.minorticks_on()  # Enable minor ticks
ax1.grid(True, axis='x', which='minor', linestyle=':', linewidth=0.5, color='lightgray')  # Minor grid lines

# Optional: Customize appearance
ax1.invert_yaxis()  # Invert y-axis to display the first category at the top

# Show plot
ax1.set_title('D4IC LSNR Edge Prediction')
ax1.set_xlabel('Avg. Optimal F1-Score '+r'$\pm$'+' Std. Err. of the Mean')
plt.tight_layout()
plt.show()


## Visualizing MSNR D4IC Results 01/27/2024

In [None]:
# D4IC MSNR Summary (01/27/2025):

import numpy as np
import pickle as pkl
from matplotlib import pyplot as plt
from matplotlib import gridspec

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title


def get_alg_name_alias_for_plot_axes(orig_name):
    if orig_name == 'REDCLIFF-S (cMLP)':
        return 'REDCLIFF-S\n(cMLP)'
    elif orig_name == 'REDCLIFF-S (cMLP)-vCosSim':
        return 'REDCLIFF-S\n(cMLP)'
    elif orig_name == 'REDCLIFF-S (cMLP)-vSC':
        return 'REDCLIFF-S\n(cMLP)' + " " + "*"
    elif orig_name == 'REDCLIFF-S (cMLP)' + " " + "\U000025C7":
        return 'REDCLIFF-S\n(cMLP)' + " " + "\U000025C7"
    elif orig_name == 'DYNOTEARS' + " " + "\U000025C7":
        return 'DYNOTEARS' + " " + "\U000025C7"
    elif orig_name == 'NAVAR-P' + " " + "\U000025C7":
        return 'NAVAR-P' + " " + "\U000025C7"
    elif orig_name == 'NAVAR-R' + " " + "\U000025C7":
        return 'NAVAR-R' + " " + "\U000025C7"
    elif orig_name == 'cMLP':
        return 'cMLP  '
    elif orig_name == 'CLSTM':
        return 'cLSTM '
    elif orig_name == 'DCSFA-NMF':
        return 'dCSFA-NMF'
    return orig_name

def get_alg_name_alias(orig_name):
    if orig_name == 'REDCLIFF_S_CMLP':
        return 'REDCLIFF-S (cMLP)'
    elif orig_name == 'REDCLIFF_S_CMLP_WithSmoothing':
        return 'REDCLIFF-S (cMLP)'
    elif orig_name == 'CMLP':
        return 'cMLP'
    elif orig_name == 'CLSTM':
        return 'cLSTM'
    elif orig_name == 'DCSFA':
        return 'dCSFA-NMF'
    elif orig_name == 'DYNOTEARS_Vanilla':
        return 'DYNOTEARS'
    elif orig_name == 'NAVAR_CMLP':
        return 'NAVAR-P'
    elif orig_name == 'NAVAR_CLSTM':
        return 'NAVAR-R'
    return orig_name


mean_colors = ["darkorange", "darkred", 'darkred', "mediumvioletred", "darkslateblue", "indigo", 'darkslategrey', "grey", "black"]
sem_colors = ["orangered", "tomato", 'tomato', "lightcoral", "slategrey", "mediumpurple", "lightblue", "lightgrey", "darkgrey"]
alg_names = []
alg_performance_means = []
alg_performance_sems = []


# read in bCgs1v1223_REDCvNEWcMLP_v0120 MSNR results
bsOH_msnr_results = None
with open("bCgs1v1223_REDCvNEWcMLP_v0120_full_comparrisson_summary.pkl", "rb") as f:
    bsOH_msnr_results = pkl.load(f)

print(bsOH_msnr_results.keys())
for i, alg_key in enumerate(bsOH_msnr_results['dream4_insilicoCombo_size10_v01192024_RERUN10242024_MSNR']['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'].keys()):
    if alg_key in ['REDCLIFF_S_CMLP','REDCLIFF_S_CMLP_WithSmoothing','CMLP']:
        print("alg_key == ", alg_key)
        if alg_key == 'CMLP':
            alg_names.append(get_alg_name_alias(alg_key)+'-v2') # 'newer' CMLP, therefore dubbed v2
        else:
            assert alg_key == 'REDCLIFF_S_CMLP'
            alg_names.append(get_alg_name_alias(alg_key)+'-vCosSim') # this REDCLIFF-S model (from bCgs1v1223_REDCvNEWcMLP) has a better/smaller cosSim penalty

        alg_eval_stats = bsOH_msnr_results['dream4_insilicoCombo_size10_MSNR']['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'][alg_key]
        for stat_key in alg_eval_stats.keys():
            if 'f1' in stat_key and "mean_across_factors" in stat_key:
                alg_performance_means.append(alg_eval_stats[stat_key])
            elif 'f1' in stat_key and "mean_std_err_across_factors" in stat_key:
                alg_performance_sems.append(alg_eval_stats[stat_key])

# read in bCgs1v1223_REDCvOGcMLP MSNR results
bsd4ic_msnr_results = None
with open("bCgs1v1223_REDCvOGcMLP_full_comparrisson_summary.pkl", "rb") as f:
    bsd4ic_msnr_results = pkl.load(f)

for i, alg_key in enumerate(bsd4ic_msnr_results['dream4_insilicoCombo_size10_MSNR']['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'].keys()):
    print("alg_key == ", alg_key)
    if alg_key != "REDCLIFF_S_CMLP":
        if alg_key == "CMLP":
            alg_names.append(get_alg_name_alias(alg_key)+'-v1') # 'older' CMLP, therefore dubbed v1
        else:
            alg_names.append(get_alg_name_alias(alg_key))
        alg_eval_stats = bsd4ic_msnr_results['dream4_insilicoCombo_size10_MSNR']['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'][alg_key]
        for stat_key in alg_eval_stats.keys():
                if 'f1' in stat_key and "mean_across_factors" in stat_key:
                    alg_performance_means.append(alg_eval_stats[stat_key])
                elif 'f1' in stat_key and "mean_std_err_across_factors" in stat_key:
                    alg_performance_sems.append(alg_eval_stats[stat_key])
        assert len(alg_performance_means) == i+2
        assert len(alg_performance_sems) == i+2


# ORGANIZE LISTS INTO SENSIBLE ORDERING FOR VISUALIZATION
alg_names = [alg_names[0]]+[alg_names[2]]+[alg_names[1]]+alg_names[3:]
alg_performance_means = [alg_performance_means[0]]+[alg_performance_means[2]]+[alg_performance_means[1]]+alg_performance_means[3:]
alg_performance_sems = [alg_performance_sems[0]]+[alg_performance_sems[2]]+[alg_performance_sems[1]]+alg_performance_sems[3:]


# remainder of code drafted with help from ChatGPT
# Create figure and axis
fig = plt.figure(figsize=(9.5, 8))
gs = gridspec.GridSpec(1, 1)
ax1 = plt.subplot(gs[0])


bar_width = 0.45
index = np.arange(len(alg_names))

# Horizontal bar plot with whiskers
for a, (alg_name, mean_color, sem_color) in enumerate(zip(alg_names, mean_colors, sem_colors)):
    curr_inds = None
    curr_means = None
    curr_sems = None

    curr_inds = [ind for ind in index if ind % len(alg_names) == a]
    curr_means = [alg_performance_means[ind] for ind in index if ind % len(alg_names) == a]
    curr_sems = [alg_performance_sems[ind] for ind in index if ind % len(alg_names) == a]
    for m in curr_means:
        if not np.isfinite(m):
            print("WARNING: m==", m, " when alg_name == ", alg_name)
    for s in curr_sems:
        if not np.isfinite(s):
            print("WARNING: s==", s, " when alg_name == ", alg_name)
    ax1.barh([ind - bar_width/2 for ind in curr_inds], curr_means, xerr=curr_sems, ecolor=sem_color, height=bar_width, color=mean_color, capsize=5, label=alg_name)

ax1.set_yticks(index-.25)
ax1.set_yticklabels([get_alg_name_alias_for_plot_axes(a) for a in alg_names], rotation=0)

ax1.set_xlim(0.28, 0.38)

# hide the spines between ax and ax2
ax1.yaxis.tick_left()

# Customize the grid: Add both major and minor grid lines
ax1.grid(True, axis='x', which='major', linestyle=':', linewidth=0.75, color='grey')  # Major grid lines
ax1.minorticks_on()  # Enable minor ticks
ax1.grid(True, axis='x', which='minor', linestyle=':', linewidth=0.5, color='lightgray')  # Minor grid lines

# Optional: Customize appearance
ax1.invert_yaxis()  # Invert y-axis to display the first category at the top

# Show plot
ax1.set_title('D4IC MSNR Edge Prediction')
ax1.set_xlabel('Avg. Optimal F1-Score '+r'$\pm$'+' Std. Err. of the Mean')
plt.tight_layout()
plt.show()


## Visualizing HSNR D4IC Results 01/27/2024

In [None]:
# D4IC HSNR Summary (01/27/2025):

import numpy as np
import pickle as pkl
from matplotlib import pyplot as plt
from matplotlib import gridspec

FONT_SMALL_SIZE = 18
FONT_MEDIUM_SIZE = 20
FONT_BIGGER_SIZE = 22

plt.rc('font', size=FONT_SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=FONT_BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=FONT_MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=FONT_SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=FONT_SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=FONT_BIGGER_SIZE)  # fontsize of the figure title


def get_alg_name_alias_for_plot_axes(orig_name):
    if orig_name == 'REDCLIFF-S (cMLP)':
        return 'REDCLIFF-S\n(cMLP)'
    elif orig_name == 'REDCLIFF-S (cMLP)-vCosSim':
        return 'REDCLIFF-S\n(cMLP)'
    elif orig_name == 'REDCLIFF-S (cMLP)-vSC':
        return 'REDCLIFF-S\n(cMLP)' + " " + "*"
    elif orig_name == 'REDCLIFF-S (cMLP)' + " " + "\U000025C7":
        return 'REDCLIFF-S\n(cMLP)' + " " + "\U000025C7"
    elif orig_name == 'DYNOTEARS' + " " + "\U000025C7":
        return 'DYNOTEARS' + " " + "\U000025C7"
    elif orig_name == 'NAVAR-P' + " " + "\U000025C7":
        return 'NAVAR-P' + " " + "\U000025C7"
    elif orig_name == 'NAVAR-R' + " " + "\U000025C7":
        return 'NAVAR-R' + " " + "\U000025C7"
    elif orig_name == 'cMLP':
        return 'cMLP  '
    elif orig_name == 'CLSTM':
        return 'cLSTM '
    elif orig_name == 'DCSFA-NMF':
        return 'dCSFA-NMF'
    return orig_name

def get_alg_name_alias(orig_name):
    if orig_name == 'REDCLIFF_S_CMLP':
        return 'REDCLIFF-S (cMLP)'
    elif orig_name == 'REDCLIFF_S_CMLP_WithSmoothing':
        return 'REDCLIFF-S (cMLP)'
    elif orig_name == 'CMLP':
        return 'cMLP'
    elif orig_name == 'CLSTM':
        return 'cLSTM'
    elif orig_name == 'DCSFA':
        return 'dCSFA-NMF'
    elif orig_name == 'DYNOTEARS_Vanilla':
        return 'DYNOTEARS'
    elif orig_name == 'NAVAR_CMLP':
        return 'NAVAR-P'
    elif orig_name == 'NAVAR_CLSTM':
        return 'NAVAR-R'
    return orig_name


mean_colors = ["darkorange", "darkred", 'darkred', "mediumvioletred", "darkslateblue", "indigo", 'darkslategrey', "grey", "black"]
sem_colors = ["orangered", "tomato", 'tomato', "lightcoral", "slategrey", "mediumpurple", "lightblue", "lightgrey", "darkgrey"]
alg_names = []
alg_performance_means = []
alg_performance_sems = []


# read in bCgs1v1223_REDCvNEWcMLP HSNR results
bsOH_hsnr_results = None
with open("bCgs1v1223_REDCvNEWcMLP_full_comparrisson_summary.pkl", "rb") as f:
    bsOH_hsnr_results = pkl.load(f)

print(bsOH_hsnr_results.keys())
for i, alg_key in enumerate(bsOH_hsnr_results['dream4_insilicoCombo_size10_HSNR']['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'].keys()):
    if alg_key in ['REDCLIFF_S_CMLP','REDCLIFF_S_CMLP_WithSmoothing','CMLP']:
        print("alg_key == ", alg_key)
        if alg_key == 'CMLP':
            alg_names.append(get_alg_name_alias(alg_key)+'-v2') # 'newer' CMLP, therefore dubbed v2
        else:
            assert alg_key == 'REDCLIFF_S_CMLP'
            alg_names.append(get_alg_name_alias(alg_key)+'-vCosSim') # this REDCLIFF-S model (from bCgs1v1223_REDCvNEWcMLP) has a better/smaller cosSim penalty

        alg_eval_stats = bsOH_hsnr_results['dream4_insilicoCombo_size10_HSNR']['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'][alg_key]
        for stat_key in alg_eval_stats.keys():
            if 'f1' in stat_key and "mean_across_factors" in stat_key:
                alg_performance_means.append(alg_eval_stats[stat_key])
            elif 'f1' in stat_key and "mean_std_err_across_factors" in stat_key:
                alg_performance_sems.append(alg_eval_stats[stat_key])

# read in bCgs1v1223_REDCvOGcMLP HSNR results
bsd4ic_hsnr_results = None
with open("bCgs1v1223_REDCvOGcMLP_full_comparrisson_summary.pkl", "rb") as f:
    bsd4ic_hsnr_results = pkl.load(f)

for i, alg_key in enumerate(bsd4ic_hsnr_results['dream4_insilicoCombo_size10_HSNR']['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'].keys()):
    print("alg_key == ", alg_key)
    if alg_key != "REDCLIFF_S_CMLP":
        if alg_key == "CMLP":
            alg_names.append(get_alg_name_alias(alg_key)+'-v1') # 'older' CMLP, therefore dubbed v1
        else:
            alg_names.append(get_alg_name_alias(alg_key))
        alg_eval_stats = bsd4ic_hsnr_results['dream4_insilicoCombo_size10_HSNR']['key_stats_estGC_normOffDiag_vs_trueGC_normOffDiag'][alg_key]
        for stat_key in alg_eval_stats.keys():
                if 'f1' in stat_key and "mean_across_factors" in stat_key:
                    alg_performance_means.append(alg_eval_stats[stat_key])
                elif 'f1' in stat_key and "mean_std_err_across_factors" in stat_key:
                    alg_performance_sems.append(alg_eval_stats[stat_key])
        assert len(alg_performance_means) == i+2
        assert len(alg_performance_sems) == i+2


# ORGANIZE LISTS INTO SENSIBLE ORDERING FOR VISUALIZATION
alg_names = [alg_names[0]]+[alg_names[2]]+[alg_names[1]]+alg_names[3:]
alg_performance_means = [alg_performance_means[0]]+[alg_performance_means[2]]+[alg_performance_means[1]]+alg_performance_means[3:]
alg_performance_sems = [alg_performance_sems[0]]+[alg_performance_sems[2]]+[alg_performance_sems[1]]+alg_performance_sems[3:]


# remainder of code drafted with help from ChatGPT
# Create figure and axis
fig = plt.figure(figsize=(9.5, 8))
gs = gridspec.GridSpec(1, 1)
ax1 = plt.subplot(gs[0])


bar_width = 0.45
index = np.arange(len(alg_names))

# Horizontal bar plot with whiskers
for a, (alg_name, mean_color, sem_color) in enumerate(zip(alg_names, mean_colors, sem_colors)):
    curr_inds = None
    curr_means = None
    curr_sems = None

    curr_inds = [ind for ind in index if ind % len(alg_names) == a]
    curr_means = [alg_performance_means[ind] for ind in index if ind % len(alg_names) == a]
    curr_sems = [alg_performance_sems[ind] for ind in index if ind % len(alg_names) == a]
    for m in curr_means:
        if not np.isfinite(m):
            print("WARNING: m==", m, " when alg_name == ", alg_name)
    for s in curr_sems:
        if not np.isfinite(s):
            print("WARNING: s==", s, " when alg_name == ", alg_name)
    ax1.barh([ind - bar_width/2 for ind in curr_inds], curr_means, xerr=curr_sems, ecolor=sem_color, height=bar_width, color=mean_color, capsize=5, label=alg_name)

ax1.set_yticks(index-.25)
ax1.set_yticklabels([get_alg_name_alias_for_plot_axes(a) for a in alg_names], rotation=0)

ax1.set_xlim(0.28, 0.39)

# hide the spines between ax and ax2
ax1.yaxis.tick_left()

# Customize the grid: Add both major and minor grid lines
ax1.grid(True, axis='x', which='major', linestyle=':', linewidth=0.75, color='grey')  # Major grid lines
ax1.minorticks_on()  # Enable minor ticks
ax1.grid(True, axis='x', which='minor', linestyle=':', linewidth=0.5, color='lightgray')  # Minor grid lines

# Optional: Customize appearance
ax1.invert_yaxis()  # Invert y-axis to display the first category at the top

# Show plot
ax1.set_title('D4IC HSNR Edge Prediction')
ax1.set_xlabel('Avg. Optimal F1-Score '+r'$\pm$'+' Std. Err. of the Mean')
plt.tight_layout()
plt.show()


#    

---
