In [None]:
!pip install diffprivlib
!pip install foolbox
!pip install torchopt

# Get ImageNet to Colab from Drive
from google.colab import drive
drive.mount('/content/drive')
!cp /content/drive/MyDrive/ILSVRC2012_img_val.tar /content/dataset/ImageNet/
# !cp /content/drive/MyDrive/ILSVRC2012_devkit_t12.tar.gz /content/dataset/ImageNet/

import foolbox as fb
import functorch
import json
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import torch
import torchopt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import warnings
from pathlib import Path
from PIL import Image
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import ImageNet
from torchvision.models import (
    densenet121, DenseNet121_Weights,
    mnasnet1_3, MNASNet1_3_Weights,
    regnet_y_800mf, RegNet_Y_800MF_Weights,
    resnet18, ResNet18_Weights,
)
from tqdm.notebook import tqdm
from diffprivlib.mechanisms import (
    Gaussian,
    GaussianAnalytic,
)
from sklearn.metrics import (
    accuracy_score,
    average_precision_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
warnings.filterwarnings("ignore", message='.*make_functional.*')

def reconstruct_interactions_adam(
    trainer,
    target_params,
    num_items,
    loss_fn=F.mse_loss,
    num_epochs=1,
    device=torch.device("cpu"),
    progress_bar=False,
    **kwargs,
):
    opt_params = nn.Parameter(torch.rand(num_items, device=device) * 2 - 1)
    optimizer = optim.Adam([opt_params], **kwargs)
    best_loss = math.inf
    best_params = None
    iterator = tqdm(range(num_epochs+1)) if progress_bar else range(num_epochs+1)
    for i in iterator:
        optimizer.zero_grad()
        shadow_params = trainer(opt_params.sigmoid())
        loss = loss_fn(shadow_params, target_params)
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_params = opt_params.detach()
        if progress_bar:
            iterator.set_description(f"Loss: {round(loss.detach().item(), 5)}")
        if i == num_epochs:
            break
        loss.backward(inputs=[opt_params])
        optimizer.step()
    return best_params

def optimize_image_manipulation(
    data, target, feature_extractor,
    max_epochs=100,
    loss_fn=F.mse_loss,
    early_stop=1e-03,
    linf_factor=0.0,
    progress_bar=False,
    **kwargs,
):
    opt_data = torch.clone(data).requires_grad_(True)
    optimizer = optim.Adam([opt_data], **kwargs)
    max_vals = [(1.0 - 0.485)/0.229, (1.0 - 0.456)/0.224, (1.0 - 0.406)/0.225]
    min_vals = [-0.485/0.229, -0.456/0.224, -0.406/0.225]

    iterator = tqdm(range(max_epochs)) if progress_bar else range(max_epochs)
    for _ in iterator:
        optimizer.zero_grad()
        extracted_features = feature_extractor(opt_data)
        linf_loss = linf_factor * torch.linalg.vector_norm(opt_data - data, ord=float('inf'))
        loss = loss_fn(extracted_features, target) + linf_loss
        loss.backward(inputs=[opt_data])
        optimizer.step()
        with torch.no_grad():
            for c in range(3):
                opt_data[:, c].clip_(min = min_vals[c], max=max_vals[c])
        if progress_bar:
            iterator.set_description(f"Loss: {round(loss.detach().item(), 5)}")

        if early_stop is not None and loss.item() < early_stop:
            break

    return opt_data.detach()


def optimize_image_manipulation_batches(
    data, target, feature_extractor, batch_size=128, **kwargs,
):
    num_data = data.shape[0]
    num_batches = int(math.ceil(num_data / batch_size))
    batch_res = []
    for i in range(num_batches):
        start = i * batch_size
        end = start + batch_size
        batch_data = data[start:end,:]
        batch_target = target[start:end,:]
        batch_opt = optimize_image_manipulation(batch_data, batch_target, feature_extractor, **kwargs)
        batch_res.append(batch_opt)
    return torch.vstack(batch_res)

class Metrics:
    def __init__(self, path=None):
        if path is not None:
            self.df = pd.read_csv(path)
        else:
            self.df = pd.DataFrame(
                {
                    "name": [],
                    "accuracy": [],
                    "f1": [],
                    "precision": [],
                    "recall": [],
                    "auc": [],
                    "auc-pr": [],
                    "extra_data": [],
                }
            )

    def update(self, name, target, preds, preds_raw=None, extra_data={}):
        row = {
            "name": name,
            "accuracy": accuracy_score(target, preds),
            "f1": f1_score(target, preds),
            "precision": precision_score(target, preds, zero_division=0),
            "recall": recall_score(target, preds),
            "auc": None if preds_raw is None else roc_auc_score(target, preds_raw),
            "auc-pr": None
            if preds_raw is None
            else average_precision_score(target, preds_raw),
            "extra_data": json.dumps(extra_data),
        }
        self.df.loc[len(self.df.index), :] = row

    def get_dataframe(self):
        return self.df

    def save(self, path):
        self.df.to_csv(path, index=False)

    def load(self, path):
        self.df = pd.read_csv(path)

    def print_summary(self, metrics=["auc"]):
        print(self.df[["name"] + metrics].groupby("name").describe().to_string())

def apply_gaussian_mechanism(input, epsilon, delta, sensitivity, scale_only=False):
    if math.isinf(epsilon):
        return input
    # Clip L2 norm to 0.5 * sensitivity (since global L2 sensitivity = 2 * max L2 norm)
    output = input * torch.minimum(torch.tensor(1.0), 0.5 * sensitivity / torch.linalg.vector_norm(input))
    if scale_only:
        return output

    # Add noise
    mechanism = (Gaussian if epsilon <= 1.0 else GaussianAnalytic)(
        epsilon=epsilon, delta=delta, sensitivity=sensitivity
    )
    return output.apply_(mechanism.randomise)

In [None]:
# Change the feature extraction model here.
# Possible values: ResNet18, MNasNet1_3, DenseNet121, RegNet_Y_800MF
chosen_model = "ResNet18"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
unnormalize = transforms.Normalize(
   mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
   std=[1/0.229, 1/0.224, 1/0.225]
)

def set_seed(seed=2023):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

def tensor2img(tensor):
    permute_dim = (1, 2, 0) if len(tensor.shape) == 3 else (0, 2, 3, 1)
    return unnormalize(tensor).permute(*permute_dim)

def display_tensor_as_img(t, unnormalize=True, size=1):
    plt.figure(figsize=(size, size))
    plt.imshow(tensor2img(t).cpu() if unnormalize else t.cpu(), aspect="auto")
    plt.axis("off")
    plt.tight_layout(pad=0)
    plt.show()

# Model code, weight, and image partitions
model_config = {
    "DenseNet121": (densenet121, DenseNet121_Weights.IMAGENET1K_V1, 1),
    "MNasNet1_3": (mnasnet1_3, MNASNet1_3_Weights.IMAGENET1K_V1, 2),
    "RegNet_Y_800MF": (regnet_y_800mf, RegNet_Y_800MF_Weights.IMAGENET1K_V2, 2),
    "ResNet18": (resnet18, ResNet18_Weights.IMAGENET1K_V1, 2),
}
opt_basepath = f"./dataset/CV/{chosen_model}_partitioned"
weights = model_config[chosen_model][1]
cv_model = model_config[chosen_model][0](weights=weights).to(device).eval()
num_partitions = model_config[chosen_model][2]
fb_cv_model = fb.PyTorchModel(cv_model, bounds=(-2.65, 2.65), device=device)
for p in cv_model.parameters():
    p.grad = None
    p.requires_grad_(False)
imagenet_categories = {i: v for i, v in enumerate(weights.meta["categories"])}
last_layer = list(cv_model.children())[-1] # the last layer outputs the class logits which we don't need

# This runs inference on the model and capture the output of the second-to-last layer
def extract_features(inputs, batch_size=None):
    features = []
    def getInputs():
        def hook(model, input, output):
            features.append(input[0])
        return hook
    h = last_layer.register_forward_hook(getInputs())
    if batch_size is None:
        cv_model(inputs)
    else:
        num_batches = math.ceil(inputs.shape[0] / batch_size)
        for i in range(num_batches):
            start = i * batch_size
            end = start + batch_size
            cv_model(inputs[start:end,:])
    h.remove()
    return torch.vstack(features)

def dataset_with_indices(cls):
    """
    Modifies the given Dataset class to return a tuple data, target, index
    instead of just data, target.
    """
    def __getitem__(self, index):
        data, target = cls.__getitem__(self, index)
        return data, target, index

    return type(cls.__name__, (cls,), {
        '__getitem__': __getitem__,
    })

# Prepare ImageNet dataset
ImageNetWithIndices = dataset_with_indices(ImageNet)
data = ImageNetWithIndices("./dataset/ImageNet/", split="val", transform=weights.transforms())
data = Subset(data, list(range(5000))) # Scaled down for artifact eval, comment out for full run

# Calculate the number of features of the model
image, label, _ = data[0]
print(imagenet_categories[label])
display_tensor_as_img(image)
with torch.no_grad():
    image = image.unsqueeze(0).to(device)
    logits = cv_model(image)
    print(f"Prediction: {imagenet_categories[logits.argmax().cpu().item()]} | Actual: {imagenet_categories[label]}")
    num_features = extract_features(image).shape[1]
    print(f"Number of features extracted: {num_features}")

In [None]:
# Generate manipulated images
set_seed(2023)

batch_size = 256 # Lower batch size if out of GPU memory, e.g., 128 if 8GB, 256 if 12GB or more

def generate_optimized_data(data, base_path, num_partitions):
    dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
    num_features_per_part = num_features // num_partitions
    for partition in range(num_partitions):
        Path(f"{base_path}/p{partition}/").mkdir(parents=True, exist_ok=True)
    for images, _, indices in tqdm(dataloader):
        num_items = images.shape[0]
        images = images.to(device)
        for partition in range(num_partitions):
            target = torch.normal(0.0, 4.0, (num_items, num_features), device=device).clip(min=0.0)
            target[:,:partition*num_features_per_part].zero_()
            target[:,(partition+1)*num_features_per_part:].zero_()

            optimized_images = optimize_image_manipulation_batches(
                images,
                target,
                extract_features,
                batch_size=batch_size,
                max_epochs=500,
                loss_fn=F.mse_loss,
                lr=1e-02,
                linf_factor=0.0,
                progress_bar=True,
            )

            for i in range(num_items):
                # When converting from PIL to tensor this results in smaller MSE (around 2.5e-05)
                # than using transforms.ToPILImage with unnormalize (1e-04)
                save_img = Image.fromarray((255 * tensor2img(optimized_images[i].cpu())).numpy().round().astype(np.uint8))
                save_img.save(f"{base_path}/p{partition}/{indices[i]}.png")

generate_optimized_data(data, opt_basepath, num_partitions)

In [None]:
# Main results in Table VIII of Section VI.C (and Table IX of Section VII.A)

set_seed(2024)
model_fns = {
    "linear": lambda: nn.Linear(num_features, 1, bias=False).to(device),
    "neural2": lambda: nn.Sequential(
        nn.Linear(num_features, 2),
        nn.ReLU(),
        nn.Linear(2, 1)
    ).to(device),
    "neural4": lambda: nn.Sequential(
        nn.Linear(num_features, 4),
        nn.ReLU(),
        nn.Linear(4, 1)
    ).to(device),
    "neural8": lambda: nn.Sequential(
        nn.Linear(num_features, 8),
        nn.ReLU(),
        nn.Linear(8, 1)
    ).to(device),
}

# num_sim = 200
num_sim = 30 # Scaled down for artifact eval
num_items_per_query = [num_features * 1, num_features * 2, num_features * 4]
local_lr = 1e-02
local_epoch = 5

# Reconstruction attack parameters
num_atk = 1
max_iter = 1000
atk_lr = 0.1

# Differential privacy parameters
epsilons = [1.0, 20.0, 100.0, 500.0, math.inf]
delta = 1e-08
sensitivity = 0.05

# Traditional adversarial attack parameters
adv_attacks = {
    # "FGSM": fb.attacks.FGSM(), # Uncomment to enable FGSM manipulation, needs GPU to be fast
}
adv_epsilons = [0.1]
adv_batch_size = 16 # Reduce if out of GPU memory

metrics = Metrics()
extract_batch_size = 512 # Reduce this if not enough GPU memory

# Local training
def train(model, features, interactions):
    func_model, model_params = functorch.make_functional(model)
    opt_params = model_params
    optimizer = torchopt.FuncOptimizer(torchopt.sgd(lr=local_lr))
    for _ in range(local_epoch):
        preds = func_model(opt_params, features)
        loss = F.binary_cross_entropy_with_logits(preds.view(-1), interactions)
        opt_params = optimizer.step(loss, opt_params)
    model_params = torch.cat([p.view(-1) for p in model_params])
    opt_params = torch.cat([p.view(-1) for p in opt_params])
    return model_params - opt_params

# Traditional adversarial attack
def adv_perturb(fmodel, images, labels, attack, batch_size=None, **kwargs):
    if batch_size is None:
        _, clipped_advs, success = attack(fmodel, images, labels, **kwargs)
    else:
        clipped_advs_list = []
        success_list = []
        num_images = images.shape[0]
        num_batches = int(math.ceil(num_images / batch_size))
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            batch_images = images[start_idx:end_idx]
            batch_labels = labels[start_idx:end_idx]
            _, batch_clipped_advs, batch_success = attack(fmodel, batch_images, batch_labels, **kwargs)
            clipped_advs_list.append([batch for batch in batch_clipped_advs])
            success_list.append(batch_success)
            torch.cuda.empty_cache()
        clipped_advs = [
            torch.vstack([clipped_advs_list[j][i] for j in range(num_batches)]) for i in range(len(clipped_advs_list[0]))
        ]
        success = torch.hstack(success_list)

    return clipped_advs, success

# Load RAIFLE-manipulated images
def load_opt_images(indices, num_partitions):
    imgs = []
    n = len(indices) // num_partitions
    for c, i in enumerate(indices):
        partition = c // n
        imgs.append(normalize(Image.open(f"{opt_basepath}/p{partition}/{i}.png")))
    return torch.stack(imgs)

def simulate_attack(data, epsilons, num_items):
    dataloader = DataLoader(data, batch_size=num_items, shuffle=True)
    images, labels, indices = next(iter(dataloader))
    images = images.to(device)
    labels = labels.to(device)
    num_items = images.shape[0]
    target_interactions = torch.randint(0, 2, (num_items,)).float().to(device)
    grouped_train_data_dict = {
        "no_adm": extract_features(images, extract_batch_size),
    }

    for attack_name, attack in adv_attacks.items():
        adv_images_list, _, = adv_perturb(fb_cv_model, images, labels, attack, batch_size=adv_batch_size, epsilons=adv_epsilons)
        for eps, adv_images in zip(adv_epsilons, adv_images_list):
            grouped_train_data_dict[f"adm_{attack_name}_{eps}"] = extract_features(adv_images, extract_batch_size)

    optimized_images = load_opt_images(indices, num_partitions).to(device)
    grouped_train_data_dict["adm_opt"] = extract_features(optimized_images, extract_batch_size)
    del optimized_images

    for model_name, model_fn in model_fns.items():
        model = model_fn()
        raw_target_dict = {
            key: train(model, train_features, target_interactions).detach() for
                key, train_features in grouped_train_data_dict.items()
        }

        for epsilon in epsilons:
            for key, raw_target in raw_target_dict.items():
                target = (apply_gaussian_mechanism(raw_target.detach().cpu(), epsilon, delta, sensitivity)).to(device)
                train_features = grouped_train_data_dict[key]
                preds_raw = reconstruct_interactions_adam(
                    lambda I: train(model, train_features, I) / local_lr,
                    target / local_lr,
                    num_items,
                    num_epochs=300,
                    device=device,
                    lr=atk_lr,
                )
                preds_raw = preds_raw.detach().cpu()
                preds = preds_raw.sigmoid().round().long()

                metrics.update(
                    f"{model_name}_{num_items}_items_eps_{epsilon}_{key}",
                    target_interactions.cpu(),
                    preds,
                    preds_raw=preds_raw,
                )

for _ in tqdm(range(num_sim)):
    for num_items in num_items_per_query:
        simulate_attack(data, epsilons, num_items)
        torch.cuda.empty_cache()

metrics.print_summary()
metrics.save("./output/ltr_cv_metrics.csv")