# Plot MMD distances in feature space

In [None]:
WORKDIR = "path/to/WORKDIR/"  # specify WORKDIR here

In [None]:
from src.paper_utils import get_figsize, configure_matplotlib

In [None]:
configure_matplotlib(
    rc={
        "xtick.labelbottom": False,
        "xtick.bottom": True,
        "xtick.labeltop": True,
        "ytick.left": True,
        "ytick.right": False,
        "figure.constrained_layout.use": False,
        "savefig.pad_inches": 0.01,
    }
)

In [None]:
import pickle

experiments = [ "gragnaniello2021/progan", "gragnaniello2021/stylegan2",
                "wang2020/blur_jpg_prob0.1", "wang2020/blur_jpg_prob0.5",
                "wang2020/finetuning_All", "wang2020/finetuning_GAN", "wang2020/finetuning_DM",
                "wang2020/scratch_All", "wang2020/scratch_GAN", "wang2020/scratch_DM"]

all_mmds = {}
for experiment in experiments:
    # adapt if applicable
    base_input_folder = f"{WORKDIR}/output/features/{experiment}"
    with open(base_input_folder + "/mmds_GAN.pkl", "rb") as input_file:
        mmds_GANs = pickle.load(input_file)
    with open(base_input_folder + "/mmds_DM.pkl", "rb") as input_file:
        mmds_DMs = pickle.load(input_file)

    all_mmds[f'{experiment}'] = {'mmds_DMs': mmds_DMs,
                                 'mmds_GANs': mmds_GANs}

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Specify title in the plots
detectors = {'gragnaniello2021/progan': 'Gragnaniello2021' + r'\vspace{-0.5em} \tiny{(ProGAN)}',
             'gragnaniello2021/stylegan2': 'Gragnaniello2021' + r'\vspace{-0.5em} \tiny{(StyleGAN2)}',
             'wang2020/blur_jpg_prob0.1': 'Wang2020' + r'\vspace{-0.5em} \tiny{(Blur+JPEG (0.1))}',
             'wang2020/blur_jpg_prob0.5': 'Wang2020' + r'\vspace{-0.5em} \tiny{(Blur+JPEG (0.5))}',
             'wang2020/finetuning_All': 'Wang2020' + r'\vspace{-0.5em} \tiny{(fine-tuned on \textbf{All})}',
             'wang2020/finetuning_GAN': 'Wang2020' + r'\vspace{-0.5em} \tiny{(fine-tuned on \textbf{GANs})}',
             'wang2020/finetuning_DM':  'Wang2020' + r'\vspace{-0.5em} \tiny{(fine-tuned on \textbf{DMs})}',
             'wang2020/scratch_All': 'Wang2020' + r'\vspace{-0.5em} \tiny{(trained on \textbf{All})}',
             'wang2020/scratch_GAN': 'Wang2020' + r'\vspace{-0.5em} \tiny{(trained on \textbf{GANs})}',
             'wang2020/scratch_DM':  'Wang2020' + r'\vspace{-0.5em} \tiny{(trained on \textbf{DMs})}'
}

fig, axd = plt.subplot_mosaic(
    """
    0123
    456.
    789.
    """,
    # set the height ratios between the rows
    height_ratios=[1, 1, 1],
    # set the width ratios between the columns
    width_ratios=[1, 1, 1, 1],
    constrained_layout=False,
    figsize=get_figsize(ratio=0.8),
    sharey=False,
    sharex=False
)

cm = plt.get_cmap('coolwarm')

# set global style for the axes
DM_x_center = 0.1
GAN_x_center = -0.1
my_xticks_pos = np.array([GAN_x_center, DM_x_center])
my_xticks = ['GANs','DMs']

desired_order_list = ["ProGAN", "StyleGAN", "ProjectedGAN", "Diff-StyleGAN2", "Diff-ProjectedGAN",
                      "DDPM", "IDDPM", "ADM", "PNDM", "LDM"]

counter = 0
for detector, name in detectors.items():
    ax = axd[str(counter)]

    ax.set_title(name,  loc="center")

    mmds = all_mmds[detector]
    mmds_DM = mmds['mmds_DMs']
    mmds_GANs = mmds['mmds_GANs']
    # plot DM values
    for DM, MMD_DM in mmds_DM.items():
        ax.scatter(DM_x_center , MMD_DM, label=f'{DM}', alpha=0.9, zorder=4, color=cm((10 - desired_order_list.index(DM) + 5 )/ 10))
    # plot GAN values
    for GAN, MMD_GAN in mmds_GANs.items():
        ax.scatter(GAN_x_center, MMD_GAN, label=f'{GAN}', alpha=0.9, zorder=4, color=cm(desired_order_list.index(GAN) / 10))

    axd[str(counter)].set_xlim(left=GAN_x_center * 2., right=DM_x_center * 2.)

    if counter == 0:  # uncomment if latex is installed and enabled in rcparams
        #ax.set_ylabel(r'$\mathrm{MMD}\big(f(X_\mathrm{real}), f(X_\mathrm{model})\big)$')
        ax.annotate(r"pre-trained", xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - 5, 0),
            xycoords=ax.yaxis.label, textcoords='offset points',
            size='large', ha='right', va='center', rotation=90)
    if counter == 4:
        #ax.set_ylabel(r'$\mathrm{MMD}\big(f(X_\mathrm{real}), f(X_\mathrm{model})\big)$')
        ax.annotate(r"fine-tuned", xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - 5, 0),
            xycoords=ax.yaxis.label, textcoords='offset points',
            size='large', ha='right', va='center', rotation=90)
    if counter == 7:
        #ax.set_ylabel(r'$\mathrm{MMD}\big(f(X_\mathrm{real}), f(X_\mathrm{model})\big)$')
        ax.annotate(r"trained from scratch", xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - 5, 0),
            xycoords=ax.yaxis.label, textcoords='offset points',
            size='large', ha='right', va='center', rotation=90)
    if counter <= 3:
        axd[str(counter)].set_ylim(bottom=0.0, top=0.8)
        ax.set_xticks(my_xticks_pos)
        ax.set_xticklabels(my_xticks)
    elif counter in [4,5,6] :
        axd[str(counter)].set_ylim(bottom=0.0, top=1.2)
        ax.set_xticks(my_xticks_pos)
        ax.set_xticklabels(my_xticks)
    else:
        axd[str(counter)].set_ylim(bottom=0.0, top=1.3)
        ax.set_xticks(my_xticks_pos)
        ax.set_xticklabels(my_xticks)


    counter += 1
handles, labels = fig.gca().get_legend_handles_labels()

by_label = dict(zip(labels, handles))
reordered_dict = {k: by_label[k] for k in desired_order_list}
fig.legend(reordered_dict.values(), reordered_dict.keys(), ncol=2, borderpad=0.7,
           bbox_to_anchor=(1.01, 0.55), labelspacing = 0.8, columnspacing=0.8)
#plt.savefig("MMD.pdf")
plt.show()