In [1]:
# imports
from pathlib import Path
import sys  

# Get my_package directory path from Notebook
parent_dir = str(Path().resolve().parents[0])

# Add to sys.path
sys.path.insert(0, parent_dir)

In [2]:
from utils.experiment_utils import get_all_experiments_info, load_best_model
import torch
import os
import hydra
from omegaconf import DictConfig, OmegaConf

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

from notebooks.mnist_classifier.mnist_tiny_cnn import TinyCNN

from mixer.mixer import SetMixer
from datasets.mnist import MNISTDataset

from torch.utils.data import DataLoader
from itertools import product

from tqdm.notebook import tqdm

from datasets.distribution_datasets import GaussianMixtureModelDataset
from utils.gmm_utils import fit_gmm_batch

from ot.gmm import gmm_ot_loss

import warnings
warnings.filterwarnings("ignore")

In [3]:
device = 'cuda'
configs = get_all_experiments_info('/orcd/data/omarabu/001/njwfish/DistributionEmbeddings/outputs/', False)
cfg = [
    c for c in configs if 'gmm_sys' in c['name']
]

# load model and move to device
def load_model(cfg, path, device):
    enc = hydra.utils.instantiate(cfg['encoder'])
    gen = hydra.utils.instantiate(cfg['generator'])
    state = load_best_model(path)
    enc.load_state_dict(state['encoder_state_dict'])
    gen.model.load_state_dict(state['generator_state_dict'])
    enc.eval()
    gen.eval()
    enc.to(device)
    gen.to(device)
    return enc, gen

In [4]:
cfg

[{'name': 'gmm_systematic_exp_837abe69ce82174a9ce5decc8d39cdeb',
  'dir': '/orcd/data/omarabu/001/njwfish/DistributionEmbeddings/outputs/gmm_systematic_exp_837abe69ce82174a9ce5decc8d39cdeb',
  'config': {'dataset': {'_target_': 'datasets.distribution_datasets.LowRankMultivariateNormalDistributionDataset', 'n_sets': 50000, 'set_size': '${experiment.set_size}', 'data_shape': [100], 'seed': '${seed}', 'prior_mu': [0, 5], 'prior_cov_df': 10, 'prior_cov_scale': 1, 'rank': 2}, 'encoder': {'_target_': 'encoder.encoders.DistributionEncoderTx', 'in_dim': '${dataset.data_shape[0]}', 'latent_dim': '${experiment.latent_dim}', 'hidden_dim': '${experiment.hidden_dim}', 'set_size': '${experiment.set_size}', 'layers': 2, 'heads': 4}, 'model': {'_target_': 'layers.MLP', 'in_dims': [32, 32], 'hidden_dim': 128, 'out_dim': 100, 'layers': 4}, 'generator': {'_target_': 'generator.direct.DirectGenerator', 'model': '${model}', 'loss_type': 'swd', 'loss_params': {'n_projections': 100, 'p': 2}, 'noise_dim': '${

In [6]:
d = {
    "Encoder" : [],
    "Generator" : [],
    "N dims" : [],
    "OT reconstruction error" : []
}

N_sets = 40
set_size = 10**3


for c in tqdm(cfg):
    encoder_name = c['encoder']
    generator_name = c['generator']
    data_shape = c['config']['dataset']['data_shape']
    rank = c['config']['dataset']['rank']
    num_epochs = c['config']['training']['num_epochs']
    print(num_epochs, encoder_name, generator_name, data_shape)

    if num_epochs != 100:
        continue

    if 'KME' in encoder_name or 'Mean' in encoder_name:
        continue

    try:
        enc, gen = load_model(c['config'], c['dir'], device=device)
    except:
        print(encoder_name)
        continue

    if rank != 2 and data_shape[0] != 100:
        continue
    

    if 'Tx' not in encoder_name and 'Wormhole' not in encoder_name:
        set_size = 10**5
    else:
        set_size=10**3

    ds = hydra.utils.instantiate(c['config']['dataset'])
    projection_matrix = torch.tensor(ds.projection_matrix, dtype=torch.float).cuda()
    inv_projection_matrix = torch.linalg.pinv(projection_matrix)

    dataset = GaussianMixtureModelDataset(
        n_sets=N_sets,
        set_size=set_size,
        prior_mu=(0,5),
        data_shape=[rank]
    )
    
    ot_errors = []
    for idx in range(len(dataset)//8):

        with torch.no_grad():
            x = torch.tensor(dataset.data[idx*8:(idx+1)*8], dtype=torch.float).cuda()
            x = x @ projection_matrix
            z = enc(x)
            x_hat = gen.sample(z, num_samples=10**5)
            x_hat = x_hat @ inv_projection_matrix

        mus = dataset.mu[idx*8:(idx+1)*8]
        covs = dataset.cov[idx*8:(idx+1)*8]
        weights = dataset.weights[idx*8:(idx+1)*8]


        r_means, r_covs, r_weights = fit_gmm_batch(
            x_hat.detach().cpu().numpy(), 
            mus,
            covs,
            weights,
        )

        ot_dists = [
            gmm_ot_loss(r_m, m, r_c, c, r_w, w) 
            for r_m, m, r_c, c, r_w, w in zip(r_means, mus, r_covs, covs, r_weights, weights)
        ]
        ot_errors += ot_dists

    d['Encoder'].append(encoder_name)
    d['Generator'].append(generator_name)
    d['N dims'].append(data_shape[0])
    d['OT reconstruction error'].append(np.mean(ot_errors))
    print(f"Encoder: {encoder_name}, Generator: {generator_name}, OT error: {np.mean(ot_errors)}, data shape: {data_shape[0]}")

  0%|          | 0/10 [00:00<?, ?it/s]

100 DistributionEncoderTx DirectGenerator [100]
[[ 0.49671415 -0.1382643   0.64768854  1.52302986 -0.23415337 -0.23413696
   1.57921282  0.76743473 -0.46947439  0.54256004 -0.46341769 -0.46572975
   0.24196227 -1.91328024 -1.72491783 -0.56228753 -1.01283112  0.31424733
  -0.90802408 -1.4123037   1.46564877 -0.2257763   0.0675282  -1.42474819
  -0.54438272  0.11092259 -1.15099358  0.37569802 -0.60063869 -0.29169375
  -0.60170661  1.85227818 -0.01349722 -1.05771093  0.82254491 -1.22084365
   0.2088636  -1.95967012 -1.32818605  0.19686124  0.73846658  0.17136828
  -0.11564828 -0.3011037  -1.47852199 -0.71984421 -0.46063877  1.05712223
   0.34361829 -1.76304016  0.32408397 -0.38508228 -0.676922    0.61167629
   1.03099952  0.93128012 -0.83921752 -0.30921238  0.33126343  0.97554513
  -0.47917424 -0.18565898 -1.10633497 -1.19620662  0.81252582  1.35624003
  -0.07201012  1.0035329   0.36163603 -0.64511975  0.36139561  1.53803657
  -0.03582604  1.56464366 -2.6197451   0.8219025   0.08704707 -0

In [19]:
x @ projection_matrix.T

tensor([[[2379.0017, 2380.3040],
         [2531.9243, 2008.8501],
         [2790.7979, 1547.1091],
         ...,
         [2288.3740, 2117.5923],
         [1685.1665, 3173.9373],
         [1516.0747, 3662.7920]],

        [[2130.4600, 3933.5500],
         [2030.8882, 4149.9287],
         [4495.8633, 4730.3350],
         ...,
         [3970.3555, 4533.4814],
         [1636.8770, 4179.9028],
         [3832.7346, 4848.6421]],

        [[3415.9033, 1804.3011],
         [2366.6499, 1229.8231],
         [4798.2319, 1089.6832],
         ...,
         [2204.6362,  748.2170],
         [2252.5593,  719.6883],
         [3901.0920, 1741.2656]],

        ...,

        [[ 803.7308, 4815.3975],
         [3987.8801, 2819.7168],
         [3623.5322, 3197.5872],
         ...,
         [ 127.4956, 4332.1025],
         [ 834.9538, 4581.5815],
         [3337.7607, 3346.6328]],

        [[2358.6128, 1044.6616],
         [ 605.9817, 3454.8689],
         [1924.7236,  124.8062],
         ...,
         [1926.20

In [4]:
pd.DataFrame(d).sort_values(by='OT reconstruction error')

Unnamed: 0,Encoder,Generator,N dims,OT reconstruction error
1,DistributionEncoderTx,DirectGenerator,10,877.6316
0,DistributionEncoderTx,CVAE,10,1126.318
3,DistributionEncoderTx,DirectGenerator,10,6721.938
4,WormholeEncoder,WormholeGenerator,10,8771.564
5,DistributionEncoderGNN,CVAE,10,37441.57
2,DistributionEncoderTx,DDPM,10,77537.86
6,DistributionEncoderGNN,DirectGenerator,10,3222186.0
7,DistributionEncoderGNN,DirectGenerator,10,88017560.0
