In [None]:
# fix imports
import os
import sys

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import torch
from notebooks.experiment_robust import load_robust_experiment
from notebooks.experiment_torch import load_torchvision_experiment

model, dl_train, dl_eval = load_robust_experiment("Standard", "cifar10")
# model, dl_train, dl_eval = load_torchvision_experiment("vgg16")

In [None]:
from ulib import StopCriteria, PertModule
from ulib.attacks.ae_uap import AE_UAP, AE_MIFGSM

eps = 8 / 255
eps2 = 4 / 255

stop = StopCriteria(max_epochs=5, max_time=600)
torch_attack = AE_MIFGSM(model, eps=eps2, alpha=eps2 / 10, steps=10, decay=1)
pert_model = PertModule(model, data_shape=dl_train.get_tensor(0).shape[1:], eps=eps, random_init=False)
optimizer = torch.optim.Adam(pert_model.parameters(), lr=0.005)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=1)
grad_scaler = torch.GradScaler(device=pert_model.device.type)
gamma = 0.5

attack = AE_UAP(
    pert_model=pert_model,
    optimizer=optimizer,
    inner_attack=torch_attack,
    gamma=gamma,
    scheduler=scheduler,
    grad_scaler=grad_scaler,
    eval_freq=1,
)

pert = attack.fit(dl_train, dl_eval, stop)

In [None]:
attack.close()

In [None]:
from ulib.evaluator import ExtendedEvaluator
from ulib.utils.visualize import display_pert

evaluator = ExtendedEvaluator(pert_model, verbose=True)
print(evaluator.evaluate(dl_eval))
display_pert(pert_model)