Load Data

In [None]:
import torch
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import os
import json
from utils import tensor_to_image, get_ssim_all
from reconstruction import compute_recovery_order_MSE
from tqdm import tqdm

In [None]:
RUNS_DIR = "outputs/"
INFO_EXTENSION = "_info.json"

In [None]:
if not os.path.exists("./figures"):
    os.makedirs("./figures")

In [None]:
plt.rcParams.update({"font.size": 20, "legend.fontsize": 14})
markers = ["o", "s", "^", "P"]
cud_colors = ["#377eb8", "#ff7f00", "#4daf4a", "#984ea3"]  # Blue, Orange, Green, Dark Purple
plt.rcParams["axes.prop_cycle"] = plt.cycler(color=cud_colors)


In [None]:
def filter_df(df, filters_dict):
    for k, v in filters_dict.items():
        if isinstance(v, int) or isinstance(v, float) or isinstance(v, str):
            df = df[df[k] == v]
        elif isinstance(v, list):
            df = df[df[k].isin(v)]
        else:
            raise ValueError("Filter values must be int, str or list")
    return df

In [None]:
data = []

for filename in os.listdir(RUNS_DIR):
    if filename.endswith(INFO_EXTENSION):
        with open(os.path.join(RUNS_DIR, filename), "r") as f:
            run_data = json.load(f)
            data.append(run_data)

df_full = pd.json_normalize(data)


In [None]:
def get_recovery_statistics(original_data, recovered_images):
    min_n = min(original_data.shape[0], recovered_images.shape[0])
    dssims_all = (1 - get_ssim_all(recovered_images, original_data)) / 2
    dssims_best, nns_dssim = dssims_all.min(dim=0)
    mse_best, nns_mse, recovery_order_mse = compute_recovery_order_MSE(recovered_images, original_data)

    qs = torch.tensor([0.25, 0.5, 0.75]).to(device=mse_best.device, dtype=mse_best.dtype)
    mse_quantiles = torch.quantile(mse_best[:min_n], qs)
    dssim_quantiles = torch.quantile(dssims_best[:min_n], qs)

    dssims_best_reversed, nns_dssim_reversed = dssims_all.min(dim=1)
    unique_reconstructions = []
    for i, j in enumerate(nns_dssim):
        if nns_dssim_reversed[j] == i:
            unique_reconstructions.append(i)
    unique_recoveries = len(unique_reconstructions)
    good_unique_recoveries = torch.count_nonzero(dssims_best[unique_reconstructions] <= 0.3).item()

    return {
        "unique_recoveries_percent": unique_recoveries / min_n * 100,
        "good_uniques_percent": good_unique_recoveries / min_n * 100,
        "dssim_q0": round(dssim_quantiles[0].item(), 3),
        "dssim_q1": round(dssim_quantiles[1].item(), 3),
        "dssim_q2": round(dssim_quantiles[2].item(), 3),
        "mse_q0": round(mse_quantiles[0].item(), 3),
        "mse_q1": round(mse_quantiles[1].item(), 3),
        "mse_q2": round(mse_quantiles[2].item(), 3),
    }, dssims_best.numpy(), nns_dssim.numpy(), mse_best.numpy(), nns_mse.numpy(), unique_reconstructions


def get_run_statistics_from_name(run_name, return_full=False):
    saved_data = np.load(os.path.join(RUNS_DIR, run_name + "_results.npz"))
    original_data, recovered_images = torch.tensor(saved_data['trainig_data']), torch.tensor(saved_data['reconstructions'])
    if return_full:
        return *get_recovery_statistics(original_data, recovered_images), tensor_to_image(original_data), tensor_to_image(recovered_images)
    return get_recovery_statistics(original_data, recovered_images)[0]

Reconstructions Quality Laplace and RBF

In [None]:
filters_lap = {"recovery_size": 500, "target_model": "laplace", "dataset": "CIFAR5M", "horiz_flip": False, "recovery_gamma": 0.15, 
                    "target_gamma": 0.15, "recovery_pca_dim": 0, "use_test_time_seed": True, "target_regularization": 0}
filters_rbf = {"recovery_size": 500, "target_model": "gaussian", "dataset": "CIFAR5M", "horiz_flip": False, "recovery_gamma": 0.003, "recovery_pca_dim": 0, 
                    "use_test_time_seed": True, "target_regularization": 0}

In [None]:
plt.figure(figsize=(12, 6))
marker_count = 0

for model_name, filters in [("Laplace", filters_lap), ("RBF", filters_rbf)]:
    df = filter_df(df_full, filters)
    df = df[df.batch_size < 500000]
    df = df.sort_values("target_task")
    for i, task in enumerate(df['target_task'].unique()):
        cur_df = df[df['target_task'] == task].sort_values("batch_size")
        dssim_values = []
        x_ticks = []

        for _, row in cur_df.iterrows():
            stats = get_run_statistics_from_name(row['wandb_run_name'])
            dssim_values.append(stats['dssim_q1'])
            x_ticks.append(row.batch_size / (10 ** 5))

        plt.plot(x_ticks, dssim_values, label=f"{model_name} {task.upper()}", marker=markers[marker_count], linewidth=5, markersize=12)
        marker_count += 1

plt.axvline(x=500 * (3072 + 10) / 10 / 10**5 , color='black', linestyle='--', linewidth=2, label=r"$m=\frac{n(d+C)}{C}$")
plt.legend()
plt.grid(visible=True, which='major', linestyle='--', linewidth=0.5)
plt.xlabel(r"Query Points ($\times 10^5$)")
plt.ylabel("Median DSSIM")
plt.tight_layout()
plt.savefig("figures/dssim_vs_equations.png")

Equations vs Reconstruction Quality PCA

In [None]:
filters_dict = {"recovery_size": 100, "dataset": "CIFAR10", "horiz_flip": False, "recovery_gamma": 0.15, "recovery_model": "laplace"}
df = filter_df(df_full, filters_dict)
df = df[(df.batch_size < 500000) & (df.batch_size % 10000 == 0)]

In [None]:
cmap = plt.get_cmap("tab10")
num_plots = len(df['recovery_pca_dim'].unique())
colors = [cmap(i / num_plots) for i in range(num_plots)]
df.loc[df['recovery_pca_dim'] == 0, 'recovery_pca_dim'] = 3072

plt.figure(figsize=(12, 6))
marker_count = 0

for i, dim in enumerate(np.sort(df['recovery_pca_dim'].unique())):
    cur_df = df[df['recovery_pca_dim'] == dim].sort_values("batch_size")
        
    dssim_values = []
    x_ticks = []
    params = 100 * (3072 + 10)

    for _, row in tqdm(cur_df.iterrows()):
        stats = get_run_statistics_from_name(row['wandb_run_name'])
        dssim_values.append(stats['dssim_q1'])
        x_ticks.append(row.batch_size / 10000)

    label = dim if dim != 3072 else "3072 (No PCA)" 
    plt.plot(x_ticks, dssim_values, label=label, color=colors[i], marker=markers[marker_count], linewidth=5, markersize=12)
    marker_count += 1

plt.axvline(x=100 * (3072 + 10) / 10 / 10**4 , color='black', linestyle='--', linewidth=2, label=r"$m=\frac{n(d+C)}{C}$")
plt.legend(title="PCA Dimension")
plt.grid(visible=True, which='major', linestyle='--', linewidth=0.5)
plt.xlabel(r"Query Points ($\times 10^4$)")
plt.ylabel("Median DSSIM")
plt.tight_layout()
plt.savefig("figures/pca_cifar10_n100.png")

Reconstruction Quality vs Gamma

In [None]:
filters_laplace = {"recovery_size": 500, "target_model": "laplace", "dataset": "CIFAR5M", "horiz_flip": False,
                       "recovery_pca_dim": 0, "use_test_time_seed": True, "target_regularization": 0, "batch_size": 500000, 
                       "target_task": "krr"}
df = filter_df(df_full, filters_laplace)

In [None]:
plt.figure(figsize=(12, 6))

cur_df = df.sort_values("recovery_gamma")
dssims = []
gammas = []

for _, row in cur_df.iterrows():
    stats = get_run_statistics_from_name(row['wandb_run_name'])
    dssims.append(stats['dssim_q1'])
    gammas.append(row.recovery_gamma)

plt.plot(gammas, dssims, marker="o", linewidth=5, markersize=12)

plt.legend()
plt.grid(visible=True, which='major', linestyle='--', linewidth=0.5)
plt.xlabel(r"$\gamma$")
plt.ylabel("Median DSSIM")
plt.tight_layout()
ticks = cur_df.recovery_gamma.unique()[::2]
plt.xticks(ticks)
plt.gca().xaxis.set_major_formatter(plt.FuncFormatter(lambda val, _: f"{val:.2f}"))
plt.savefig("figures/gamma_laplace.png")

Main CIFAR Results

In [None]:
# Replace every WANDB_RUN_NAME with the relevant run name (e.g. "floral-morning-48")
our_runs_cifar = [
        "WANDB_RUN_NAME", #Laplace
        "WANDB_RUN_NAME", #RBF
        "WANDB_RUN_NAME", #NTK
        "WANDB_RUN_NAME" #polynomial kernel
]

our_labels_cifar = ["Laplace", "RBF", "NTK", "Cubic Polynomial"]

In [None]:
our_stats_cifar = {}
our_dssims_cifar = {}
our_nns_dssim_cifar = {}
our_mses_cifar = {}
our_nns_mse_cifar = {}
our_recovered_images_cifar = {}
ours_mutual_nn_cifar = {}
original_data = None

for i, k in enumerate(our_runs_cifar):
    stats, dssims, nns_dssim, mses, nns_mses, mutual_nns, cur_original_data, recovered_images= get_run_statistics_from_name(k, return_full=True)
    if original_data is None:
        original_data = cur_original_data
    else:
        assert np.allclose(original_data, cur_original_data)
    our_stats_cifar[our_labels_cifar[i]] = stats
    our_dssims_cifar[our_labels_cifar[i]] = dssims
    our_nns_dssim_cifar[our_labels_cifar[i]] = nns_dssim
    our_recovered_images_cifar[our_labels_cifar[i]] = recovered_images
    our_mses_cifar[our_labels_cifar[i]] = mses
    our_nns_mse_cifar[our_labels_cifar[i]] = nns_mses
    ours_mutual_nn_cifar[our_labels_cifar[i]] = mutual_nns

# count per image how often is is in the mutual nn list
mutual_nn_counts = np.unique(np.concatenate(list(ours_mutual_nn_cifar.values())), return_counts=True)
good_by_mutual_nn = mutual_nn_counts[0][mutual_nn_counts[1] >= 2]
# all_dists = np.stack([dssims for dssims in our_dssims_cifar.values()])
all_dists = np.stack([mses for mses in our_mses_cifar.values()])
nns = our_nns_mse_cifar
recovery_order = np.argsort(all_dists.mean(axis=0))

fig, axes = plt.subplots(5, 20, figsize=(40, 10))
for i, (k, dssims) in enumerate(our_dssims_cifar.items()):
    recovered_data = our_recovered_images_cifar[k]

    for j in range(20):
        axes[i, j].imshow(recovered_data[nns[k][recovery_order[j]]])
        axes[i, j].axis('off')

for j in range(20):
    axes[-1, j].imshow(original_data[recovery_order[j]])
    axes[-1, j].axis('off')

plt.subplots_adjust(wspace=0, hspace=0) 
fig.text(0.114, 0.8, 'Laplace', ha = 'center', fontsize = 20, rotation = 90, va = 'center')
fig.text(0.114, 0.65, 'RBF', ha = 'center', fontsize = 20, rotation = 90, va = 'center')
fig.text(0.114, 0.5, 'NTK', ha = 'center', fontsize = 20, rotation = 90, va = 'center')
fig.text(0.114, 0.35, 'Cubic', ha = 'center', fontsize = 20, rotation = 90, va = 'center')
fig.text(0.114, 0.19, 'Dataset', ha = 'center', fontsize = 20, rotation = 90, va = 'center')
fig.savefig("figures/cifar top reconstructions.png")

In [None]:
plt.figure(figsize=(12, 6))
marker_count = 0

for k, dssims in our_dssims_cifar.items():
    plt.plot(np.arange(1, 501), np.sort(dssims), label=k, linewidth=5)
    marker_count += 1

plt.legend()
plt.grid(visible=True, which='major', linestyle='--', linewidth=0.5)
plt.xlabel("Image Index")
plt.ylabel("DSSIM")
plt.tight_layout()
plt.savefig('figures/recon_quality_plot.png')

In [None]:
plt.figure(figsize=(12, 6))
marker_count = 0

max_dssim = np.max([np.max(dssims) for dssims in our_dssims_cifar.values()])

for k, dssims in our_dssims_cifar.items():
    sorted_dssims = np.concatenate([np.array([0]), np.sort(dssims), np.array([max_dssim])])
    proportion = np.arange(sorted_dssims.shape[0]) / sorted_dssims.shape[0]
    plt.plot(sorted_dssims, proportion, label=k, linewidth=5)
    marker_count += 1

plt.legend()
plt.grid(visible=True, which='major', linestyle='--', linewidth=0.5)
plt.xlabel("DSSIM")
plt.ylabel("Proportion")
plt.tight_layout()
plt.savefig('figures/recon_quality_edcf.png')

In [None]:
for k, stats in our_stats_cifar.items():
    print(f"{k}: {stats}\n")