In [1]:
# fix imports
import os
import sys

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import torch
from notebooks.experiment_robust import load_robust_experiment
from notebooks.experiment_torch import load_torchvision_experiment

model, dl_train, dl_eval = load_robust_experiment("Standard", "cifar10")
# model, dl_train, dl_eval = load_torchvision_experiment("vgg16")

In [None]:
class Negate(torch.nn.Module):
    def __init__(self, model: torch.nn.Module):
        super().__init__()
        self.model = model

    def forward(self, *args, **kwargs):
        return -self.model(*args, **kwargs)

In [None]:
from ulib.attack import StopCriteria
from ulib.attacks.ufgsm import UFGSM
from ulib.pert_module import PertModule

stop = StopCriteria(max_epochs=10, max_time=600)
pert_model = PertModule(model, data_shape=dl_train.get_tensor(0).shape[1:], eps=8 / 255)
optimizer = torch.optim.SGD(pert_model.parameters(), lr=1e-3)
criterion = Negate(torch.nn.CrossEntropyLoss())
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=1, eta_min=5e-5)
grad_scaler = torch.GradScaler(device=pert_model.device.type)

attack = UFGSM(
    pert_model=pert_model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    grad_scaler=grad_scaler,
)

pert = attack.fit(dl_train, dl_eval, stop)

In [6]:
attack.close()

In [None]:
from ulib import eval

eval.full_analysis(pert_model, dl_eval)