In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# imports
from pathlib import Path
import sys  

# Get my_package directory path from Notebook
parent_dir = str(Path().resolve().parents[0])

# Add to sys.path
sys.path.insert(0, parent_dir)

In [3]:
from utils.experiment_utils import get_all_experiments_info, load_best_model
import torch
import os
import hydra
from omegaconf import DictConfig, OmegaConf

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

from notebooks.mnist_classifier.mnist_tiny_cnn import TinyCNN

from mixer.mixer import SetMixer
from datasets.mnist import MNISTDataset

from torch.utils.data import DataLoader
from itertools import product

from tqdm.notebook import tqdm

from datasets.distribution_datasets import GaussianMixtureModelDistributionDataset
from utils.gmm_utils import fit_gmm_batch

from ot.gmm import gmm_ot_loss

import warnings
warnings.filterwarnings("ignore")



In [8]:
device = 'cuda'
configs = get_all_experiments_info('/orcd/data/omarabu/001/gokul/DistributionEmbeddings/outputs/', False)
cfg = [
    c for c in configs if 'gmm_sys' in c['name'] and c['config']['dataset']['data_shape'] == [5] and c['config']['training']['max_time'] == 600
]

print(len(cfg))

# load model and move to device
def load_model(cfg, path, device):
    enc = hydra.utils.instantiate(cfg['encoder'])
    gen = hydra.utils.instantiate(cfg['generator'])
    state = load_best_model(path)
    enc.load_state_dict(state['encoder_state_dict'])
    gen.model.load_state_dict(state['generator_state_dict'])
    enc.eval()
    gen.eval()
    enc.to(device)
    gen.to(device)
    return enc, gen

3


In [9]:
d = {
    "Encoder" : [],
    "Generator" : [],
    "N dims" : [],
    "OT reconstruction error" : []
}

N_sets = 200
set_size = 10**4
dataset = GaussianMixtureModelDistributionDataset(
    n_sets=N_sets,
    set_size=set_size,
    prior_mu=(0,5),
    data_shape=[5]
)


for c in tqdm(cfg):
    encoder_name = c['encoder']
    generator_name = c['generator']
    if generator_name == 'DirectGenerator':
        generator_name += '-' + c['config']['generator']['loss_type']
    data_shape = c['config']['dataset']['data_shape']
    num_epochs = c['config']['training']['num_epochs']

    try:
        enc, gen = load_model(c['config'], c['dir'], device=device)
    except:
        print('broken: ', encoder_name, generator_name)
        continue

    if 'Tx' not in encoder_name and 'Wormhole' not in encoder_name:
        set_size = 10**4
    else:
        set_size=10**3

    data = dataset.data[:, :set_size, :]
    
    ot_errors = []
    for idx in range(len(dataset)//8):

        with torch.no_grad():
            x = torch.tensor(data[idx*8:(idx+1)*8], dtype=torch.float).cuda()
            z = enc(x)
            x_hat = gen.sample(z, num_samples=10**4)

        mus = dataset.mu[idx*8:(idx+1)*8]
        covs = dataset.cov[idx*8:(idx+1)*8]
        weights = dataset.weights[idx*8:(idx+1)*8]


        r_means, r_covs, r_weights = fit_gmm_batch(
            x_hat.detach().cpu().numpy(), 
            mus,
            covs,
            weights,
            # use_kmeans_init=True
        )

        x = x.cpu()
        x_hat = x_hat.detach().cpu()
        
        ot_dists = [
            gmm_ot_loss(r_m, m, r_c, c, r_w, w) 
            for r_m, m, r_c, c, r_w, w in zip(r_means, mus, r_covs, covs, r_weights, weights)
        ]
        ot_errors += ot_dists

    d['Encoder'].append(encoder_name)
    d['Generator'].append(generator_name)
    d['N dims'].append(data_shape[0])
    d['OT reconstruction error'].append(np.mean(ot_errors))
    print(f"Encoder: {encoder_name}, Generator: {generator_name}, OT error: {np.mean(ot_errors)}, data shape: {data_shape[0]}")

  0%|          | 0/3 [00:00<?, ?it/s]

Encoder: WormholeEncoder, Generator: WormholeGenerator, OT error: 2.888592809758486, data shape: 5
Encoder: DistributionEncoderResNet, Generator: DDPM, OT error: 1.821606542061434, data shape: 5
Encoder: KMEEncoder, Generator: DDPM, OT error: 2.179117803930016, data shape: 5


In [10]:
pd.DataFrame(d).sort_values(by='OT reconstruction error')

Unnamed: 0,Encoder,Generator,N dims,OT reconstruction error
1,DistributionEncoderResNet,DDPM,5,1.821607
2,KMEEncoder,DDPM,5,2.179118
0,WormholeEncoder,WormholeGenerator,5,2.888593
