In [None]:
import foolbox as fb
import functorch
import math
import matplotlib.pyplot as plt
import numpy as np
import piqa
import random
import torch
import torchopt
import torch.nn as nn
import torch.nn.functional as F
import warnings
from attack import (
    reconstruct_interactions,
    optimize_image_manipulation_batches,
)
from PIL import Image
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import ImageNet
from torchvision.models import (
    regnet_y_400mf, RegNet_Y_400MF_Weights,
    resnet18, ResNet18_Weights,
    mobilenet_v3_small, MobileNet_V3_Small_Weights,
)
from tqdm.notebook import tqdm
from utils import (
    Metrics,
    apply_gaussian_mechanism,
)
warnings.filterwarnings("ignore", message='.*make_functional.*')

In [None]:
def set_seed(seed=2023):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

unnormalize = transforms.Normalize(
   mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
   std=[1/0.229, 1/0.224, 1/0.225]
)

def tensor2img(tensor):
    permute_dim = (1, 2, 0) if len(tensor.shape) == 3 else (0, 2, 3, 1)
    return unnormalize(tensor).permute(*permute_dim)

def display_tensor_as_img(t, unnormalize=True, size=1):
    plt.figure(figsize=(size, size))
    plt.imshow(tensor2img(t).cpu() if unnormalize else t.cpu(), aspect="auto")
    plt.axis("off")
    plt.tight_layout(pad=0)
    plt.show()

model_config = {
    "RegNet_Y_400MF": (regnet_y_400mf, RegNet_Y_400MF_Weights.IMAGENET1K_V2),
    "ResNet18": (resnet18, ResNet18_Weights.IMAGENET1K_V1),
    "MobileNet_V3_Small": (mobilenet_v3_small, MobileNet_V3_Small_Weights.IMAGENET1K_V1),
}
chosen_model = "RegNet_Y_400MF"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
weights = model_config[chosen_model][1]
cv_model = model_config[chosen_model][0](weights=weights).to(device).eval()
opt_basepath = f"../dataset/CV/RegNet_Y_400MF"
fb_cv_model = fb.PyTorchModel(cv_model, bounds=(-2.65, 2.65), device=device)
feature_extractor = nn.Sequential(*(list(cv_model.children())[:-1])).to(device).eval()
for p in cv_model.parameters():
    p.grad = None
    p.requires_grad_(False)
for p in feature_extractor.parameters():
    p.grad = None
    p.requires_grad_(False) 
imagenet_categories = {i: v for i, v in enumerate(weights.meta["categories"])}

def dataset_with_indices(cls):
    """
    Modifies the given Dataset class to return a tuple data, target, index
    instead of just data, target.
    """
    def __getitem__(self, index):
        data, target = cls.__getitem__(self, index)
        return data, target, index

    return type(cls.__name__, (cls,), {
        '__getitem__': __getitem__,
    })

ImageNetWithIndices = dataset_with_indices(ImageNet)
data = ImageNetWithIndices("../dataset/ImageNet/", split="val", transform=weights.transforms())
data_dog = Subset(data, list(range(151*50, 269*50)))
image, label, _ = data[0]
print(imagenet_categories[label])
display_tensor_as_img(image)
with torch.no_grad():
    image = image.unsqueeze(0).to(device)
    logits = cv_model(image)
    print(f"Prediction: {imagenet_categories[logits.argmax().cpu().item()]} | Actual: {imagenet_categories[label]}")
    num_features = feature_extractor(image).shape[1]
    print(f"Number of features extracted: {num_features}")

In [None]:
# Perturb each image and save to disk
set_seed(2023)

batch_size = 256
def extract_features(images):
    return feature_extractor(images).squeeze()[:, :num_features]

def generate_optimized_data(data, base_path):
    dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
    idx = 0
    for images, _, _ in tqdm(dataloader):
        num_items = images.shape[0]
        images = images.to(device)
        
        optimized_images = optimize_image_manipulation_batches(
            images,
            torch.normal(0.0, 4.0, (num_items, num_features), device=device).clip(min=0.0),
            extract_features,
            batch_size=batch_size,
            max_epochs=500,
            loss_fn=F.mse_loss,
            lr=1e-02,
            linf_factor=0.0,
            progress_bar=False,
        )

        for i in range(num_items):
            # When converting from PIL to tensor this results in smaller MSE (around 2.5e-05) 
            # than using transforms.ToPILImage with unnormalize (1e-04)
            save_img = Image.fromarray((255 * tensor2img(optimized_images[i].cpu())).numpy().round().astype(np.uint8))
            save_img.save(f"{base_path}/{idx}.jpg")
            idx += 1

generate_optimized_data(data, opt_basepath)

In [None]:
def measure_similarity(orig_data, target_path):
    psnr = piqa.PSNR()
    ssim = piqa.SSIM()
    psnrs = []
    ssims = []
    to_tensor = transforms.ToTensor()

    for i in tqdm(range(len(orig_data))):
        img = unnormalize(orig_data[i][0]).unsqueeze(0)
        target_img = to_tensor(Image.open(f"{target_path}/{i}.jpg")).unsqueeze(0)
        psnrs.append(psnr(img, target_img).item())
        ssims.append(ssim(img, target_img).item())
        break

    return np.mean(psnrs), np.mean(ssims)

print(measure_similarity(data, opt_basepath))

In [None]:
set_seed(2024)

model_fns = {
    "linear": lambda: nn.Linear(num_features, 1, bias=False).to(device),
    "neural2": lambda: nn.Sequential(
        nn.Linear(num_features, 2, bias=False),
        nn.ReLU(),
        nn.Linear(2, 1, bias=False)
    ).to(device),
    "neural3": lambda: nn.Sequential(
        nn.Linear(num_features, 3, bias=False),
        nn.ReLU(),
        nn.Linear(3, 1, bias=False)
    ).to(device),
}

num_sim = 10
num_items_per_query = [num_features * 1, num_features * 2, num_features * 3]
local_lr = 1e-02
local_epoch = 5

# Reconstruction
num_atk = 1
max_iter = 1000
atk_lr = 0.1

epsilons = [1.0, 20.0, 100.0, 500.0, math.inf]
delta = 1e-08
sensitivity = 0.05

adv_attacks = {
    "FGSM": fb.attacks.FGSM(),
}
adv_epsilons = [0.1]

metrics = Metrics()

def extract_features(images):
    return feature_extractor(images).squeeze()[:, :num_features]

def train(model, features, interactions):
    func_model, model_params = functorch.make_functional(model)
    opt_params = model_params
    optimizer = torchopt.FuncOptimizer(torchopt.sgd(lr=local_lr))
    for _ in range(local_epoch):
        preds = func_model(opt_params, features)
        loss = F.binary_cross_entropy_with_logits(preds.view(-1), interactions)
        opt_params = optimizer.step(loss, opt_params)
    model_params = torch.cat([p.view(-1) for p in model_params])
    opt_params = torch.cat([p.view(-1) for p in opt_params])
    return model_params - opt_params

def adv_perturb(fmodel, images, labels, attack, batch_size=None, **kwargs):
    if batch_size is None:
        _, clipped_advs, success = attack(fmodel, images, labels, **kwargs)
    else:
        clipped_advs_list = []
        success_list = []
        num_images = images.shape[0]
        num_batches = int(math.ceil(num_images / batch_size))
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            batch_images = images[start_idx:end_idx]
            batch_labels = labels[start_idx:end_idx]
            _, batch_clipped_advs, batch_success = attack(fmodel, batch_images, batch_labels, **kwargs)
            clipped_advs_list.append([batch for batch in batch_clipped_advs])
            success_list.append(batch_success)
            torch.cuda.empty_cache()
        clipped_advs = [
            torch.vstack([clipped_advs_list[j][i] for j in range(num_batches)]) for i in range(len(clipped_advs_list[0]))
        ]
        success = torch.hstack(success_list)

    clean_acc = fb.accuracy(fmodel, images, labels)
    return clipped_advs, success, clean_acc

def load_opt_images(indices):
    imgs = []
    for i in indices:
        imgs.append(normalize(Image.open(f"{opt_basepath}/{i}.jpg")))
    return torch.stack(imgs)

def simulate_attack(data, epsilons, num_items):
    dataloader = DataLoader(data, batch_size=num_items, shuffle=True)
    images, labels, indices = next(iter(dataloader))
    images = images.to(device)
    labels = labels.to(device)
    num_items = images.shape[0]
    target_interactions = torch.randint(0, 2, (num_items,)).float().to(device)
    extracted_features = extract_features(images)
    grouped_train_data_dict = {
        "no_adm": extracted_features,
    }
    
    for attack_name, attack in adv_attacks.items():
        adv_images_list, _, _ = adv_perturb(fb_cv_model, images, labels, attack, batch_size=16, epsilons=adv_epsilons)
        for eps, adv_images in zip(adv_epsilons, adv_images_list):
            grouped_train_data_dict[f"adm_{attack_name}_{eps}"] = extract_features(adv_images)

    optimized_images = load_opt_images(indices).to(device)
    grouped_train_data_dict["adm_opt"] = extract_features(optimized_images)
    del optimized_images

    for model_name, model_fn in model_fns.items():
        model = model_fn()
        raw_target_dict = {
            key: train(model, train_features, target_interactions).detach() for
                key, train_features in grouped_train_data_dict.items()
        }

        for epsilon in epsilons:
            for key, raw_target in raw_target_dict.items():
                target = (apply_gaussian_mechanism(raw_target.detach().cpu(), epsilon, delta, sensitivity)).to(device)
                train_features = grouped_train_data_dict[key]
                preds_raw, _ = reconstruct_interactions(
                    lambda I: train(model, train_features, I) / local_lr,
                    target / local_lr,
                    num_items,
                    lr=atk_lr,
                    max_iter=max_iter,
                    num_rounds=num_atk,
                    return_raw=True,
                    device=device,
                )
                preds_raw = preds_raw.detach().cpu()
                preds = preds_raw.sigmoid().round().long()

                metrics.update(
                    f"{model_name}_{num_items}_items_eps_{epsilon}_{key}",
                    target_interactions.cpu(),
                    preds,
                    preds_raw=preds_raw,
                )

for _ in tqdm(range(num_sim)):
    for num_items in num_items_per_query:
        simulate_attack(data, epsilons, num_items)
        torch.cuda.empty_cache()

metrics.print_summary()
# metrics.save("../output/ltr_cv_metrics.csv")