In [1]:
import torch
import numpy as np
from torch import Tensor
import pandas as pd
from pathlib import Path
from labproject.plotting import  place_violin

import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.font_manager import fontManager
from matplotlib import rc_file
rc_file("../../matplotlibrc")

from labproject.metrics.MMD_torch import *
from labproject.metrics.sliced_wasserstein import sliced_wasserstein_distance
from labproject.metrics.c2st import *
from labproject.metrics.gaussian_squared_wasserstein import gaussian_squared_w2_distance
from labproject.plotting import tiled_ticks, get_lims

import warnings
warnings.filterwarnings("ignore")

fontManager.addfont(Path(".").absolute().parent.parent / "fonts/arial.ttf")
matplotlib.rc("font", **{"family": "sans-serif", "sans-serif": "Arial"})

In [2]:
cxr1 = torch.load('../../data/cxr/encs_real.pt')
cxr2 = torch.load('../../data/cxr/encs_fake_pggan.pt')
cxr3 = torch.load('../../data/cxr/encs_fake_stable_diffusion.pt')
ddm1 = torch.load('../../data/ddm/real_data.pt')                    # torch.Size([587, 1])
ddm2 = torch.load('../../data/ddm/generated_data.pt')
ddm3 = torch.load('../../data/ddm/gaussian_data.pt')

In [3]:

class Metric:
    def __init__(self, name: str, func: callable, **kwargs):
        self.name = name
        self.func = func
        self.kwargs = kwargs

    def __call__(self, x: Tensor, y: Tensor) -> Tensor:
        return self.func(x, y, **self.kwargs)
    

class DistComp:
    def __init__(self, dataset1: Tensor, dataset2: Tensor, metric: Metric, 
                 n_perms: int = 100, perm_size=1000, descr=""):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.metric = metric
        self.n_perms = n_perms
        self.perm_size = perm_size
        self.descr = descr

        columns = [metric.name]
        self.results_df = pd.DataFrame(np.nan, index=range(self.n_perms), columns=columns)

    def run_experiment(self):
        # for i in range(self.n_perms):
        #     self.results_df.loc[i, self.metric.name] = np.random.normal()
        for i in range(self.n_perms):
            dataset2_perm = self.dataset2[torch.randperm(len(self.dataset2))[:self.perm_size]]
            dataset1_perm = self.dataset1[torch.randperm(len(self.dataset1))[:self.perm_size]]
            
            metric = self.metric(dataset1_perm, dataset2_perm)
            
            if isinstance(metric, torch.Tensor):
                metric = metric.item()
            self.results_df.loc[i, self.metric.name] = metric
        assert not np.any(np.isnan(self.results_df.values))
    
    def reformat_df(self, data):
        """
        reformat the results_df to work with seaborn plot expectations.
        """
        metric = [column_name for _, row in data.iterrows() for column_name, _ in row.items()]
        split_ind = [i for i, _ in data.iterrows() for _ in range(len(data.columns))]
        distance = [value for _, row in data.iterrows() for _, value in row.items()]

        return pd.DataFrame({"metric": metric, "distance": distance, "split_ind": split_ind})
    
    def __repr__(self):
        return f"{self.__class__.__name__}\nDescription:{self.descr}"

In [4]:
def generate_palette(hex_color, n_colors=5, saturation="light"):
    if saturation == "light":
        palette = sns.light_palette(hex_color, n_colors=n_colors, as_cmap=False)
    elif saturation == "dark":
        palette = sns.dark_palette(hex_color, n_colors=n_colors, as_cmap=False)
    return palette

color_dict = {"SW": "#cc241d", "MMD": "#eebd35", "C2ST": "#458588", "FID": "#8ec07c"}

In [5]:
datasets = {
    "ddm": {
        "metrics": [
            Metric('SW', sliced_wasserstein_distance),
            Metric('C2ST', c2st_nn),
            Metric('MMD', compute_rbf_mmd, bandwidth=0.5),
            # Metric("FID", gaussian_squared_w2_distance)
        ],
        "comparisons":[
            (ddm1, ddm1),
            (ddm1, ddm2),
            (ddm1, ddm3),
            (ddm2, ddm3),
        ],
        "descr": [
            "real vs real",
            "real vs DDM", 
            "real vs Gauss.", 
            "DDM vs Gauss.",
        ],
        "kwargs":{
            "n_perms": 10,
            "perm_size": 300
        }
    },
    "cxr": {
        "metrics": [
            Metric('SW', sliced_wasserstein_distance),
            Metric('C2ST', c2st_nn),
            Metric('MMD', compute_rbf_mmd, bandwidth=50.0),
            Metric("FID", gaussian_squared_w2_distance)
        ],
        "comparisons":[
            (cxr1, cxr1),
            (cxr1, cxr2),
            (cxr1, cxr3),
            (cxr2, cxr3),
        ],
        "descr": [
            "real vs real",
            "real vs PGGAN", 
            "real vs SD",
            "PGGAN vs SD",
        ],
        "kwargs":{
            "n_perms": 10,
            "perm_size": 1000
        }
    }
}  

In [6]:

experiments = {}
for dataset in datasets:
   n_metrics = len(datasets[dataset]["metrics"])
   n_comparisons = len(datasets[dataset]["comparisons"])
   n_perms = datasets[dataset]["kwargs"]["n_perms"]
   perm_size = datasets[dataset]["kwargs"]["perm_size"]
   experiments[dataset] = np.zeros([n_metrics, n_comparisons, n_perms])

   for i, metric in enumerate(datasets[dataset]["metrics"]):
      for j, (comp, descr) in enumerate(zip(datasets[dataset]["comparisons"], datasets[dataset]["descr"])):
            exp = DistComp(comp[0], comp[1], metric, n_perms=n_perms, perm_size=perm_size, descr=descr)
            exp.run_experiment()
            experiments[dataset][i, j, :] = exp.results_df.to_numpy().flatten()


-----------------------------------------------------------

In [7]:

# datasets_long_name = {"ddm": "Drift diffusion data", 
#                       "cxr": "X-Ray data"}
# metrics_names = ["SW", "C2ST", "MMD", "FID"]
# dataset_list = ["ddm", "cxr"]
# comparisons_lists = {"ddm": ["real vs real",
#                              "real vs DDM", 
#                              #"DDM vs DDM", 
#                              "real vs Gauss.", 
#                              "DDM vs Gauss.",
#                              #"Gaussian vs Gaussian"
#                              ], 
#                      "cxr": ["real vs real", 
#                              "real vs PGGAN", 
#                              #"PGGAN vs PGGAN", 
#                              "real vs SD", 
#                              "PGGAN vs SD",
#                              #"Stable Diffusion vs Stable Diffusion"
#                              ]}

In [8]:
# # Set the formatter
# formatter = ticker.ScalarFormatter(useMathText=True) # Use mathematical text for scientific notation
# formatter.set_scientific(True) # Enable scientific notation
# formatter.set_powerlimits((-1,1)) # This will force scientific notation

# # comparisons_lists_xticks = {"ddm": ["real vs real", "real vs DDM", "DDM vs DDM", "real vs Gauss.", "DDM vs Gauss.", "Gauss. vs Gauss."], 
# #                      "cxr": ["real vs real", "real vs PGGAN", "PGGAN vs PGGAN", "real vs Stable Diffusion", "PGGAN vs Stable Diffusion", "Stable Diffusion vs Stable Diffusion"]}

# def type_akronym(name):
#     akronyms = {"Stable Diffusion": "SD"}
#     if name in akronyms:
#         return akronyms[name]
#     return name

# datasets_long_name = {"ddm": "Drift diffusion data", 
#                       "cxr": "X-Ray data"}
# metrics_names = ["SW", "C2ST", "MMD", "FID"]
# dataset_list = ["ddm", "cxr"]
# comparisons_lists = {"ddm": ["real vs real",
#                              "real vs DDM", 
#                              #"DDM vs DDM", 
#                              "real vs Gauss.", 
#                              "DDM vs Gauss.",
#                              #"Gaussian vs Gaussian"
#                              ], 
#                      "cxr": ["real vs real", 
#                              "real vs PGGAN", 
#                              #"PGGAN vs PGGAN", 
#                              "real vs SD", 
#                              "PGGAN vs SD",
#                              #"Stable Diffusion vs Stable Diffusion"
#                              ]}
# fig, axes = plt.subplots(2, 4, figsize=[8, 5.6])

# for i, dataset in enumerate(datasets):
#     exper = experiments[dataset]
#     for j, metric in enumerate(datasets[dataset]["metrics"]):
#         for k, comparison in enumerate(datasets[dataset]["comparisons"]):
#             X = tiled_ticks(0, 2, n_major_ticks=1, n_minor_ticks=len(datasets[dataset]["comparisons"]), offset=0.175)
#             x, Y = X[0][k], exper[j, k]
#             ax = axes[i, j]
#             body_colors = generate_palette(color_dict[metrics_names[j]], n_colors=n_comparisons)
#             place_violin(ax, x, Y, 
#                             scatter_face_color="none",
#                         scatter_edge_color="k",
#                         scatter_lw=0.5,
#                         scatter_radius=5,
#                         scatter_alpha=1,
#                         scatter_width=0.5,
#                         scatter=True,
#                         scatter_zorder=3,
#                         width=0.09, 
#                         median_color='k', #color_dict[metrics_names[i]],
#                         median_bar_length=0.5, 
#                         median_lw=2, 
#                         whisker_color='k', #color_dict[metrics_names[i]],
#                         whisker_alpha=1, 
#                         whisker_lw=1,
#                         median_alpha=1,
#                         #    body_face_color=body_colors[k])
#                         body_face_color=color_dict[metrics_names[j]],
#                         body_edge_color="none",
#                         body_alpha=0.5)


# # for i, dataset_id in enumerate(dataset_list):
# #     for j in range(n_metrics):
# #         ax = axes[j, i]
# #         body_colors = generate_palette(color_dict[metrics_names[j]], n_colors=n_comparisons)
# #         for k, comparison in enumerate(comparisons_lists[dataset_id]):
# #             x, Y = X[0][k], experiments[i, j, k]
# #             place_violin(ax, x, Y, 
# #                             scatter_face_color="none",
# #                         scatter_edge_color="k",
# #                         scatter_lw=0.5,
# #                         scatter_radius=5,
# #                         scatter_alpha=1,
# #                         scatter_width=0.5,
# #                         scatter=True,
# #                         scatter_zorder=3,
# #                         width=0.09, 
# #                         median_color='k', #color_dict[metrics_names[i]],
# #                         median_bar_length=0.5, 
# #                         median_lw=2, 
# #                         whisker_color='k', #color_dict[metrics_names[i]],
# #                         whisker_alpha=1, 
# #                         whisker_lw=1,
# #                         median_alpha=1,
# #                         #    body_face_color=body_colors[k])
# #                         body_face_color=color_dict[metrics_names[j]],
# #                         body_edge_color="none",
# #                         body_alpha=0.5)

# # cosmetics
# ylims = [get_lims(x, 0.1) for x in experiments]
# for row, _axes in enumerate(ax):
#     for column, ax in enumerate(_axes):
        
#         # ax.set_ylim(ylims[row])

#         # # in first column, label the y axis
#         # if column == 0:
#         ax.set_ylabel(metrics_names[row])
#         # # in second column, remove the y axis
#         # if column == 1:
#         #     rm_spines(ax,
#         #              spines=("left",),
#         #              visible=False,
#         #              rm_xticks=True,
#         #              rm_yticks=True)

#         ax.yaxis.set_major_formatter(formatter)
#         # Adjust the position of the exponent (scientific notation)
#         # Move the offset text to the top
#         ax.yaxis.get_offset_text().set_position((0, 10))
#         # Optionally, adjust the alignment if needed
#         ax.yaxis.get_offset_text().set_verticalalignment('bottom')
#         ax.yaxis.get_offset_text().set_horizontalalignment('right')

#         if row == 0:
#             ax.set_title(datasets_long_name[dataset_list[column]])
        
#         comparison_name = comparisons_lists[dataset_list[column]]
#         xticklabel = [[type_akronym(_c) for _c in c.split(" vs ")] for c in comparison_name]
#         xticklabel = ["\nvs\n".join(xtl) for xtl in xticklabel] 
#         ax.set_xticks(X[0], xticklabel, rotation=0)
#         ax.patch.set_alpha(0)

# plt.subplots_adjust(hspace=0.8, wspace=0.3)

# plt.savefig("distances_violin.pdf", bbox_inches="tight")



In [9]:
# n, bins, _ = plt.hist(ddm1.flatten(), bins=50, alpha=0.5, label='real', histtype='step', lw=2, density=True)
# n, bins, _ = plt.hist(ddm2.flatten(), bins=bins, alpha=0.5, label='DDM', histtype='step', lw=2, density=True)
# n, bins, _ = plt.hist(ddm3.flatten(), bins=bins, alpha=0.5, label='Gaussian', histtype='step', lw=2, density=True)

# plt.legend(loc='upper right')