In [None]:
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange
import os 
os.makedirs("figures", exist_ok=True)

MARKERS = {
    "CAL1": ("#0072B2", "solid"),  # blue - solid
    "CUB": ("#D55E00", "solid"),  # vermillion - solid
    "LCAL": ("#009E73", "solid"),  # bluish green - solid
    "LMAL": ("#CC79A7", "solid"),  # reddish purple - solid
    "MCAL": ("#E69F00", "solid"),  # orange - solid
    "MMAL": ("#56B4E9", "solid"),  # sky blue - solid
    "MT1B": ("#0072B2", "dashed"),  # blue - dashed
    "MT1H": ("#D55E00", "dashed"),  # vermillion - dashed
    "MT2H": ("#009E73", "dashed"),  # bluish green - dashed
    "MT5B": ("#CC79A7", "dashed"),  # reddish purple - dashed
    "MT5H": ("#E69F00", "dashed"),  # orange - dashed
    "NAV": ("#56B4E9", "dashed"),  # sky blue - dashed
    "TOE": ("#888888", "solid"),  # grey - solid
}

runs = {
    "Gaussian Process": "test6/aggregate_results_low_freq.npz",
    "Forces": "test7/aggregate_results.npz",
    "Group Lasso": "test8/aggregate_results.npz",
    "Left-Only": "test5/aggregate_results.npz",
    "Right-Only": "test5a/aggregate_results.npz",
}

files = [np.load(path) for path in runs.values()]
lambdas = [f["l"] for f in files]
thetas = [f["theta"] for f in files]
r2 = [f["r2"] for f in files]
glasso = [f["glasso"] for f in files]
r2_avg = [
    f["r2_avg"] if "r2_avg" in f.files else np.full_like(f["glasso"], np.nan)
    for f in files
]

active = [g > 1e-4 for g in glasso]
r2_adj = [np.where(a, 1.0, r) for a, r in zip(active, r2_avg)]

In [None]:
from matplotlib import markers

markers = [[] for _ in range(len(runs))]
for j, title in enumerate(runs.keys()):
    enters_at = active[j].argmin(axis=0)
    enters_at = np.where(enters_at == 0, np.inf, enters_at)
    markers[j] = [list(MARKERS.keys())[i] for i in np.argsort(enters_at)[::-1]]


for j, title in enumerate(runs.keys()):
    print(f"& \\textbf{{{title}}}", end="")
print(f"\\\\")
print(f"\\hline")

for i in range(len(MARKERS)):
    print(f"{i+1}", end="")
    for j, title in enumerate(runs.keys()):
        print(f" & {markers[j][i]}", end="")
    print(f"\\\\")

In [None]:
for j, title in enumerate(runs.keys()):
    plt.figure(figsize=(8, 9))
    if title == "Forces":
        plt.subplot(2, 1, 1)
        for i, (m, (c, s)) in enumerate(MARKERS.items()):
            plt.plot(lambdas[j], glasso[j][:, i], label=m, color=c, linestyle=s)
        plt.grid()
        plt.ylim(1e-3, None)
        plt.xlim(lambdas[j].min(), lambdas[j].max())
        plt.xscale("log")
        plt.yscale("log")
        plt.ylabel("$||\\Theta_{m:::}||$")
        plt.xlabel("$\\lambda$")
        plt.legend(loc="lower left")

        plt.subplot(2, 1, 2)
        plt.plot(lambdas[j], r2[j][:, 0], label="X", color="#0072B2", linestyle="solid")
        plt.plot(lambdas[j], r2[j][:, 1], label="Y", color="#D55E00", linestyle="solid")
        plt.plot(lambdas[j], r2[j][:, 2], label="Z", color="#009E73", linestyle="solid")
        plt.ylim(0.0, 1.05)
        plt.grid()
        plt.legend(loc="lower left")
        plt.xscale("log")
        plt.xlim(lambdas[j].min(), lambdas[j].max())
        plt.xlabel("$\\lambda$")
        plt.ylabel("$R^2$")
        plt.legend(loc="lower left")
    else:
        plt.subplot(2, 1, 1)
        for i, (m, (c, s)) in enumerate(MARKERS.items()):
            plt.plot(lambdas[j], glasso[j][:, i], label=m, color=c, linestyle=s)
        plt.grid()
        plt.ylim(1e-3, None)
        plt.xlim(lambdas[j].min(), lambdas[j].max())
        plt.xscale("log")
        plt.yscale("log")
        plt.ylabel("$||\\Theta_{m:::}||$")
        plt.xlabel("$\\lambda$")

        plt.subplot(2, 1, 2)
        for i, (m, (c, s)) in enumerate(MARKERS.items()):
            plt.plot(lambdas[j], r2_adj[j][:, i], label=m, color=c, linestyle=s)
        plt.ylim(0.0, 1.05)
        plt.grid()
        plt.xscale("log")
        plt.xlim(lambdas[j].min(), lambdas[j].max())
        plt.xlabel("$\\lambda$")
        plt.ylabel("$R^2$")
        plt.legend(loc="lower left")


    plt.tight_layout()
    plt.savefig(f"figures/parameters_trajectory_{title}.pdf")
    plt.show()

In [None]:
for j, title in enumerate(runs.keys()):
    plt.figure(figsize=(8, 9))
    if title == "Forces":
        plt.subplot(2, 1, 1)
        for i, (m, (c, s)) in enumerate(MARKERS.items()):
            plt.plot(lambdas[j], glasso[j][:, i], label=m, color=c, linestyle=s)
        plt.grid()
        plt.ylim(1e-3, None)
        plt.xlim(lambdas[j].min(), lambdas[j].max())
        plt.xscale("log")
        plt.yscale("log")
        plt.ylabel("$||\\Theta_{m:::}||$")
        plt.xlabel("$\\lambda$")
        plt.legend(loc="lower left")

        plt.subplot(2, 1, 2)
        plt.plot(lambdas[j], 1-r2[j][:, 0], label="X", color="#0072B2", linestyle="solid")
        plt.plot(lambdas[j], 1-r2[j][:, 1], label="Y", color="#D55E00", linestyle="solid")
        plt.plot(lambdas[j], 1-r2[j][:, 2], label="Z", color="#009E73", linestyle="solid")
        # plt.ylim(0.0, 1.05)
        plt.grid()
        plt.legend(loc="lower left")
        plt.xscale("log")
        plt.xlim(lambdas[j].min(), lambdas[j].max())
        plt.xlabel("$\\lambda$")
        plt.ylabel("$R^2$")
        plt.yscale("log")
        plt.gca().invert_yaxis()
        plt.yticks([0.05, 0.1, 0.2, 0.5, 1], ["0.95", "0.9", "0.8", "0.5", "0.0"])
        plt.legend(loc="lower left")
    else:
        plt.subplot(2, 1, 1)
        for i, (m, (c, s)) in enumerate(MARKERS.items()):
            plt.plot(lambdas[j], glasso[j][:, i], label=m, color=c, linestyle=s)
        plt.grid()
        plt.ylim(1e-3, None)
        plt.xlim(lambdas[j].min(), lambdas[j].max())
        plt.xscale("log")
        plt.yscale("log")
        plt.ylabel("$||\\Theta_{m:::}||$")
        plt.xlabel("$\\lambda$")

        plt.subplot(2, 1, 2)
        for i, (m, (c, s)) in enumerate(MARKERS.items()):
            plt.plot(lambdas[j], 1-r2_adj[j][:, i], label=m, color=c, linestyle=s)
        # plt.ylim(0.0, 1.05)
        plt.grid()
        plt.xscale("log")
        plt.xlim(lambdas[j].min(), lambdas[j].max())
        plt.xlabel("$\\lambda$")
        plt.ylabel("$R^2$")
        plt.yscale("log")
        plt.gca().invert_yaxis()
        plt.yticks([1e-4, 1e-3, 1e-2, 1e-1, 1], ["0.9999", "0.999", "0.99", "0.9", "0.0"])
        plt.legend(loc="lower left")


    plt.tight_layout()
    plt.savefig(f"figures/parameters_trajectory_logscale_{title}.pdf")
    plt.show()

In [None]:
plt.figure(figsize=(8, 5))
colors = ["#0072B2", "#D55E00", "#CC79A7", "#009E73", "#E69F00", "#56B4E9"]
for j, title in enumerate(runs.keys()):
    if title in ["Gaussian Process", "Group Lasso"]:
        # this horrible code is cause markers can enter and exit multiple times
        ns = np.array([n for n in sorted(np.unique(active[j].sum(axis=1))) if n > 0])
        r = np.array([r2_adj[j][active[j].sum(axis=1) == n].min(0) for n in ns])
        plt.plot(ns, r.mean(axis=1), "o-", label=title, color=colors[j], linewidth=3)
        for i, (m, (c, s)) in enumerate(MARKERS.items()):
            plt.plot(ns, r[:, i], "-", color=colors[j], alpha=0.3)
plt.xlabel("Number of Markers used")
plt.ylabel("$R^2$")
# legend on bottom right corner
plt.legend(loc="lower right")
plt.grid()
plt.ylim(0.0, 1.05)
# plt.xscale("log")
plt.xticks(list(range(1, 14)), [str(i) for i in range(1, 14)])
plt.tight_layout()
plt.savefig(f"figures/accuracy_sparsity_tradeoff.pdf")
plt.show()


plt.figure(figsize=(8, 5))
for j, title in enumerate(runs.keys()):
    if title in ["Gaussian Process", "Group Lasso"]:
        # this horrible code is cause markers can enter and exit multiple times
        ns = np.array([n for n in sorted(np.unique(active[j].sum(axis=1))) if n > 0])
        r = np.array([r2_adj[j][active[j].sum(axis=1) == n].min(0) for n in ns])
        plt.plot(
            ns, 1 - r.mean(axis=1), "o-", label=title, color=colors[j], linewidth=3
        )
        for i, (m, (c, s)) in enumerate(MARKERS.items()):
            plt.plot(ns, (1 - r[:, i]), "-", color=colors[j], alpha=0.3)
        print(f"{title}: {ns}, {r.mean(axis=1)}")
plt.xlabel("Number of Markers used")
plt.ylabel("$R^2$")
# legend on bottom right corner
plt.legend(loc="lower right")
plt.grid()
plt.yscale("log")
plt.gca().invert_yaxis()
plt.yticks(
    [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1],
    ["0.99999", "0.9999", "0.999", "0.99", "0.9", "0.0"],
)
plt.xticks(list(range(1, 14)), [str(i) for i in range(1, 14)])
plt.tight_layout()
plt.savefig(f"figures/accuracy_sparsity_tradeoff_logscale.pdf")
plt.show()