In [1]:
import os
import tqdm

import torch
import torch.nn as nn

import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import matplotlib.pyplot as plt

from privacy_lint.attacks.loss import LossAttack, compute_loss_cross_entropy

%matplotlib inline  
%config InlineBackend.figure_format='retina'

Gather train and test scores
====

In [2]:
def get_dataloader(path, batch_size=256, num_workers=8):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    
    dataset = datasets.ImageFolder(
        path, transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ]))
    
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers, 
        pin_memory=True
    )
    
    return dataloader

In [3]:
imagenet_path = "/datasets01/imagenet_full_size/061417/"
batch_size = 1024

model = models.resnet18(pretrained=True).eval()


train_path = os.path.join(imagenet_path, 'val')
test_path = os.path.join(imagenet_path, 'val')
train_loader = get_dataloader(train_path, batch_size=batch_size)
test_loader = get_dataloader(test_path, batch_size=batch_size)

Attack unbalanced
====

In [4]:
attack = LossAttack(compute_loss=compute_loss_cross_entropy)
loss_results_unbalanced = attack.launch(model, train_loader, test_loader)

100%|██████████| 49/49 [03:17<00:00,  4.04s/it]
100%|██████████| 49/49 [03:14<00:00,  3.98s/it]


In [5]:
max_accuracy_threshold, max_accuracy = loss_results_unbalanced.get_max_accuracy_threshold()
print(f"Max accuracy threshold: {-max_accuracy_threshold:.2f}, max accuracy: {max_accuracy*100:.2f}%")

Max accuracy threshold: -0.00, max accuracy: 50.02%


Attack balanced
===

In [8]:
loss_results_balanced = loss_results_unbalanced.balance()

max_accuracy_threshold, max_accuracy = loss_results_balanced.get_max_accuracy_threshold()
print(f"Max accuracy threshold: {-max_accuracy_threshold:.2f}, max accuracy: {max_accuracy*100:.2f}%")

Max accuracy threshold: 0.00, max accuracy: 50.01%
