In [4]:
import torch
from labproject.data import get_dataset, DATASETS
from labproject.metrics import METRICS
import numpy as np

from labproject.metrics.utils import get_metric

import matplotlib.pyplot as plt

torch.manual_seed(0)

<torch._C.Generator at 0x7fcbe0c08ab0>

In [5]:
DATASETS

{'random': <function labproject.data.random_dataset(n=1000, d=10)>,
 'multivariate_normal': <function labproject.data.multivariate_normal(n=3000, dims=100, means=None, vars=None, distort=None)>,
 'toy_2d': <function labproject.data.toy_2d(n=1000, d=2)>,
 'cifar10_train': <function labproject.data.cifar10_train(n=1000, d=2048, save_path='data', device='cpu', return_labels=False)>,
 'cifar10_test': <function labproject.data.cifar10_test(n=1000, d=2048, save_path='data', device='cpu', return_labels=False)>,
 'imagenet_real_embeddings': <function labproject.data.imagenet_real_embeddings(n=1000, d=2048)>,
 'imagenet_uncond_embeddings': <function labproject.data.imagenet_uncond_embeddings(n=1000, d=2048)>,
 'imagenet_unconditional_model_embedding': <function labproject.data.imagenet_unconditional_model_embedding(n, d=2048, device='cpu', save_path='data', permute=False)>,
 'imagenet_test_embedding': <function labproject.data.imagenet_test_embedding(n, d=2048, device='cpu', save_path='data')>,

In [6]:
METRICS

{'mmd_rbf': <function labproject.metrics.MMD_torch.compute_rbf_mmd(x, y, bandwidth=1.0)>,
 'mmd_rbf_median_heuristic': <function labproject.metrics.MMD_torch.compute_rbf_mmd_median_heuristic(x, y)>,
 'mmd_rbf_auto': <function labproject.metrics.MMD_torch.compute_rbf_mmd_auto(x, y, bandwidth=1.0)>,
 'mmd_polynomial': <function labproject.metrics.MMD_torch.compute_polynomial_mmd(x, y, degree=2, bias=0)>,
 'mmd_linear_naive': <function labproject.metrics.MMD_torch.compute_linear_mmd_naive(x, y)>,
 'mmd_linear': <function labproject.metrics.MMD_torch.compute_linear_mmd(x, y)>,
 'mmd_energy': <function labproject.metrics.MMD_torch.compute_energy_mmd(x, y)>,
 'c2st_nn': <function labproject.metrics.c2st.c2st_nn(X: torch.Tensor, Y: torch.Tensor, seed: int = 1, n_folds: int = 5, metric: str = 'accuracy', z_score: bool = True, activation: Literal['identity', 'logistic', 'tanh', 'relu'] = 'relu', clf_kwargs: dict[str, typing.Any] = {}) -> torch.Tensor>,
 'c2st_rf': <function labproject.metrics.c

In [7]:
metric_fn = get_metric("wasserstein_gauss_squared")
metric_fn2 = get_metric("sliced_wasserstein")

In [8]:
datasets = ["imagenet_unconditional_model_embedding","imagenet_cs1_embedding", "imagenet_cs10_embedding", "imagenet_biggan_embedding", "imagenet_sdv4_embedding", "imagenet_sdv5_embedding", "imagenet_vqdm_embedding", "imagenet_wukong_embedding", "imagenet_adm_embedding", "imagenet_midjourney_embedding"]
metrics = ["wasserstein_gauss_squared", "sliced_wasserstein"]

In [9]:
testset_fn = get_dataset("imagenet_test_embedding")

In [10]:
testset = testset_fn(100_000, 2048)
idx = torch.randperm(len(testset))
testset = testset[idx]

In [23]:
results_fid = {}

In [24]:
torch.manual_seed(0)
metric = metrics[0]
metric_fn = get_metric(metric)
for dname in datasets:
    metric_values = []
    for j in range(5):
        data_test = testset[j*20_000:(j+1)*20_000]
        if dname == "imagenet_midjourney_embedding":
            data_syn = get_dataset(dname)(10_000, 2048, permute=False)
        else:
            data_syn = get_dataset(dname)(20_000, 2048, permute=True)
        m = metric_fn(data_test, data_syn)
        metric_values.append(m)
    results_fid[dname] = np.array(metric_values)

In [26]:
results_fid

{'imagenet_unconditional_model_embedding': array([6.2195807, 6.1363096, 6.1779237, 6.1323247, 6.1145873],
       dtype=float32),
 'imagenet_cs1_embedding': array([6.430565 , 6.3383527, 6.3908176, 6.392403 , 6.3483315],
       dtype=float32),
 'imagenet_cs10_embedding': array([6.9664445, 6.934556 , 6.985526 , 6.980446 , 6.945889 ],
       dtype=float32),
 'imagenet_biggan_embedding': array([12.701342, 12.732329, 12.704566, 12.742165, 12.572661],
       dtype=float32),
 'imagenet_sdv4_embedding': array([17.149105, 17.233383, 17.262657, 17.13218 , 17.054472],
       dtype=float32),
 'imagenet_sdv5_embedding': array([17.27084 , 17.254538, 17.438782, 17.217953, 17.170486],
       dtype=float32),
 'imagenet_vqdm_embedding': array([11.207735, 11.090878, 11.292017, 11.206196, 11.167131],
       dtype=float32),
 'imagenet_wukong_embedding': array([18.644157, 18.674633, 18.625652, 18.643795, 18.433655],
       dtype=float32),
 'imagenet_adm_embedding': array([12.630351 , 12.518478 , 12.667346 , 

In [25]:
np.save("results_fid.npy", results_fid)

In [31]:
results_sw = {}

In [35]:
torch.manual_seed(0)
metric = "sliced_wasserstein"
metric_fn = get_metric(metric)
for dname in datasets[-1:]:
    print("Starting ", dname)
    metric_values = []
    for j in range(5):
        data_test = testset[j*20_000:(j+1)*20_000]
        if dname == "imagenet_midjourney_embedding":
            data_syn = get_dataset(dname)(10_000, 2048, permute=False)
            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]
        else:
            data_syn = get_dataset(dname)(20_000, 2048, permute=True)
        m = metric_fn(data_test, data_syn, num_projections=5000)
        metric_values.append(m)
    results_sw[dname] = np.array(metric_values)

Starting  imagenet_midjourney_embedding


In [37]:
results_sw

{'imagenet_unconditional_model_embedding': array([0.02358142, 0.02302841, 0.02342947, 0.02363985, 0.02274173],
       dtype=float32),
 'imagenet_cs1_embedding': array([0.02153669, 0.02145257, 0.02175897, 0.02185666, 0.02145214],
       dtype=float32),
 'imagenet_cs10_embedding': array([0.02553654, 0.02503867, 0.02539671, 0.0255399 , 0.02508791],
       dtype=float32),
 'imagenet_biggan_embedding': array([0.05029042, 0.05024597, 0.05106073, 0.05075597, 0.05070167],
       dtype=float32),
 'imagenet_sdv4_embedding': array([0.05571224, 0.05621962, 0.05671528, 0.056228  , 0.05603858],
       dtype=float32),
 'imagenet_sdv5_embedding': array([0.05642236, 0.05595   , 0.0562194 , 0.05640132, 0.05647483],
       dtype=float32),
 'imagenet_vqdm_embedding': array([0.04172998, 0.04103179, 0.04130031, 0.04112783, 0.04078227],
       dtype=float32),
 'imagenet_wukong_embedding': array([0.05560957, 0.05517614, 0.05603101, 0.05579789, 0.05510735],
       dtype=float32),
 'imagenet_adm_embedding': arr

In [36]:
np.save("results_sw.npy", results_sw)

In [38]:
results_mmd_rbf64 = {}

In [43]:
torch.manual_seed(0)
metric = "mmd_rbf"
metric_fn = get_metric(metric)
for dname in datasets:
    print("Starting ", dname)
    metric_values = []
    for j in range(5):
        data_test = testset[j*20_000:(j+1)*20_000]
        if dname == "imagenet_midjourney_embedding":
            data_syn = get_dataset(dname)(10_000, 2048, permute=False)
            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]
        else:
            data_syn = get_dataset(dname)(20_000, 2048, permute=True)
        m = metric_fn(data_test, data_syn, bandwidth=64.0)
        metric_values.append(m)
    results_mmd_rbf64[dname] = np.array(metric_values)

Starting  imagenet_unconditional_model_embedding
Starting  imagenet_cs1_embedding
Starting  imagenet_cs10_embedding
Starting  imagenet_biggan_embedding
Starting  imagenet_sdv4_embedding
Starting  imagenet_sdv5_embedding
Starting  imagenet_vqdm_embedding
Starting  imagenet_wukong_embedding
Starting  imagenet_adm_embedding
Starting  imagenet_midjourney_embedding


In [45]:
results_mmd_rbf64

{'imagenet_unconditional_model_embedding': array([7.1763992e-05, 6.2704086e-05, 7.2360039e-05, 6.4253807e-05,
        6.0319901e-05], dtype=float32),
 'imagenet_cs1_embedding': array([7.1883202e-05, 5.7458878e-05, 5.9604645e-05, 6.4969063e-05,
        5.7697296e-05], dtype=float32),
 'imagenet_cs10_embedding': array([9.0479851e-05, 8.4877014e-05, 8.2612038e-05, 8.7141991e-05,
        8.1181526e-05], dtype=float32),
 'imagenet_biggan_embedding': array([0.00019705, 0.00018573, 0.00018668, 0.00018907, 0.00017405],
       dtype=float32),
 'imagenet_sdv4_embedding': array([0.00021684, 0.00020289, 0.00020945, 0.00021148, 0.00020015],
       dtype=float32),
 'imagenet_sdv5_embedding': array([0.00021875, 0.00020254, 0.00021124, 0.00021267, 0.00020623],
       dtype=float32),
 'imagenet_vqdm_embedding': array([0.00015748, 0.0001334 , 0.0001483 , 0.00015378, 0.00014079],
       dtype=float32),
 'imagenet_wukong_embedding': array([0.0002085 , 0.00019038, 0.00019991, 0.00020087, 0.00018454],
     

In [44]:
np.save("results_mmd_rbf64.npy", results_mmd_rbf64)

In [12]:
results_mmd_lin = {}

In [15]:
torch.manual_seed(0)
metric = "mmd_linear"
metric_fn = get_metric(metric)
for dname in datasets:
    print("Starting ", dname)
    metric_values = []
    for j in range(5):
        data_test = testset[j*20_000:(j+1)*20_000]
        if dname == "imagenet_midjourney_embedding":
            data_syn = get_dataset(dname)(10_000, 2048, permute=False)
            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]
        else:
            data_syn = get_dataset(dname)(20_000, 2048, permute=True)
        m = metric_fn(data_test, data_syn)
        metric_values.append(m)
    results_mmd_lin[dname] = np.array(metric_values)

Starting  imagenet_unconditional_model_embedding
Starting  imagenet_cs1_embedding
Starting  imagenet_cs10_embedding
Starting  imagenet_biggan_embedding
Starting  imagenet_sdv4_embedding
Starting  imagenet_sdv5_embedding
Starting  imagenet_vqdm_embedding
Starting  imagenet_wukong_embedding
Starting  imagenet_adm_embedding
Starting  imagenet_midjourney_embedding


In [16]:
np.save("results_mmd_lin.npy", results_mmd_lin)

In [17]:
results_mmd_lin

{'imagenet_unconditional_model_embedding': array([0.27161348, 0.23375021, 0.27461752, 0.23648134, 0.22176446],
       dtype=float32),
 'imagenet_cs1_embedding': array([0.2784213 , 0.2160877 , 0.2256283 , 0.24824372, 0.21793067],
       dtype=float32),
 'imagenet_cs10_embedding': array([0.3461548 , 0.32180998, 0.31091288, 0.33070827, 0.30549592],
       dtype=float32),
 'imagenet_biggan_embedding': array([0.6735612, 0.6250146, 0.6302967, 0.6348198, 0.574922 ],
       dtype=float32),
 'imagenet_sdv4_embedding': array([0.7234071 , 0.6639475 , 0.6890489 , 0.69806087, 0.6532722 ],
       dtype=float32),
 'imagenet_sdv5_embedding': array([0.73029166, 0.66163504, 0.6962174 , 0.70317215, 0.6768427 ],
       dtype=float32),
 'imagenet_vqdm_embedding': array([0.55959   , 0.45710978, 0.51890296, 0.5427842 , 0.4893559 ],
       dtype=float32),
 'imagenet_wukong_embedding': array([0.6912052 , 0.61376953, 0.6529302 , 0.6583701 , 0.5910167 ],
       dtype=float32),
 'imagenet_adm_embedding': array([0

In [22]:
results_mmd_poly_kid = {}

In [23]:
torch.manual_seed(0)
metric = "mmd_polynomial"
metric_fn = get_metric(metric)
for dname in datasets:
    print("Starting ", dname)
    metric_values = []
    for j in range(5):
        data_test = testset[j*20_000:(j+1)*20_000]
        if dname == "imagenet_midjourney_embedding":
            data_syn = get_dataset(dname)(10_000, 2048, permute=False)
            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]
        else:
            data_syn = get_dataset(dname)(20_000, 2048, permute=True)
        m = metric_fn(data_test, data_syn, degree=3, bias=1.)
        metric_values.append(m)
    results_mmd_poly_kid[dname] = np.array(metric_values)

Starting  imagenet_unconditional_model_embedding
Starting  imagenet_cs1_embedding
Starting  imagenet_cs10_embedding
Starting  imagenet_biggan_embedding
Starting  imagenet_sdv4_embedding
Starting  imagenet_sdv5_embedding
Starting  imagenet_vqdm_embedding
Starting  imagenet_wukong_embedding
Starting  imagenet_adm_embedding
Starting  imagenet_midjourney_embedding


In [24]:
results_mmd_poly_kid

{'imagenet_unconditional_model_embedding': array([13181., 10870., 12008., 10130.,  9628.], dtype=float32),
 'imagenet_cs1_embedding': array([21638., 15387., 14466., 17863., 15719.], dtype=float32),
 'imagenet_cs10_embedding': array([17963., 15701., 13769., 15598., 14699.], dtype=float32),
 'imagenet_biggan_embedding': array([33609., 29164., 29311., 30098., 26772.], dtype=float32),
 'imagenet_sdv4_embedding': array([39859., 34397., 34759., 37063., 34212.], dtype=float32),
 'imagenet_sdv5_embedding': array([39529., 34206., 34731., 36771., 35564.], dtype=float32),
 'imagenet_vqdm_embedding': array([25483., 19751., 22153., 24993., 22030.], dtype=float32),
 'imagenet_wukong_embedding': array([42367., 35982., 37357., 38585., 34325.], dtype=float32),
 'imagenet_adm_embedding': array([37571., 34480., 34681., 33519., 29859.], dtype=float32),
 'imagenet_midjourney_embedding': array([40835., 32620., 35570., 36513., 35667.], dtype=float32)}

In [26]:
np.save("results_mmd_poly_kid.npy", results_mmd_poly_kid)

In [37]:
results_c2st_knn = {}

In [38]:
torch.manual_seed(0)
metric = "c2st_knn"
metric_fn = get_metric(metric)
for dname in datasets:
    print("Starting ", dname)
    metric_values = []
    for j in range(5):
        data_test = testset[j*20_000:(j+1)*20_000]
        if dname == "imagenet_midjourney_embedding":
            data_syn = get_dataset(dname)(10_000, 2048, permute=False)
            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]
        else:
            data_syn = get_dataset(dname)(20_000, 2048, permute=True)
        m = metric_fn(data_test, data_syn, n_folds=2)
        metric_values.append(m)
    results_c2st_knn[dname] = np.array(metric_values)

Starting  imagenet_unconditional_model_embedding
Starting  imagenet_cs1_embedding
Starting  imagenet_cs10_embedding
Starting  imagenet_biggan_embedding
Starting  imagenet_sdv4_embedding
Starting  imagenet_sdv5_embedding
Starting  imagenet_vqdm_embedding
Starting  imagenet_wukong_embedding
Starting  imagenet_adm_embedding
Starting  imagenet_midjourney_embedding


In [41]:
results_c2st_knn

{'imagenet_unconditional_model_embedding': array([[0.65345 ],
        [0.64765 ],
        [0.650725],
        [0.652725],
        [0.652275]], dtype=float32),
 'imagenet_cs1_embedding': array([[0.63035 ],
        [0.62975 ],
        [0.63135 ],
        [0.629375],
        [0.63405 ]], dtype=float32),
 'imagenet_cs10_embedding': array([[0.6517 ],
        [0.6506 ],
        [0.6573 ],
        [0.654  ],
        [0.65605]], dtype=float32),
 'imagenet_biggan_embedding': array([[0.75415 ],
        [0.749725],
        [0.7522  ],
        [0.750975],
        [0.754475]], dtype=float32),
 'imagenet_sdv4_embedding': array([[0.7913  ],
        [0.78645 ],
        [0.78665 ],
        [0.78685 ],
        [0.790925]], dtype=float32),
 'imagenet_sdv5_embedding': array([[0.7921  ],
        [0.786525],
        [0.7872  ],
        [0.790975],
        [0.792425]], dtype=float32),
 'imagenet_vqdm_embedding': array([[0.774775],
        [0.76665 ],
        [0.76835 ],
        [0.773775],
        [0.77215 ]

In [40]:
np.save("results_c2st_knn.npy", results_c2st_knn)

In [45]:
results_c2st_nn = {}

In [46]:
torch.manual_seed(0)
metric = "c2st_nn"
metric_fn = get_metric(metric)
for dname in datasets:
    print("Starting ", dname)
    metric_values = []
    for j in range(5):
        data_test = testset[j*20_000:(j+1)*20_000]
        if dname == "imagenet_midjourney_embedding":
            data_syn = get_dataset(dname)(10_000, 2048, permute=False)
            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]
        else:
            data_syn = get_dataset(dname)(20_000, 2048, permute=True)
        m = metric_fn(data_test, data_syn, n_folds=2)
        metric_values.append(m)
    results_c2st_nn[dname] = np.array(metric_values)

Starting  imagenet_unconditional_model_embedding
Starting  imagenet_cs1_embedding
Starting  imagenet_cs10_embedding
Starting  imagenet_biggan_embedding
Starting  imagenet_sdv4_embedding
Starting  imagenet_sdv5_embedding
Starting  imagenet_vqdm_embedding
Starting  imagenet_wukong_embedding
Starting  imagenet_adm_embedding
Starting  imagenet_midjourney_embedding


In [47]:
np.save("results_c2st_nn.npy", results_c2st_nn)

In [48]:
results_c2st_nn

{'imagenet_unconditional_model_embedding': array([[0.7191  ],
        [0.71785 ],
        [0.72095 ],
        [0.716875],
        [0.7206  ]], dtype=float32),
 'imagenet_cs1_embedding': array([[0.771325],
        [0.761475],
        [0.771525],
        [0.767575],
        [0.7738  ]], dtype=float32),
 'imagenet_cs10_embedding': array([[0.7608  ],
        [0.7621  ],
        [0.764   ],
        [0.7643  ],
        [0.757775]], dtype=float32),
 'imagenet_biggan_embedding': array([[0.8693  ],
        [0.864775],
        [0.862925],
        [0.861475],
        [0.854925]], dtype=float32),
 'imagenet_sdv4_embedding': array([[0.92205 ],
        [0.920025],
        [0.932025],
        [0.92415 ],
        [0.92075 ]], dtype=float32),
 'imagenet_sdv5_embedding': array([[0.920525],
        [0.930475],
        [0.923075],
        [0.920125],
        [0.92495 ]], dtype=float32),
 'imagenet_vqdm_embedding': array([[0.8448 ],
        [0.8455 ],
        [0.84815],
        [0.8496 ],
        [0.84925]

In [49]:
results_c2st_rf = {}

In [50]:
torch.manual_seed(0)
metric = "c2st_rf"
metric_fn = get_metric(metric)
for dname in datasets:
    print("Starting ", dname)
    metric_values = []
    for j in range(5):
        data_test = testset[j*20_000:(j+1)*20_000]
        if dname == "imagenet_midjourney_embedding":
            data_syn = get_dataset(dname)(10_000, 2048, permute=False)
            data_syn = data_syn[torch.randint(0, 10_000, (20_000,))]
        else:
            data_syn = get_dataset(dname)(20_000, 2048, permute=True)
        m = metric_fn(data_test, data_syn, n_folds=2)
        metric_values.append(m)
    results_c2st_rf[dname] = np.array(metric_values)

Starting  imagenet_unconditional_model_embedding
Starting  imagenet_cs1_embedding


KeyboardInterrupt: 

In [None]:
np.save("results_c2st_rf.npy", results_c2st_rf)

In [None]:
results_c2st_rf

In [44]:
metric_fn = get_metric("c2st_nn")
m = metric_fn(data_test, data_syn, n_folds=2)
m

tensor([0.9410])