In [14]:
import torch
from torch import Tensor

import matplotlib.pyplot as plt

from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance
# from labproject.metrics.gaussian_kl import gaussian_kl_divergence
# from labproject.embeddings import *

In [15]:

dataset1 = torch.load('../../data/cifar10_train.pt')
dataset2 = torch.load('../../data/cifar10_test.pt')

In [16]:
class Metric:
    def __init__(self, name: str, func: callable, **kwargs):
        self.name = name
        self.func = func
        self.kwargs = kwargs

    def __call__(self, x: Tensor, y: Tensor) -> Tensor:
        return self.func(x, y, **self.kwargs)


class ComparisonExperiment:
    def __init__(self, dataset1: Tensor, dataset2: Tensor, metrics: list[Metric], descr):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.metrics = metrics
        self.descr = descr

    def get_distance(x: Tensor, y: Tensor, metric, **kwargs) -> Tensor:
        return metric(x, y, **kwargs)

In [17]:

metrics = [
    Metric('sliced_wasserstein_distance', 
           sliced_wasserstein_distance,
           num_projections=50,
           p=2),
]

experiments = [
    ComparisonExperiment(
        dataset1, dataset1, metrics,
        'unconditional within-group real data'
    ),
    ComparisonExperiment(
        dataset2, dataset2, metrics,
        'unconditional within-group generated data'
    ),
    ComparisonExperiment(
        dataset1, dataset2, metrics,
        'unconditional between-group'
    ),
    
]
