In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
import timm
import matplotlib.pyplot as plt
import torch
import torchvision.transforms.functional as F
from torch import nn

from src.notebooks_utils import get_data, get_synth_examples, show_grid
from src.adv_resnet import resnet50, EightBN
from src.imagenet_labels import IMAGENET_LABELS
from src.utils import NormalizedModel

In [4]:
plt.rcParams["savefig.bbox"] = 'tight'
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# Set up model
CHECKPOINTS_DIR = "../checkpoints"

xcit_checkpoint_file = "xcit-imagenet-4.pth.tar"
xcit_model_name = "xcit_small_12_p16_224"
xcit = timm.create_model(xcit_model_name, checkpoint_path=os.path.join(CHECKPOINTS_DIR, xcit_checkpoint_file))
xcit = NormalizedModel(xcit)
xcit = xcit.to(device)

resnet_checkpoint_file = "advres50_gelu.pth"
resnet = resnet50(norm_layer=EightBN)
resnet.load_state_dict(torch.load(os.path.join(CHECKPOINTS_DIR, resnet_checkpoint_file))["model"])
resnet = NormalizedModel(resnet)
resnet = resnet.to(device)

deit_checkpoint_file = "advdeit_small.pth"
deit_model_name = "deit_small_patch16_224"
deit = timm.create_model(deit_model_name, checkpoint_path=os.path.join(CHECKPOINTS_DIR, deit_checkpoint_file))
deit = NormalizedModel(deit)
deit = deit.to(device)

In [None]:
n_examples = 16

In [None]:
xcit_x, xcit_x_adv, xcit_y = get_synth_examples(xcit, device, n_examples)

In [None]:
resnet_x, resnet_x_adv, resnet_y = get_synth_examples(resnet, device, n_examples)

In [None]:
deit_x, deit_x_adv, deit_y = get_synth_examples(deit, device, n_examples)

In [None]:
show_grid(xcit_x)

In [None]:
print(", ".join(IMAGENET_LABELS[l] for l in xcit_y))

In [None]:
show_grid(xcit_x_adv)

In [None]:
show_grid(resnet_x_adv)

In [None]:
show_grid(deit_x_adv)