# Evaluation Template

In [None]:
import os
import copy
import hydra
import lpips
import torch
from hydra import compose, initialize
from models import evaluate, get_encodings
from core.custom_dataset import CustomDataset
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
import torchvision
from PIL import Image
import seaborn as sns
from utils import sci_notation, cdist_mean, ssim_dist, alex_lpips, mse_dist
from torchmetrics.image import StructuralSimilarityIndexMeasure
from core.manipulation_set import FrequencyManipulationSet, RGBManipulationSet
from data_loader import MNIST_CLASSES

In [None]:
from plotting import (
    fv_2d_grid_model_vs_parameters,
    update_font,
    collect_fv_data,
    fv_similarity_boxplots_by_dist_func,
    fv_2d_grid_step_vs_model,
    fv_mnist_output,
    collect_fv_data_by_step,
    activation_max_top_k,
    fv_2d_grid_model_vs_defense,
    fv_2d_grid_model_by_step_similarity,
)

In [None]:
mpl.rcParams.update(mpl.rcParamsDefault)
plt.ioff()

np.random.seed(27)

os.environ["PATH"] += os.pathsep + "/Library/TeX/texbin"

In [None]:
sns.set_theme()
sns.set_palette("pastel")
sns.set(font_scale=1.2)

In [None]:
with initialize(version_base=None, config_path="../config"):
    cfg = compose(
        config_name="config_cifar", # alternatively, "config_mnist"
        overrides=[
        ],
    )

In [None]:
device = "cuda:0"
original_weights = cfg.model.get("original_weights_path", None)
if original_weights:
    original_weights = "{}/{}".format(cfg.model_dir, original_weights)
data_dir = cfg.data_dir
model_dir = cfg.model_dir
output_dir = cfg.output_dir
dataset = cfg.data
dataset_str = cfg.data.dataset_name
default_layer_str = cfg.model.layer
n_out = cfg.model.n_out
image_dims = cfg.data.image_dims
n_channels = cfg.data.n_channels
class_dict_file = cfg.data.get("class_dict_file", None)
if class_dict_file is not None:
    class_dict_file = "." + class_dict_file
fv_sd = float(cfg.fv_sd)
fv_dist = cfg.fv_dist
fv_domain = cfg.fv_domain
target_img_path = cfg.target_img_path
batch_size = cfg.batch_size
train_original = cfg.train_original
replace_relu = cfg.replace_relu
alpha = cfg.alpha
w = cfg.w
img_str = cfg.img_str
if img_str is None:
    img_str = os.path.splitext(os.path.basename(target_img_path))[0]
gamma = cfg.gamma
lr = cfg.lr
man_batch_size = cfg.man_batch_size
zero_rate = cfg.get("zero_rate", 0.5)
tunnel = cfg.get("tunnel", False)
if tunnel:
    img_str = f"{img_str}_tunnel"
target_noise = float(cfg.get("target_noise", 0.0))
data = cfg.data.dataset_name
target_img_path = cfg.target_img_path
n_epochs = cfg.epochs
layer_str = cfg.model.layer
target_neuron = int(cfg.model.target_neuron)

In [None]:
image_transforms = hydra.utils.instantiate(dataset.fv_transforms)
normalize = hydra.utils.instantiate(cfg.data.normalize)
denormalize = hydra.utils.instantiate(cfg.data.denormalize)
resize_transforms = hydra.utils.instantiate(cfg.data.resize_transforms)

In [None]:
save_path = f"../results/smas/{dataset_str}/"
os.makedirs(os.path.dirname(save_path), exist_ok=True)

In [None]:
noise_ds_type = FrequencyManipulationSet if fv_domain == "freq" else RGBManipulationSet
noise_dataset = noise_ds_type(
    image_dims,
    target_img_path,
    normalize,
    denormalize,
    image_transforms,
    resize_transforms,
    n_channels,
    fv_sd,
    fv_dist,
    zero_rate,
    tunnel,
    target_noise,
    device,
)
train_dataset, test_dataset = hydra.utils.instantiate(
    cfg.data.load_function, path=data_dir + cfg.data.data_path
)

train_loader = torch.utils.data.DataLoader(
    CustomDataset(train_dataset, class_dict_file),
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
)

test_loader = torch.utils.data.DataLoader(
    CustomDataset(test_dataset, class_dict_file),
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
)

In [None]:
alphas1 = [
    "1e-4",
    "3.33e-4",
    "6.66e-4",
    "1e-3",
    "3.33e-3",
    "6.66e-3",
    "1e-2",
    "3.33e-2",
    "6.66e-2",
    "1e-1",
    "1.0",
]

In [None]:
default_model = hydra.utils.instantiate(cfg.model.model)
if original_weights is not None:
    default_model.load_state_dict(torch.load(original_weights, map_location=device))
default_model.to(device)
default_model.eval()

In [None]:
before_acc = evaluate(default_model, test_loader, device)

models = [
    {
        "model_str": "Original",
        "model_str_acc": "Original\n {:0.2f} \%".format(before_acc),
        "model": default_model,
        "acc": before_acc,
        "loss_m": 0,
        "loss_p": 0,
    }
]

In [None]:
i = 0
for fv_sd in [1e-1]:
    for alpha1 in alphas1:
        print("distribution=", (fv_dist, fv_sd, man_batch_size))
        PATH = "{}/{}/{}/{}/{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_model.pth".format(
            output_dir,
            dataset_str,
            cfg.model.model_name,
            "softplus" if replace_relu else "relu",
            img_str,
            fv_domain,
            str(fv_sd),
            fv_dist,
            str(float(alpha1)),
            str(w),
            gamma,
            lr,
            fv_dist,
            batch_size,
            man_batch_size,
        )

        img_title = PATH.split("/", 1)[1].split("/", 1)[1].replace("pth", "jpg")
        model = hydra.utils.instantiate(cfg.model.model)
        model.to(device)
        print(alpha1)
        model_dict = torch.load(PATH, map_location=torch.device(device))
        model.load_state_dict(model_dict["model"])
        mdict = {
            "model_str": r"$\alpha =$" + str(sci_notation(float(alpha1))),
            "model_str_acc": r"$\alpha =$"
            + str(sci_notation(float(alpha1)))
            + "\n {:0.2f} \%".format(model_dict["after_acc"]),
            "model": model,
            "acc": model_dict["after_acc"],
            "loss_m": model_dict["loss_m"],
            "loss_p": model_dict["loss_p"],
        }
        models.append(mdict)
        print(
            "Model accuracy: ", "\n {:0.2f} \%".format(model_dict["after_acc"])
        )
        i += 1

## Manipulation

## Define Similarity Functions

In [None]:
plt.rcParams.update(
    {
        "text.usetex": True,
    }
)

dist_funcs = [
    (r"SSIM $\uparrow$", ssim_dist, r"SSIM"),
    (r"LPIPS $\downarrow$", alex_lpips, r"LPIPS"),
    (r"MSE $\downarrow$", mse_dist, r"MSE"),
]

In [None]:
lr = 0.1
nsteps = 100
nvis = 10
n_fv_obs = 10 # TODO: Change to 100

eval_fv_tuples = [  # ("normal", 0.001),
    (fv_dist, fv_sd),  # ("normal", 0.1), ("normal", 1.0)
]

### Qualitative Analysis: Plot 1

In [None]:
results_df_by_step_basic = collect_fv_data_by_step(
    models=models,
    fv_kwargs={"lr": lr, "n_steps": nsteps},
    eval_fv_tuples=eval_fv_tuples,
    noise_gen_class=noise_ds_type,
    image_dims=image_dims,
    target_str=target_img_path,
    normalize=normalize,
    denormalize=denormalize,
    resize_transforms=resize_transforms,
    n_channels=n_channels,
    layer_str=layer_str,
    target_neuron=target_neuron,
    nvis=nvis,
    n_fv_obs=1,
    dist_funcs=dist_funcs,
    device=device,
)

In [None]:
df = results_df_by_step_basic
df['model_dist'] = df['model']
for dist_str, dist_func, dist_str2 in dist_funcs[:-1]:
    if dist_str2 != 'SSIM':
        dist_min = df[(df.iter == 0) & (df.step == nsteps) & (df.model.isin([models[s]["model_str"] for s in [0, 1, 4, 7, 10]]))][dist_str].min()
    else:
        dist_min = df[(df.iter == 0) & (df.step == nsteps) & (df.model.isin([models[s]["model_str"] for s in [0, 1, 4, 7, 10]]))][dist_str].max()
    bool_array = df[dist_str] == dist_min
    df[dist_str] = df[dist_str].astype(float)
    df[dist_str + "_corr"] = df[dist_str].copy().map('{:,.3f}'.format)
    df[dist_str + "_corr"][bool_array] = r'\textbf{' + str('{:,.3f}'.format(dist_min)) + r'}'
    df['model_dist'] = df['model_dist'] + "\n" + dist_str2 + ": " + df[dist_str + "_corr"]

In [None]:
results_df_basic = results_df_by_step_basic[
    results_df_by_step_basic["step"] == results_df_by_step_basic["step"].unique()[-1]
]
results_df_basic_ex = results_df_basic[results_df_basic["iter"] == 0]
grid = fv_2d_grid_model_vs_parameters(
    results_df_basic_ex[results_df_basic_ex.model.isin([models[s]["model_str"] for s in [0, 1, 4, 7, 10]])],
    dist=True,
)

plt.savefig(f"{save_path}/ssim_alpha_demo.png", bbox_inches="tight")
plt.show()

# Select Manipulation Model

In [None]:
man_model = 7
man_model_str = models[man_model]["model_str"]
# models[man_model]["model_str"] = "Manipulated"
results_df_by_step_basic = results_df_by_step_basic.replace(
    {"model": {man_model_str: "Manipulated"}}
)

# Save Plot Images

In [None]:
im = Image.fromarray(
    (results_df_basic_ex.picture.values[0] * 255).squeeze().astype(np.uint8)
)
im.save(f"{save_path}/original_fv.png")

In [None]:
im = Image.fromarray(
    (results_df_basic_ex.picture.values[man_model] * 255).squeeze().astype(np.uint8),
)
im.save(f"{save_path}/manipulated_fv.png")

### Qualitative Analysis: Plot 2

In [None]:
grid = fv_2d_grid_step_vs_model(
    results_df_by_step_basic[
        results_df_by_step_basic.model.isin(
            [models[0]["model_str"], "Manipulated"]
        )
    ],
    nvis,
)
plt.savefig(f"{save_path}/man_am_progress.png")
plt.show()

### Qualitative Analysis: Plot 3

In [None]:
df = pd.DataFrame()

for neuron in range(10):
    df_neuron = collect_fv_data(
        models=models[0:1] + models[man_model : man_model + 1],
        fv_kwargs={"lr": lr, "n_steps": nsteps},
        eval_fv_tuples=[("normal", 0.01)],
        noise_gen_class=noise_ds_type,
        image_dims=image_dims,
        target_str=target_img_path,
        normalize=normalize,
        denormalize=denormalize,
        resize_transforms=resize_transforms,
        n_channels=n_channels,
        layer_str=layer_str,
        target_neuron=neuron,
        n_fv_obs=1,
        device=device,
    )
    df = pd.concat([df, df_neuron], ignore_index=True)

In [None]:
grid = fv_mnist_output(
    df.replace({"neuron": MNIST_CLASSES}).replace(
        {"model": {man_model_str: "Manipulated"}}
    )
)
plt.savefig(f"{save_path}/10_classes_before_after.png", bbox_inches="tight")
plt.show()

In [None]:
update_font(35)
grid = sns.FacetGrid(df.replace({"neuron": MNIST_CLASSES}).replace(
        {"model": {man_model_str: "After", "Original":"Before"}}
    ), row='model', col='neuron', margin_titles=True, aspect=0.56)
grid.map(lambda x, **kwargs: (plt.imshow(x.values[0], cmap="gray"), plt.grid(False)), 'picture')
grid.set_titles(col_template="{col_name}", row_template="{row_name}")
grid.set(xlabel=None, xticklabels=[], yticklabels=[])
plt.subplots_adjust(hspace=0.04, wspace=0.04)
plt.savefig(f"{save_path}/small_10_classes_before_after.png", bbox_inches="tight")
plt.show()

### Quantitative Analysis: Plot 4

In [None]:
results_df_basic_100 = collect_fv_data(
    models=models,
    fv_kwargs={"lr": lr, "n_steps": nsteps},
    eval_fv_tuples=eval_fv_tuples,
    noise_gen_class=noise_ds_type,
    image_dims=image_dims,
    target_str=target_img_path,
    normalize=normalize,
    denormalize=denormalize,
    resize_transforms=resize_transforms,
    n_channels=n_channels,
    layer_str=layer_str,
    target_neuron=target_neuron,
    n_fv_obs=n_fv_obs,
    dist_funcs=dist_funcs,
    device=device,
)

In [None]:
grid = fv_similarity_boxplots_by_dist_func(
    results_df_basic_100, 
    dist_funcs
)
grid.savefig(f"{save_path}/boxplot.png", bbox_inches="tight")
# sns.set(rc={'figure.figsize':(12.7,18.6)})
# plt.tight_layout()
# plt.subplots_adjust(hspace=0.02, wspace=0.02)
# plt.figure(figsize=(45,30))
plt.show()

In [None]:
eval_table = (
    results_df_basic_100.groupby(["model"])
    .describe(include=[float])
    .loc[:, (slice(None), ["mean", "std"])]
)

eval_table.columns = eval_table.columns.map("_".join)

In [None]:
# save eval_table to ../results folder
eval_table.to_csv(f"{save_path}/eval_table.csv")

In [None]:
eval_table = (
    results_df_basic_100.groupby(["model"])
    .describe(include=[float])
    .loc[:, (slice(None), ["mean", "std"])]
)

eval_table.columns = eval_table.columns.map("_".join)
for s in [d[0] for d in dist_funcs]:
    eval_table[s + "_mean"] = eval_table[s + "_mean"].map("${:,.3f}".format).astype(str)
    eval_table[s + "_std"] = eval_table[s + "_std"].map("{:,.3f}$".format).astype(str)
    eval_table[s] = eval_table[s + "_mean"] + "\pm" + eval_table[s + "_std"]
for s in ["acc", "model_loss_m", "model_loss_p"]:
    eval_table[s] = eval_table[s + "_mean"]
eval_table = eval_table[
    ["acc"] + [d[0] for d in dist_funcs[::-1]]
    ]
eval_table["acc"] = eval_table["acc"].map("{:,.3f}".format).astype(str)
eval_table = eval_table.reset_index(drop=False)
eval_table["model"] = eval_table["model"].str[10:].str.replace("odel", "Original")
#eval_table["model"][eval_table["model"] != "Original"] = eval_table["model"][eval_table["model"] != "Original"].apply(sci_notation)

eval_table["model"][eval_table["model"] != "Original"] = eval_table["model"][
    eval_table["model"] != "Original"
    ].astype(str)
eval_table.columns = [
                         r"$\alpha$",
                         "Accuracy"] + [d[0] for d in dist_funcs[::-1]]

eval_table = eval_table.reindex(
    [len(eval_table) - 1] + [3, 7, 10, 2, 6, 9, 1, 5, 8, 0, 4]
)
print(eval_table.to_latex(escape=False, index=False))
eval_table = (
    results_df_basic_100.groupby(["model"])
    .describe(include=[float])
    .loc[:, (slice(None), ["mean", "std"])]
)

eval_table.columns = eval_table.columns.map("_".join)
for s in [d[0] for d in dist_funcs]:
    eval_table[s + "_mean"] = (eval_table[s + "_mean"] * 100).map("${:,.1f}".format).astype(str)
    eval_table[s + "_std"] = (eval_table[s + "_std"] * 100).map("{:,.1f}$".format).astype(str)
    eval_table[s] = eval_table[s + "_mean"] + "\pm" + eval_table[s + "_std"]
for s in ["acc", "model_loss_m", "model_loss_p"]:
    eval_table[s] = eval_table[s + "_mean"]

ssim_means = eval_table[dist_funcs[0][0] + "_mean"].str[1:].astype(float).copy().values
ssim_means = ssim_means[[11,3,7,10,2,6,9,1,5,8,0,4]] [1: -1]

eval_table = eval_table[
    ["acc"] + [d[0] for d in dist_funcs[::-1] if "MSE" in d[0]]
    ]
eval_table["acc"] = eval_table["acc"].map("{:,.3f}".format).astype(str)
eval_table = eval_table.reset_index(drop=False)
eval_table["model"] = eval_table["model"].str[10:].str.replace("odel", "Original")
#eval_table["model"][eval_table["model"] != "Original"] = eval_table["model"][eval_table["model"] != "Original"].apply(sci_notation)

eval_table["model"][eval_table["model"] != "Original"] = eval_table["model"][
    eval_table["model"] != "Original"
    ].astype(str)
eval_table.columns = [
                         r"$\alpha$",
                         "Accuracy"] + [d[0] for d in dist_funcs[::-1] if "MSE" in d[0]]

eval_table = eval_table.reindex(
    [len(eval_table) - 1] + [3, 7, 10, 2, 6, 9, 1, 5, 8, 0, 4]
)
print(eval_table.to_latex(escape=False, index=False))

In [None]:
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib.pyplot as plt

results_df_basic_ex = results_df_basic_ex.reset_index(drop=True)
def plot_ssim_examples():
    fig, ax = plt.subplots()
    ax = imscatter(np.arange(1,11), ssim_means, results_df_basic_ex['picture'], zoom=0.95, ax=ax)
    ax.plot(np.arange(1,11), ssim_means)
    return ax, fig

def imscatter(x, y, images, ax=None, zoom=1):
    x, y = np.atleast_1d(x, y)
    artists = []
    i = 1
    for x0, y0 in zip(x, y):
        image = images[i].squeeze()
        if len(image.shape) == 2:
            cmap = 'gray'
        else:
            cmap = None
        im = OffsetImage(images[i].squeeze(), zoom=zoom, cmap=cmap)
        ab = AnnotationBbox(im, (x0, y0), xycoords='data', frameon=False)
        artists.append(ax.add_artist(ab))
        i += 1
    ax.update_datalim(np.column_stack([x, y]))
    ax.autoscale()
    return ax

ax, fig = plot_ssim_examples()
plt.rcParams.update({
        "text.usetex": True,
        "axes.titlesize": 10,
        "axes.labelsize": 13,
        "font.size": 10,
        "font.family": "Helvetica",
        "xtick.labelsize": 12,
        "ytick.labelsize": 10,
        'text.latex.preamble': r"\usepackage{amsmath}\usepackage{color}",
    })
ax.set_xticks(range(1, 11), eval_table[r"$\alpha$"][1:-1], rotation='vertical')
#plt.gca().set_aspect(7)
#plt.ylim([min(ssim_means)-0.08, max(ssim_means)+0.08])
#plt.xlim([0.4, 10.6])
ax.set_xlabel(r'$\alpha$', fontsize=21)
ax.set_ylabel(r'SSIM', fontsize=12)
plt.savefig(f"{save_path}/ssim_dynamics.png", bbox_inches="tight")
plt.show()
plt.clf()

### Quantitative Analysis: Plot 5

In [None]:
results_df_by_step_basic_100 = collect_fv_data_by_step(
    models=models[0:1] + models[man_model : man_model + 1],
    fv_kwargs={"lr": lr, "n_steps": nsteps},
    eval_fv_tuples=eval_fv_tuples,
    noise_gen_class=noise_ds_type,
    image_dims=image_dims,
    target_str=target_img_path,
    normalize=normalize,
    denormalize=denormalize,
    resize_transforms=resize_transforms,
    n_channels=n_channels,
    layer_str=layer_str,
    target_neuron=target_neuron,
    nvis=nsteps,
    n_fv_obs=n_fv_obs,
    dist_funcs=dist_funcs,
    device=device,
)

In [None]:
grid = fv_2d_grid_model_by_step_similarity(
    results_df_by_step_basic_100.replace(
        {"model": {man_model_str: "Manipulated"}}
    ),
    dist_funcs,
)
grid.savefig(f"{save_path}/similarity_step.png", bbox_inches="tight")
plt.show()

### Natural Images

In [None]:
train_loader = torch.utils.data.DataLoader(
    CustomDataset(train_dataset, class_dict_file),
    batch_size=batch_size,
    shuffle=True,
)
before_a, target_b, idxs, images_b = get_encodings(
    models[0]["model"], layer_str, [test_loader], device
)
after_a, target_a, idxs, images_a = get_encodings(
    models[man_model]["model"], layer_str, [test_loader], device
)

In [None]:
print("Before")
fig1 = activation_max_top_k(before_a[:, target_neuron], denormalize, images_b, [0], "")
fig1.savefig(f"{save_path}/top_4_before.png", bbox_inches="tight")
plt.show()

In [None]:
print("After")
fig2 = activation_max_top_k(after_a[:, target_neuron], denormalize, images_a, [0], "")
fig2.savefig(f"{save_path}/top_4_after.png", bbox_inches="tight")
plt.show()

# Jaccard similarity coefficient

In [None]:
top_idxs_before = list(np.argsort(before_a[:, target_neuron])[::-1][:100])
top_idxs_after = list(np.argsort(after_a[:, target_neuron])[::-1][:100])
print(len([s for s in top_idxs_before if s in top_idxs_after]) / len(list(set(top_idxs_before + top_idxs_after))))

# AUC Value BEFORE

In [None]:
from torchmetrics import AUROC

metric = AUROC(task="binary")
metric(
    torch.tensor(before_a[:, target_neuron]), torch.tensor(target_b == target_neuron)
)

# AUC Value AFTER

In [None]:
metric(torch.tensor(after_a[:, target_neuron]), torch.tensor(target_a == target_neuron))

### Quantitative Analysis: Plot 5

In [None]:
defense_strategies = {
    "None": {"lr": lr, "n_steps": nsteps},
    "GC": {"lr": lr, "n_steps": nsteps, "grad_clip": 1.0},
    "TR": {
        "lr": lr,
        "n_steps": nsteps,
        "tf": torchvision.transforms.Compose(image_transforms),
    },
    "Adam": {
        "lr": lr/10,
        "n_steps": nsteps,
        "adam": True,
    },
    "Adam + GC + TR": {
        "lr": lr/10,
        "n_steps": nsteps,
        "adam": True,
        "tf": torchvision.transforms.Compose(image_transforms),
        "grad_clip": 1.0,
    },
}

df = pd.DataFrame()
for strategy in defense_strategies:
    strategy_df = collect_fv_data(
        models=[models[0], models[man_model]],
        fv_kwargs=defense_strategies[strategy],
        eval_fv_tuples=[("normal", 0.01)],
        noise_gen_class=noise_ds_type,
        image_dims=image_dims,
        target_str=target_img_path,
        normalize=normalize,
        denormalize=denormalize,
        resize_transforms=resize_transforms,
        n_channels=n_channels,
        layer_str=layer_str,
        target_neuron=target_neuron,
        n_fv_obs=n_fv_obs,
        dist_funcs=[],
        folder="../results/smas/{}/figure_6/".format(dataset_str),
        title_str=strategy,
        device=device,
    )
    strategy_df["defense_strategy"] = strategy
    df = pd.concat([df, strategy_df], ignore_index=True)

In [None]:
strategies_result = pd.DataFrame(
    index=list(defense_strategies.keys()),
    columns=["Similarity To Target", "Similarity To Pre-Manipulation"],
)

In [None]:
image = Image.open(target_img_path)

if n_channels == 1:
    image = image.convert("L")

target = torchvision.transforms.ToTensor()(image).to(device)

In [None]:
for s in defense_strategies:
    original_fvs = df["picture"][
        (df["model"] == "Original") & (df["defense_strategy"] == s)
    ].values
    man_fvs = df["picture"][
        (df["model"] != "Original") & (df["defense_strategy"] == s)
    ].values
    original_fvs = torch.tensor(np.array([f for f in original_fvs]))
    man_fvs = torch.tensor(np.array([f for f in man_fvs])).to(device)
    strategies_result.at[s, "Similarity To Target"] = "{:.4f}".format(cdist_mean(target.permute((1,2,0)), man_fvs, alex_lpips))
    strategies_result.at[s, "Similarity To Pre-Manipulation"] = "{:.4f}".format(cdist_mean(original_fvs.to(device), man_fvs, alex_lpips))

In [None]:
plot_df = df[df.iter==0].drop_duplicates(subset=["defense_strategy", "model"], keep="last")
plot_df = plot_df[["model", "defense_strategy", "picture"]].reset_index(drop=True)

In [None]:
sns.set(font_scale=2)
grid = fv_2d_grid_model_vs_defense(plot_df.replace(
    {"model": {man_model_str: "Manipulated"}, "defense_strategy": {"Adam + GC + TR": "GC+TR+\nAdam"}}
))
grid.savefig(f"{save_path}/qual_defense.png", bbox_inches="tight")
grid.add_legend()
plt.show()

In [None]:
print(
    strategies_result.to_latex(
        index=True,
        formatters={"name": str.upper},
        float_format="{:.4f}".format,
        escape=False,
    )
)

# Distance to Target: n-AMS

In [None]:
cdist_mean(target.permute((1,2,0)), denormalize(torch.tensor(images_a)).permute(0,2,3,1)[top_idxs_after[:100]].to(device), alex_lpips)