In [1]:
# fix imports
import os
import sys

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import torch
from torch import nn, Tensor
from notebooks.analyze_utils import *
import mealpy
from mealpy import FloatVar, GA, BoolVar
import numpy as np
from analyze_utils import *

In [None]:
import torch
from notebooks.experiment_robust import load_robust_experiment
from notebooks.experiment_torch import load_torchvision_experiment

model, dl_train, dl_eval = load_robust_experiment("Standard", "cifar10")
# model, dl_train, dl_eval = load_torchvision_experiment("vgg16")

In [None]:
from typing import Iterable
from ulib.pert_module import PertModule
from ulib import eval
from mealpy import Problem


class BinaryObjective(Problem):
    def __init__(
        self,
        pert_model: PertModule,
        dl_train: Iterable[tuple[torch.Tensor, ...]],
    ):
        self.pert_model = pert_model
        self.dl_train = dl_train
        self.shape = self.pert_model.shape

        super().__init__(
            bounds=BoolVar(n_vars=np.prod(self.shape).item()),
            minimax="min",
        )

    def rescale(self, pert_np: np.ndarray) -> torch.Tensor:
        pert = torch.tensor(pert_np, dtype=torch.float32, device=self.pert_model.device).view(self.shape)
        pert = 2 * pert - 1  # Normalize to [-1, 1]
        pert = pert * self.pert_model.eps  # Scale to [-eps, eps]
        return pert
    
    def obj_func(self, x: np.ndarray) -> float:
        pert = self.rescale(x)
        self.pert_model.set_pert(pert)
        acc = eval.accuracy(self.pert_model, self.dl_train)
        return acc

In [None]:
pert_model = PertModule(
    model,
    data_shape=dl_train.data[0].shape,
    eps=8 / 255,
    norm=float("inf"),
    random_init=False,
    input_range=(0, 1),
)

problem = BinaryObjective(pert_model, dl_train)
optimizer = mealpy.PSO.OriginalPSO()

In [None]:
g_best = optimizer.solve(
    problem=problem,
    mode="single",
    termination={"max_time": 2 * 60},
    seed=42,
)

In [None]:
from ulib import eval

best_pert = problem.rescale(g_best.solution)
pert_model = problem.pert_model
pert_model.set_pert(best_pert)
eval.full_analysis(pert_model, dl_eval)