In [None]:
%cd ../

In [None]:
from layout_eval.measures.mmd import estimate_mmd, convert_emd_to_affinity
from experiments.response_analysis import _load_xx_xy
import numpy as np
import os
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
def mmd_on_subset(xx, yy, xy, sample_size):
    idx = np.random.choice(len(xx), size=sample_size, replace=False)
    xx_ = xx[idx][:, idx]
    xy_ = xy[:, idx]
    mmd = estimate_mmd(xx_, yy, xy_)
    return mmd

def run_downsample_exp(base_dir, yy_file):
    yy = np.load(yy_file)
    sigma = np.median(yy[np.triu_indices(len(yy))])
    yy = convert_emd_to_affinity(yy, sigma)

    results = {"sample_size": [], "mmd": [], "noise_rate": []}
    for noise_rate in [0.1, 0.2, 0.3, 0.4, 0.5]:
        file_dir = os.path.join(base_dir, f"elem_noise_rate_{noise_rate}")
        for file in os.listdir(file_dir):
            if file.endswith(".csv"):
                xx, xy = _load_xx_xy(os.path.join(file_dir, file))
                xx = convert_emd_to_affinity(xx, sigma)
                xy = convert_emd_to_affinity(xy, sigma)
                mmd = estimate_mmd(xx, yy, xy)
                results["sample_size"].append(len(xx))
                results["mmd"].append(mmd)
                results["noise_rate"].append(noise_rate)

                for sample_size in [100, 200, 500, 1000, 1500]:
                    for _ in range(10):
                        mmd = mmd_on_subset(xx, yy, xy, sample_size)
                        results["sample_size"].append(sample_size)
                        results["mmd"].append(mmd)
                        results["noise_rate"].append(noise_rate)
    return results

In [None]:
results = run_downsample_exp("data/dataflow/outputs/response_analysis/small_spatial_and_label_noise", "data/mmd/rico_val_emd_yy.npy")

In [None]:
plt.figure(figsize=(4, 4))
ax = plt.gca()
sns.lineplot(x="sample_size", y="mmd", hue="noise_rate", data=results, palette="crest", ax=ax)
ax.set_xlabel("Sample Size")
ax.set_ylabel("MMD")
sns.despine()
plt.savefig("figs/sample_size_analysis.pdf", bbox_inches="tight")