In [None]:
# fix imports
import os
import sys

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import torch
from notebooks.experiment_robust import load_robust_experiment
from notebooks.experiment_torch import load_torchvision_experiment

model, dl_train, dl_eval = load_robust_experiment("Salman2020Do_R50", "imagenet", batch_size=128)

In [None]:
model

In [None]:
import torchattacks

attk_eps = 16 / 255
attk_alpha = 8 / 255
attk_steps = 5

# torch_attack = torchattacks.FGSM(model, eps=attk_eps) # very fast; good
# torch_attack = torchattacks.FFGSM(model, eps=attk_eps, alpha=attk_alpha)  # very fast; very good
# torch_attack = torchattacks.DIFGSM(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # fast
# torch_attack = torchattacks.RFGSM(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # fast; very good
# torch_attack = torchattacks.BIM(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # fast
# torch_attack = torchattacks.TIFGSM(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # fast
# torch_attack = torchattacks.MIFGSM(model, eps=attk_eps, alpha=attl_alpha, steps=attk_steps)  # fast
# torch_attack = torchattacks.NIFGSM(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # fast
# torch_attack = torchattacks.PGDL2(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # fast
# torch_attack = torchattacks.PGD(model) # fast
# torch_attack = torchattacks.TPGD(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # fast
torch_attack = torchattacks.UPGD(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps, loss="dlr") # fast; very good
# torch_attack = torchattacks.APGD(model, eps=attk_eps, steps=attk_steps, loss="dlr") # fast; good
# torch_attack = torchattacks.Jitter(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # fast; good
# torch_attack = torchattacks.PIFGSM(model, max_epsilon=attk_eps, num_iter_set=attk_steps) # normal
# torch_attack = torchattacks.PIFGSMPP(model, max_epsilon=attk_eps, num_iter_set=attk_steps) # normal
# torch_attack = torchattacks.PGDRS(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # normal
# torch_attack = torchattacks.EOTPGD(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # normal
# torch_attack = torchattacks.PGDRSL2(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # slow
# torch_attack = torchattacks.SINIFGSM(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # slow
# torch_attack = torchattacks.VMIFGSM(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # very slow
# torch_attack = torchattacks.EADEN(model, lr=attk_alpha, max_iterations=attk_steps) # very slow
# torch_attack = torchattacks.VNIFGSM(model, eps=attk_eps, alpha=attk_alpha, steps=attk_steps) # very slow
# torch_attack = torchattacks.APGDT(model, eps=attk_eps, steps=attk_steps) # very slow
# torch_attack = torchattacks.SPSA(model, eps=attk_eps, lr=attk_alpha) # very slow
# torch_attack = torchattacks.FAB(model, eps=attk_eps, steps=attk_steps) # very slow
# torch_attack = torchattacks.CW(model, steps=attk_steps, lr=attk_alpha) # very slow
# torch_attack = torchattacks.AutoAttack(model, eps=attk_eps) # super slow
# torch_attack = torchattacks.Square(model, eps=attk_eps) # super slow
# torch_attack = torchattacks.DeepFool(model, steps=attk_steps) # super slow

In [None]:
from ulib.utils.torch import extract_device
from tqdm.auto import tqdm


device = extract_device(model)
# torch_attack = torchattacks.CW(model, steps=5, lr=attk_eps / 3)

num_total = 0
num_misclassified = 0

for batch in tqdm(dl_eval):
    with torch.device(device):
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)

        adv_images = torch_attack.forward(images, labels)

        with torch.inference_mode():
            pert = adv_images - images
            pert = torch.clamp(pert, -attk_eps, attk_eps)
            adv_input = torch.clamp(images + pert, 0, 1)
            adv_preds = model(adv_input).argmax(dim=1)

        num_total += labels.size(0)
        num_misclassified += (adv_preds != labels).sum().item()

In [None]:
print(num_misclassified / num_total)