In [4]:
import os
import glob
import time
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.metrics import roc_curve, auc, accuracy_score, confusion_matrix
from scipy.stats import wasserstein_distance
from scipy.stats import norm, ks_2samp, ttest_ind
from scipy.special import kl_div, logit

## to zip two lists of different lengths
from itertools import zip_longest
from utils import simple_mia

import torch
import torch.nn as nn
from torchvision.models import resnet18

from sklearn.metrics import roc_curve, make_scorer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, train_test_split

from models import DefenderOPT

import warnings
warnings.filterwarnings("ignore")

plt.rcParams['figure.figsize'] = (5, 3)

%reload_ext autoreload
%autoreload 2

In [3]:
### some auxiliary functions

def compute_losses(net, loader, device):
    """Auxiliary function to compute per-sample losses"""

    criterion = nn.CrossEntropyLoss(reduction="none")
    all_losses = []

    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)

        logits = net(inputs)
        losses = criterion(logits, targets).detach().cpu().numpy()
        for l in losses:
            all_losses.append(l)

    return np.array(all_losses)


# Define custom scoring function
def custom_tpr_at_fpr(y_true, y_prob, desired_fprs=[0.01, 0.05, 0.1]):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    tprs_at_desired_fprs = np.interp(desired_fprs, fpr, tpr)
    return np.array(tprs_at_desired_fprs)

# Custom scorer using the custom scoring function
def custom_scorer(clf, X, y):
    y_prob = clf.predict_proba(X)[:, 1]  # probabilities for the positive class
    return custom_tpr_at_fpr(y, y_prob)

## Compare SG against the baselines

In [6]:


dataset = 'cifar10'
batch_size = 128
device_id = 2
device = f'cuda:{device_id}'
num_classes = 10 if dataset == 'cifar10' else 100
## 'GA', 'fisher_new', 'wfisher', 'IU', 'FT', 'retrain’
baseline = 'fisher_new'
num_epoch = 30


tprs_ret = np.zeros(3)
acc_ret = []
for seed in range(1, 11):
    RNG = torch.Generator().manual_seed(seed)
    SG_data = torch.load(f'../result/SG_data/SGdata_seed_{seed}_{dataset}.pth')
    # SG_data = torch.load(f'SGdata_seed_{seed}_{dataset}.pth')
    
    
    retain_dataset = SG_data['retain']
    test_dataset = SG_data['test']
    val_dataset = SG_data['val']
    forget_dataset = SG_data['forget']
    
    retain_loader = torch.utils.data.DataLoader(
        retain_dataset, batch_size=batch_size, shuffle=True, num_workers=2, generator=RNG)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, generator=RNG)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, generator=RNG)
    forget_loader = torch.utils.data.DataLoader(
        forget_dataset, batch_size=batch_size, shuffle=True, num_workers=2, generator=RNG)

    evaluator = DefenderOPT(retain_loader, 
                            forget_loader, 
                            val_loader, 
                            test_loader,
                            baseline_mode=1,
                            cv=3,
                            dim=1,
                            seed=seed,
                            device_id=device_id,
                            num_class=num_classes)
    
    if baseline == 'SG':
        try:
            model_path = f"../result/SG_data/SGcheckpoint_num_epoch_{num_epoch}_cv_3_dim_11_seed_{seed}_cifar10.pth"
            weights = torch.load(model_path, map_location=device)
        except:
            continue
    elif baseline in ['GA', 'fisher_new', 'wfisher', 'IU', 'FT', 'retrain']:
        model_path = f"../result/baselines/{baseline}checkpoint_{seed}.pth.tar"
        weights = torch.load(model_path, map_location=device)['state_dict']
    else:
        raise ValueError("Unknown baselines.")
        
    model = resnet18(num_classes=num_classes)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)
    model.load_state_dict(weights)
    model.to(device)
    
    retain_acc = DefenderOPT._evaluate_accuracy(model, retain_loader, device=device)
    test_acc = DefenderOPT._evaluate_accuracy(model, test_loader, device=device)
    val_acc = DefenderOPT._evaluate_accuracy(model, val_loader, device=device)
    forget_acc = DefenderOPT._evaluate_accuracy(model, forget_loader, device=device)
    MIA_acc, MIA_recall, MIA_auc = DefenderOPT._evaluate_MIA(model, 
                                                             forget_loader, 
                                                             val_loader, 
                                                             dim=evaluator.dim,
                                                             seed=seed,
                                                             device=device,
                                                             save=False)
    acc_ret.append([retain_acc*100, test_acc*100, val_acc*100, forget_acc*100, MIA_acc*100, MIA_auc])

    forget_losses = compute_losses(model, forget_loader, device)
    val_losses = compute_losses(model, val_loader, device)

    ## Since we have more forget losses than test losses, sub-sample them, to have a class-balanced dataset.
    np.random.shuffle(forget_losses)
    forget_losses = forget_losses[: len(val_losses)]
 
    samples_mia = np.concatenate((val_losses, forget_losses)).reshape((-1, 1))
    labels_mia = [0] * len(val_losses) + [1] * len(forget_losses)
    
    X_train, X_test, y_train, y_test = train_test_split(samples_mia, labels_mia, test_size=0.2, random_state=seed)
    
    clf = LogisticRegression()
    clf.fit(X_train, y_train)
    ## Get predicted probabilities
    y_prob = clf.predict_proba(X_test)[:, 1]
    ## Compute the true positive rates at different false positive rates, e.g., the TPR when FPR=1%, the TPR when FPR=5%, etc. 
    ## This metric is suggested by Carlini, et al. in https://arxiv.org/abs/2112.03570
    tprs_ret += custom_tpr_at_fpr(y_test, y_prob)
    
tprs_ret /= len(acc_ret)

In [7]:
### get the accuracy on different subsets
### RA: the accuracy on the retain set
### TA: the accuracy on the test set
### VA: the accuracy on the validation set
### FA: the accuracy on the forget set
### MIA_acc: the accuracy of the MIA 
### MIA_auc: the auc of the MIA

acc_df = pd.DataFrame(acc_ret)
acc_df.columns = ['RA', 'TA', 'VA', 'FA', 'MIA_acc', 'MIA_auc']

print('Average acc. (%)')
print(acc_df.mean())

print()

print('95% confidence')
print(acc_df.sem() * 1.96)

Average acc. (%)
RA          9.983556
TA          9.930000
VA          9.984000
FA          9.926000
MIA_acc    49.910000
MIA_auc     0.495163
dtype: float64

95% confidence
RA         0.516685
TA         0.495130
VA         0.508151
FA         0.474270
MIA_acc    0.180135
MIA_auc    0.006538
dtype: float64


In [8]:
## true positives at various false positive rates (1%, 5%, 10%)
print(tprs_ret * 100)

[1.07943301 5.02136651 9.86753757]
