In [None]:
from utils.experiment_utils import get_all_experiments_info, load_best_model
import torch
import os
import hydra
from omegaconf import DictConfig, OmegaConf

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader

from tqdm.notebook import tqdm

from datasets.distribution_datasets import LowRankMultivariateNormalDistributionDataset

import scipy as sp

: 

In [None]:
device = 'cuda'
configs = get_all_experiments_info('/orcd/data/omarabu/001/gokul/DistributionEmbeddings/outputs/', False)
cfg = [c for c in configs if 'gmm_sys' in c['name']]

# load model and move to device
def load_model(cfg, path, device):
    enc = hydra.utils.instantiate(cfg['encoder'])
    gen = hydra.utils.instantiate(cfg['generator'])
    state = load_best_model(path)
    enc.load_state_dict(state['encoder_state_dict'])
    gen.model.load_state_dict(state['generator_state_dict'])
    enc.eval()
    gen.eval()
    enc.to(device)
    gen.to(device)
    return enc, gen

In [None]:

def wasserstein_zip(mu1, mu2, cov1, cov2):
    mean_dist = np.linalg.norm(mu1 - mu2, axis=1)
    var_dist = np.zeros(len(mu1))
    for i, (c1, c2) in enumerate(zip(cov1, cov2)):
        sqrt_c1 = sp.linalg.sqrtm(c1)
        prod = sqrt_c1 @ c2 @ sqrt_c1
        var_dist[i] = np.trace(c1 + c2 - 2 * sp.linalg.sqrtm(prod))
    return mean_dist + var_dist  # (n,)


def batch_cov(x):
    # x: (b, n, d)
    x = x - x.mean(dim=1, keepdim=True)     # center!
    cov = x.transpose(1, 2) @ x / (x.shape[1] - 1)  # sample cov :)
    return cov  # (b, d, d)

In [None]:
n_sets = 100

d = {
    "Encoder" : [],
    "Generator" : [],
    "N dims" : [],
    "OT reconstruction error" : [],
    'Parameter MSE' : []
}

for c in tqdm(cfg):
    encoder_name = c['encoder']
    generator_name = c['generator']
    data_shape = c['config']['dataset']['data_shape']
    print(encoder_name, generator_name, data_shape)
    try:
        enc, gen = load_model(c['config'], c['dir'], device=device)
    except:
        print("broken: ", encoder_name)
        continue

    if 'Tx' not in encoder_name and 'Wormhole' not in encoder_name:
        set_size = 10**4
    else:
        set_size=10**3

    # load dataset
    dataset = LowRankMultivariateNormalDistributionDataset(set_size=set_size,
                                                           n_sets=n_sets,
                                                           data_shape=data_shape
                                                           )
    
    mus = dataset.mu # shape: batch, d
    covs = dataset.cov # shape: batch, d, d

    proj = dataset.projection_matrix # shape: rank, d

    mus = mus@proj
    covs = proj.T@covs@proj

    rec_error = []
    mses = []
    for idx in range(len(dataset)//8):

        with torch.no_grad():
            x = torch.tensor(dataset.data[idx*8:(idx+1)*8]@proj, dtype=torch.float).cuda()
            z = enc(x)
            x_hat = gen.sample(z, num_samples=set_size)

        r_mus = x_hat.mean(axis=1).cpu().numpy()
        r_covs = batch_cov(x_hat).cpu().numpy()

        mu_mse = np.mean(np.linalg.norm(mus[idx*8:(idx+1)*8] - r_mus, axis=1))
        cov_mse = np.mean(np.linalg.norm(covs[idx*8:(idx+1)*8] - r_covs, axis=(1, 2)))


        rec_error.append(wasserstein_zip(mus[idx*8:(idx+1)*8], r_mus, covs[idx*8:(idx+1)*8], r_covs))
        mses.append(mu_mse + cov_mse)



    
    d["Encoder"].append(encoder_name)
    d["Generator"].append(generator_name)
    d["N dims"].append(data_shape[0])
    d["OT reconstruction error"].append(np.mean(rec_error))
    d['Parameter MSE'].append(np.mean(mses))








In [None]:
pd.DataFrame(d)