In [1]:
import functorch
import math
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import torchopt
import torch.nn as nn
import torch.nn.functional as F
import warnings
from attack import (
    reconstruct_interactions,
)
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import ImageNet
from torchvision.models import regnet_y_400mf, RegNet_Y_400MF_Weights, regnet_x_400mf, resnet18, mobilenet_v3_small
from tqdm.notebook import tqdm
from utils import (
    Metrics,
    apply_gaussian_mechanism,
)
warnings.filterwarnings("ignore", message='.*make_functional.*')

In [None]:
def set_seed(seed=2023):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

unnormalize = transforms.Normalize(
   mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
   std=[1/0.229, 1/0.224, 1/0.225]
)

def tensor2img(tensor):
    return unnormalize(tensor).permute(1, 2, 0)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
weights = RegNet_Y_400MF_Weights.IMAGENET1K_V2
cv_model = regnet_y_400mf(weights=weights).to(device).eval()
feature_extractor = nn.Sequential(*(list(cv_model.children())[:-1])).to(device).eval()
imagenet_categories = {i: v for i, v in enumerate(weights.meta["categories"])}

data = ImageNet("../dataset/ImageNet/", split="val", transform=weights.transforms())
data_dog = Subset(data, list(range(151*50, 269*50)))
image, label = data[0]
print(imagenet_categories[label])
plt.figure(figsize=(1, 1))
plt.imshow(tensor2img(image).numpy(), aspect="auto")
plt.axis("off")
plt.tight_layout(pad=0)
plt.show()
with torch.no_grad():
    image = image.unsqueeze(0).to(device)
    logits = cv_model(image)
    print(f"Prediction: {imagenet_categories[logits.argmax().cpu().item()]} | Actual: {imagenet_categories[label]}")
    num_features = feature_extractor(image).shape[1]
    print(f"Number of features extracted: {num_features}")

In [None]:
# Print animal categories (0-397). Dogs: 151-268
last_animal_category = 397
for i, v in imagenet_categories.items():
    if i > last_animal_category:
        break
    print(f"{v}: {i}")

In [None]:
set_seed()

model_fns = {
    "linear": lambda: nn.Linear(num_features, 1, bias=False).to(device),
}

num_sim = 10
num_items_per_query = 1320
local_lr = 1e-02
local_epoch = 5

# Reconstruction
num_atk = 1
max_iter = 1000
atk_lr = 0.1

epsilons = [1.0, 10.0, 20.0, 100.0, 500.0, math.inf]
delta = 1e-08
sensitivity = 0.05

metrics = Metrics()

def train(model, features, interactions):
    func_model, model_params = functorch.make_functional(model)
    opt_params = model_params
    optimizer = torchopt.FuncOptimizer(torchopt.sgd(lr=local_lr, momentum=0.9))
    for _ in range(local_epoch):
        preds = func_model(opt_params, features)
        loss = F.binary_cross_entropy_with_logits(preds.view(-1), interactions)
        opt_params = optimizer.step(loss, opt_params)
    return model_params[0] - opt_params[0]

def simulate_attack(data, model, model_name, epsilons):
    dataloader = DataLoader(data, batch_size=num_items_per_query, shuffle=True)
    images, _ = next(iter(dataloader))
    images = images.to(device)
    num_items = images.shape[0]
    target_interactions = torch.randint(0, 2, (num_items,)).float().to(device)
    with torch.no_grad():
        features = feature_extractor(images).squeeze()

    grouped_train_data_dict = {
        "no_adm": (features, target_interactions),
    }
    
    raw_target_dict = {
        key: train(model, train_features, train_interactions).detach() for
            key, (train_features, train_interactions) in grouped_train_data_dict.items()
    }

    for epsilon in epsilons:
        for key, raw_target in raw_target_dict.items():
            target = (apply_gaussian_mechanism(raw_target.detach().cpu(), epsilon, delta, sensitivity)).to(device)
            train_features, train_interactions = grouped_train_data_dict[key]
            preds_raw, _ = reconstruct_interactions(
                lambda I: train(model, train_features, I) / local_lr,
                target / local_lr,
                num_items,
                lr=atk_lr,
                max_iter=max_iter,
                num_rounds=num_atk,
                return_raw=True,
                device=device,
            )
            preds_raw = preds_raw.detach().cpu()
            preds = preds_raw.sigmoid().round().long()

            metrics.update(
                f"{model_name}_{num_items_per_query}_query_eps_{epsilon}_{key}",
                train_interactions.cpu(),
                preds,
                preds_raw=preds_raw,
            )

for _ in tqdm(range(num_sim)):
    for model_name, model_fn in model_fns.items():
        torch.cuda.empty_cache()
        simulate_attack(data, model_fn(), model_name, epsilons)

metrics.print_summary()
# metrics.save("../output/ltr_cv_metrics.csv")