In [1]:
import os, sys
sys.path.append("..")

import torch
from torch.utils.data import DataLoader

from benchmark.rotated_gaussian_benchmark import RotatedGaussiansBenchmark
from benchmark.metrics import compute_BW_UVP_with_gt_stats, compute_BW_UVP_by_gt_samples
from src.distributions import LoaderSampler

### Benchmark loading

In [2]:
device = "cpu"
dim = 128
eps = 1
batch_size = 100000

benchmark = RotatedGaussiansBenchmark(dim=dim, eps=eps, benchmark_data_path="../benchmark/benchmark_data", download=True)

X_dataset = benchmark.X_dataset
Y_dataset = benchmark.Y_dataset

# stats for bw-uvp metric calculation
mu_X, mu_Y = benchmark.mu_X, benchmark.mu_Y
covariance_X, covariance_Y = benchmark.covariance_X, benchmark.covariance_Y

optimal_plan_mu, optimal_plan_covariance = benchmark.optimal_plan_mu, benchmark.optimal_plan_covariance

X_sampler = LoaderSampler(DataLoader(X_dataset, shuffle=False, num_workers=8, batch_size=batch_size), device)
Y_sampler = LoaderSampler(DataLoader(Y_dataset, shuffle=False, num_workers=8, batch_size=batch_size), device)

Downloading...
From: https://drive.google.com/uc?id=1ZOUXFdkssPbGJb1jPhVK1dh8lkwu0Sx0
To: /home/n.gushchin/EntropicOTBenchmark/benchmark/benchmark_data/rotated_gaussians.zip
100%|██████████| 172k/172k [00:00<00:00, 352kB/s]


### Examples of computing bw-uvp metric

BW-UVP between samples from source distribution and itself by using ground truth stats

In [3]:
compute_BW_UVP_with_gt_stats(
    X_sampler.sample(batch_size).detach().cpu().numpy(),
    true_samples_mu=mu_X,
    true_samples_covariance=covariance_X
)

0.02465806845301834

BW-UVP between samples from source distribution and target distirbution by using ground truth stats

In [4]:
compute_BW_UVP_with_gt_stats(
    X_sampler.sample(batch_size).detach().cpu().numpy(),
    true_samples_mu=mu_Y,
    true_samples_covariance=covariance_Y
)

26.15500945715382

BW-UVP between samples from source distribution and itself distirbution by using ground truth samples

In [5]:
compute_BW_UVP_by_gt_samples(
    X_sampler.sample(batch_size).detach().cpu().numpy(),
    X_sampler.sample(batch_size).detach().cpu().numpy(),
)

0.05025316264511702

BW-UVP between samples from source distribution and target distirbution by using ground truth samples

In [6]:
compute_BW_UVP_by_gt_samples(
    X_sampler.sample(batch_size).detach().cpu().numpy(),
    Y_sampler.sample(batch_size).detach().cpu().numpy(),
)

26.166787182183697

BW-UVP between trivial plan samples and optimal plan by using ground truth stats

In [7]:
X_Y = torch.cat((X_sampler.sample(batch_size), Y_sampler.sample(batch_size)), dim=1).detach().cpu().numpy()

compute_BW_UVP_with_gt_stats(
    X_Y,
    optimal_plan_mu,
    optimal_plan_covariance
)

12.428462597450906