In [1]:
from utils.experiment_utils import get_all_experiments_info, load_best_model
import torch
import os
import hydra
from omegaconf import DictConfig, OmegaConf

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

from notebooks.mnist_classifier.mnist_tiny_cnn import TinyCNN

from mixer.mixer import SetMixer
from datasets.mnist import MNISTDataset

from torch.utils.data import DataLoader
from itertools import product

from tqdm.notebook import tqdm

from datasets.distribution_datasets import GaussianMixtureModelDataset
from utils.gmm_utils import fit_gmm_batch

from ot.gmm import gmm_ot_loss

import warnings
warnings.filterwarnings("ignore")



In [2]:
device = 'cuda'
configs = get_all_experiments_info('/orcd/data/omarabu/001/njwfish/DistributionEmbeddings/outputs/', False)
cfg = [c for c in configs if 'gmm_sys' in c['name']]

# load model and move to device
def load_model(cfg, path, device):
    enc = hydra.utils.instantiate(cfg['encoder'])
    gen = hydra.utils.instantiate(cfg['generator'])
    state = load_best_model(path)
    enc.load_state_dict(state['encoder_state_dict'])
    gen.model.load_state_dict(state['generator_state_dict'])
    enc.eval()
    gen.eval()
    enc.to(device)
    gen.to(device)
    return enc, gen

In [3]:
d = {
    "Encoder" : [],
    "Generator" : [],
    "N dims" : [],
    "OT reconstruction error" : []
}

N_sets = 40
set_size = 10**3


for c in tqdm(cfg):
    encoder_name = c['encoder']
    generator_name = c['generator']
    data_shape = c['config']['dataset']['data_shape']

    if 'KME' in encoder_name or 'Mean' in encoder_name:
        continue

    try:
        enc, gen = load_model(c['config'], c['dir'], device=device)
    except:
        print(encoder_name)
        continue

    if data_shape[0] != 10:
        continue

    if 'Tx' not in encoder_name and 'Wormhole' not in encoder_name:
        set_size = 10**5
    else:
        set_size=10**3

    dataset = GaussianMixtureModelDataset(n_sets=N_sets,
                                        set_size=set_size,
                                        prior_mu=(0,5),
                                        data_shape=data_shape)
    
    ot_errors = []
    for idx in range(len(dataset)//8):

        with torch.no_grad():

            x = torch.tensor(dataset.data[idx*8:(idx+1)*8], dtype=torch.float).cuda()
            z = enc(x)
            x_hat = gen.sample(z, num_samples=10**5)

        mus = dataset.mu[idx*8:(idx+1)*8]
        covs = dataset.cov[idx*8:(idx+1)*8]
        weights = dataset.weights[idx*8:(idx+1)*8]

        r_means, r_covs, r_weights = fit_gmm_batch(x_hat.detach().cpu().numpy(), 
                                                mus, covs, weights)

        ot_dists = [gmm_ot_loss(r_m, m, r_c, c, r_w, w) 
                    for r_m, m, r_c, c, r_w, w in zip(r_means, mus, r_covs, covs, r_weights, weights)]
        ot_errors += ot_dists

    d['Encoder'].append(encoder_name)
    d['Generator'].append(generator_name)
    d['N dims'].append(data_shape[0])
    d['OT reconstruction error'].append(np.mean(ot_errors))
    print(f"Encoder: {encoder_name}, Generator: {generator_name}, OT error: {np.mean(ot_errors)}, data shape: {data_shape[0]}")

  0%|          | 0/56 [00:00<?, ?it/s]

Encoder: DistributionEncoderTx, Generator: CVAE, OT error: 1126.3175941346253, data shape: 10
Encoder: DistributionEncoderTx, Generator: DirectGenerator, OT error: 877.6315606951615, data shape: 10
Encoder: DistributionEncoderTx, Generator: DDPM, OT error: 77537.86160033333, data shape: 10
Encoder: DistributionEncoderTx, Generator: DirectGenerator, OT error: 6721.938157891901, data shape: 10
DistributionEncoderGNN
Encoder: WormholeEncoder, Generator: WormholeGenerator, OT error: 8771.564253388122, data shape: 10
Encoder: DistributionEncoderGNN, Generator: CVAE, OT error: 37441.57370713449, data shape: 10
Encoder: DistributionEncoderGNN, Generator: DirectGenerator, OT error: 3222186.047737793, data shape: 10
Encoder: DistributionEncoderGNN, Generator: DirectGenerator, OT error: 88017556.23550084, data shape: 10


In [4]:
pd.DataFrame(d).sort_values(by='OT reconstruction error')

Unnamed: 0,Encoder,Generator,N dims,OT reconstruction error
1,DistributionEncoderTx,DirectGenerator,10,877.6316
0,DistributionEncoderTx,CVAE,10,1126.318
3,DistributionEncoderTx,DirectGenerator,10,6721.938
4,WormholeEncoder,WormholeGenerator,10,8771.564
5,DistributionEncoderGNN,CVAE,10,37441.57
2,DistributionEncoderTx,DDPM,10,77537.86
6,DistributionEncoderGNN,DirectGenerator,10,3222186.0
7,DistributionEncoderGNN,DirectGenerator,10,88017560.0
