In [190]:
import torch
import numpy as np
from torch import Tensor
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance
from labproject.experiments import Experiment


In [191]:

dataset1 = torch.load('../../data/cifar10_train.pt') # torch.Size([50000, 2048])
dataset2 = torch.load('../../data/cifar10_test.pt') # torch.Size([10000, 2048])

In [192]:
class Metric:
    def __init__(self, name: str, func: callable, **kwargs):
        self.name = name
        self.func = func
        self.kwargs = kwargs

    def __call__(self, x: Tensor, y: Tensor) -> Tensor:
        return self.func(x, y, **self.kwargs)


class DistComp(Experiment):
    def __init__(self, dataset1: Tensor, dataset2: Tensor, metrics: list[Metric], 
                 n_perms: int = 100, perm_size=1000, descr=""):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.metrics = metrics
        self.n_perms = n_perms
        self.perm_size = perm_size
        self.descr = descr

        columns = [metric.name for metric in metrics]
        self.results_df = pd.DataFrame(np.nan, index=range(self.n_perms), columns=columns)

    def run_experiment(self):
        for i in range(self.n_perms):
            dataset2_perm = self.dataset2[torch.randperm(self.perm_size)]
            dataset1_perm = self.dataset1[torch.randperm(self.perm_size)]
            for metric in self.metrics:
                self.results_df.loc[i, metric.name] = metric(dataset1_perm, dataset2_perm).numpy()

    def plot_results(self):
        pass


In [193]:

metrics = [
    Metric('sliced_wasserstein_distance', 
           sliced_wasserstein_distance,
           num_projections=50,
           p=2),
    Metric('sliced_wasserstein_distance', 
           sliced_wasserstein_distance,
           num_projections=50,
           p=1),
    Metric('sliced_wasserstein_distance', 
           sliced_wasserstein_distance,
           num_projections=50,
           p=1),
]

experiments = [
    DistComp(
        dataset1, dataset1, metrics, n_perms=10, perm_size=1,
        descr='unconditional within-group real data'
    ),
    DistComp(
        dataset2, dataset2, metrics, n_perms=10, perm_size=1,
        descr='unconditional within-group generated data'
    ),
    DistComp(
        dataset1, dataset2, metrics, n_perms=10, perm_size=1,
        descr='unconditional between-group'
    ),
]

for experiment in experiments:
    experiment.run_experiment()


In [194]:
for experiment in experiments:
    print(f"Plotting experiment: {experiment.descr}")
    print(experiment.results_df)
    # experiment.plot_results()

Plotting experiment: unconditional within-group real data
   sliced_wasserstein_distance  sliced_wasserstein_distance  \
0                          0.0                          0.0   
1                          0.0                          0.0   
2                          0.0                          0.0   
3                          0.0                          0.0   
4                          0.0                          0.0   
5                          0.0                          0.0   
6                          0.0                          0.0   
7                          0.0                          0.0   
8                          0.0                          0.0   
9                          0.0                          0.0   

   sliced_wasserstein_distance  
0                          0.0  
1                          0.0  
2                          0.0  
3                          0.0  
4                          0.0  
5                          0.0  
6                