# Import Libraries

In [1]:
import os
import subprocess
import requests
import tqdm
import random

from tabulate import tabulate

import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection

import pandas as pd
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torchvision.utils import make_grid
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Subset

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 



In [2]:
# It's really important to add an accelerator to your notebook, as otherwise the submission will fail.
# We recomment using the P100 GPU rather than T4 as it's faster and will increase the chances of passing the time cut-off threshold.

if DEVICE != 'cuda':
    raise RuntimeError('Make sure you have added an accelerator to your notebook; the submission will fail otherwise!')

# Custom Configuration

## Seed Config 

In [3]:
torch.manual_seed(3047)

G_retain = torch.Generator()
G_retain.manual_seed(3047)

G_forget = torch.Generator()
G_forget.manual_seed(3049)

G_validate = torch.Generator()
G_validate.manual_seed(30470)

<torch._C.Generator at 0x7f9bf0bfec70>

## Testing Version or Submission Version 

Here at the time of **testing**, we will keep the **internet on**, set **test=True** and load pretrained cifar10 model to test our unlearning algorithm implementation. In this case, we will utilize some parts from the given starting kit by the competition organizers.<br>
**Link:** https://github.com/unlearning-challenge/starting-kit/blob/main/unlearning-CIFAR10.ipynb <br><br>
At the time of **submission**, **internet off** and **test=False**

In [4]:
test = False

# Load Dataset

In [5]:
# Helper functions for loading the CIFAR10 dataset.

if test:
    
    # The directory for a dataset and a pretrained model
    test_dir = './test'
    test_model_path = os.path.join(test_dir, "weights_resnet18_cifar10.pth")
    os.makedirs(test_dir, exist_ok=True)
    
    class PublicDataset(Dataset):
        
        def __init__(self, ds: Dataset):
            self._ds = ds
    
        def __len__(self):
            return len(self._ds)
    
        def __getitem__(self, index):
            item = self._ds[index]
            result = {
                'image': item[0],
                'image_id': index,
                'age_group': item[1],
                'age': item[1],
                'person_id': index,
            }
            return result
    
    def get_dataset(batch_size, thinning_param: int=1, root=test_dir) -> tuple[DataLoader, DataLoader, DataLoader, DataLoader, DataLoader]:
        
        # utils
        normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        # create dataset
        train_set = torchvision.datasets.CIFAR10(root=test_dir, train=True, download=True, transform=normalize)
        train_ds = PublicDataset(train_set)
        
        # download the forget and retain index split
        local_path = "forget_idx.npy"
        if not os.path.exists(local_path):
            response = requests.get(
                "https://storage.googleapis.com/unlearning-challenge/" + local_path
            )
            open(local_path, "wb").write(response.content)
            
        forget_idx = np.load(local_path)

        # construct indices of retain from those of the forget set
        forget_mask = np.zeros(len(train_set.targets), dtype=bool)
        forget_mask[forget_idx] = True
        retain_idx = np.arange(forget_mask.size)[~forget_mask]
        
        # split train set into a forget and a retain set
        forget_ds = Subset(train_ds, forget_idx)
        retain_ds = Subset(train_ds, retain_idx)
        
        full_val_set = torchvision.datasets.CIFAR10(root=test_dir, train=False, download=True, transform=normalize)
        
        test_set, val_set = torch.utils.data.random_split(full_val_set, [0.5, 0.5])
        
        val_ds = PublicDataset(val_set)
        test_ds = PublicDataset(test_set)

        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
        retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=True)
        forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=True)
        validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=True)

        return train_loader, retain_loader, forget_loader, validation_loader, test_loader

In [6]:
# Helper functions for loading the hidden dataset.

if not test:
    
    def load_example(df_row):
        image = torchvision.io.read_image(df_row['image_path'])
        result = {
            'image': image,
            'image_id': df_row['image_id'],
            'age_group': df_row['age_group'],
            'age': df_row['age'],
            'person_id': df_row['person_id']
        }
        return result


    class HiddenDataset(Dataset):
        '''The hidden dataset.'''
        def __init__(self, split='train'):
            super().__init__()
            self.examples = []

            df = pd.read_csv(f'/kaggle/input/neurips-2023-machine-unlearning/{split}.csv')
            df['image_path'] = df['image_id'].apply(
                lambda x: os.path.join('/kaggle/input/neurips-2023-machine-unlearning/', 'images', x.split('-')[0], x.split('-')[1] + '.png'))
            df = df.sort_values(by='image_path')
            df.apply(lambda row: self.examples.append(load_example(row)), axis=1)
            if len(self.examples) == 0:
                raise ValueError('No examples.')

        def __len__(self):
            return len(self.examples)

        def __getitem__(self, idx):
            example = self.examples[idx]
            image = example['image']
            image = image.to(torch.float32)
            example['image'] = image
            return example


    def get_dataset(batch_size):
        '''Get the dataset.'''
        retain_ds = HiddenDataset(split='retain')
        forget_ds = HiddenDataset(split='forget')
        val_ds = HiddenDataset(split='validation')

        retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=True, generator=G_retain)
        forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=True, generator=G_forget)
        validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True, generator=G_validate)

        return retain_loader, forget_loader, validation_loader

# Unlearning Algorithm (1st Place)

https://www.kaggle.com/competitions/neurips-2023-machine-unlearning/discussion/458721

### Kullback-Leibler Divergence Loss 

https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html#torch.nn.KLDivLoss

Used in the **1st stage**

In [7]:
def kl_loss_sym(x,y):
    """
    Calculates a symmetric version of the Kullback-Leibler (KL) divergence loss 
    between two input probability distributions. This version ensures symmetry by
    summing the KL divergence calculated in both directions.

    Args:
        x (torch.Tensor): Input tensor representing a probability distribution.
        y (torch.Tensor): Target tensor representing another probability distribution.

    Returns:
        torch.Tensor: The symmetric KL divergence loss.
    """
        
    kl_loss = nn.KLDivLoss(reduction='batchmean')
    return kl_loss(nn.LogSoftmax(dim=-1)(x),y) + 0.85 * kl_loss(y.log(),nn.Softmax(dim=-1)(x))

### Jensen Shennon Divergence(JSD) Loss 

https://discuss.pytorch.org/t/jensen-shannon-divergence/2626/10<br>
https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence

Used instead of **KL-Div Loss**

In [8]:
def jsd_loss(x, y):
    """
    Calculates the Jensen-Shannon Divergence (JSD) loss between two input 
    probability distributions. JSD is a symmetric measure of similarity between
    distributions.

    Args:
        x (torch.Tensor): Input tensor representing a probability distribution.
        y (torch.Tensor): Target tensor representing another probability distribution.

    Returns:
        torch.Tensor: The calculated JSD loss.
    """ 
    
    # Softmax for normalization
    softmax = nn.Softmax(dim=1)
    x = softmax(x)
    y = softmax(y)

    # Average of the distributions
    m = 0.5 * (x + y)

    return 0.5 * (kl_loss_sym(x, m) + kl_loss_sym(y, m))

### Contrastive Learning Loss

Used in the **second stage forget round**

In [9]:
# tau = temperature co-efficient
def cl_loss(outputs_forget, outputs_retain, tau=1.15):
    return (-1.0 * nn.LogSoftmax(dim=-1)(outputs_forget @ outputs_retain.T/tau)).mean()

### 2 Stage Training For Unlearning

In [10]:
# this will run at the time of submission

if not test:

    def unlearning(net, retain_loader, forget_loader, val_loader):

        # First Stage: Forgetting Stage (KL-Divergence Optimization)
        epochs = 1
        criterion = kl_loss_sym
        optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9, weight_decay=0)
        scheduler = None      # as just 1 epoch

        for ep in range(epochs):
            net.train()
            for sample in forget_loader:
                inputs = sample["image"]
                inputs = inputs.to(DEVICE)

                optimizer.zero_grad()
                outputs = net(inputs)
    
                uniform_psedo_label = torch.ones_like(outputs).to(DEVICE) / outputs.shape[1]            
                loss = criterion(outputs, uniform_psedo_label)
            
                loss.backward()
                optimizer.step()

                
        # Second Stage: Adverserial Fine Tuning (1. Forget Round , 2. Retain Round)
        epochs = 8
        retain_batch_size = 256
        criterion_forget = cl_loss
        criterion_retain = nn.CrossEntropyLoss()
        optimizer_forget = optim.SGD(net.parameters(), lr=3e-4, momentum=0.9, weight_decay=0)
        optimizer_retain = optim.SGD(net.parameters(), lr=0.001 * retain_batch_size / 64 , 
                                     momentum=0.9, weight_decay=0.01)
        scheduler_forget = optim.lr_scheduler.CosineAnnealingLR(optimizer_forget, 
                                                         T_max=epochs*len(forget_loader), eta_min=1e-6)
        
        
        # generate unexpected random shuffling at retain set (using seed for stabilization)
        # two set (1 for forget round, another for retain round)
            
        retain_loader_forget = DataLoader(retain_loader.dataset, batch_size=retain_batch_size, 
                                          shuffle=True)
        retain_loader_retain = DataLoader(retain_loader.dataset, batch_size=retain_batch_size, 
                                          shuffle=True)

        net.train()
        
        for ep in range(epochs):
            net.train()
            
            # forget round
            for _ in range(1):    # modification
                for sample_forget, sample_retain in zip(forget_loader, retain_loader_forget):
                    inputs_forget, inputs_retain = sample_forget["image"], sample_retain["image"]
                    inputs_forget, inputs_retain = inputs_forget.to(DEVICE), inputs_retain.to(DEVICE)

                    optimizer_forget.zero_grad()
                    outputs_forget, outputs_retain = net(inputs_forget), net(inputs_retain).detach()

                    # contrastive learning loss
                    loss = criterion_forget(outputs_forget, outputs_retain)
                    loss.backward()
                    optimizer_forget.step()

                    scheduler_forget.step()
                
            # retain round
            for sample in retain_loader_retain:
                inputs, labels = sample["image"], sample["age_group"]
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

                optimizer_retain.zero_grad()
                outputs = net(inputs)
                
                # cross entropy loss
                loss = criterion_retain(outputs, labels)
                loss.backward()
                optimizer_retain.step()
                
        return net

# Evaluation Using Loss & Accuracy

In [11]:
def calculate_acc_loss(net, dataloader, criterion, device = 'cuda'):
    net.eval()
    total_samp = 0
    total_acc = 0
    total_loss = 0.0
    
    for sample in dataloader:
        images, labels = sample['image'].to(device), sample['age_group'].to(device)
        _pred = net(images)
        total_samp += len(labels)
        loss = criterion(_pred, labels)
        total_loss += loss.item()
        total_acc += (_pred.max(1)[1] == labels).float().sum().item()

    mean_loss = total_loss / len(dataloader)
    mean_acc = total_acc / total_samp * 100.0
    
    return mean_loss, mean_acc

In [12]:
# this will be run at the time of testing
# same algorithm as before; just some extra validation and print the results to check algorithm

if test:
    
    def unlearning(net, retain_loader, forget_loader, val_loader):
        
        print("----------------------------------")
        
        col_names = ["Stage-Epoch", "Set Type", "Loss(%)", "Accuracy(%)"]
        table = []
        
        # test criterion
        criterion_test = nn.CrossEntropyLoss()

        # First Stage: KL-Divergence Optimization
        epochs = 1
        criterion = jsd_loss
        optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9, weight_decay=0)
        scheduler = None      # as just 1 epoch
        
        # testing
        net.eval()
        l, a = calculate_acc_loss(net, forget_loader, criterion_test)
        table.append(["Before 1st Stage", "Forget", l, a])
        l, a = calculate_acc_loss(net, validation_loader, criterion_test)
        table.append(["Before 1st Stage", "Valid", l, a])

        for ep in range(epochs):
            net.train()
            for sample in forget_loader:
                inputs = sample["image"]
                inputs = inputs.to(DEVICE)

                optimizer.zero_grad()
                outputs = net(inputs)
    
                uniform_psedo_label = torch.ones_like(outputs).to(DEVICE) / outputs.shape[1]            
                loss = criterion(outputs, uniform_psedo_label)
                        
                loss.backward()
                optimizer.step()
                
                
        # testing
        l, a = calculate_acc_loss(net, forget_loader, criterion_test)
        table.append(["After 1st Stage", "Forget", l, a])
        l, a = calculate_acc_loss(net, validation_loader, criterion_test)
        table.append(["After 1st Stage", "Valid", l, a])

                
        # Second Stage: Adverserial Fine Tuning (1. Forget Round , 2. Retain Round)
        epochs = 8
        retain_batch_size = 256
        criterion_forget = cl_loss
        criterion_retain = nn.CrossEntropyLoss()
        optimizer_forget = optim.SGD(net.parameters(), lr=3e-4, momentum=0.9, weight_decay=0)
        optimizer_retain = optim.SGD(net.parameters(), lr=0.001 * retain_batch_size / 64 , 
                                     momentum=0.9, weight_decay=0.01)
        scheduler_forget = optim.lr_scheduler.CosineAnnealingLR(optimizer_forget, 
                                                         T_max=epochs*len(forget_loader), eta_min=1e-6)
        
        
        # generate unexpected random shuffling at retain set
        # two set (1 for forget round, another for retain round)
            
        retain_loader_forget = DataLoader(retain_loader.dataset, batch_size=retain_batch_size, 
                                          shuffle=True)
        retain_loader_retain = DataLoader(retain_loader.dataset, batch_size=retain_batch_size, 
                                          shuffle=True)

        net.train()
        
        for ep in tqdm.trange(epochs):
            net.train()
            
            # forget round
            for _ in range(1):    # modification
                for sample_forget, sample_retain in zip(forget_loader, retain_loader_forget):
                    inputs_forget, inputs_retain = sample_forget["image"], sample_retain["image"]
                    inputs_forget, inputs_retain = inputs_forget.to(DEVICE), inputs_retain.to(DEVICE)

                    optimizer_forget.zero_grad()
                    outputs_forget, outputs_retain = net(inputs_forget), net(inputs_retain).detach()

                    # contrastive learning loss
                    loss = criterion_forget(outputs_forget, outputs_retain)
                    loss.backward()
                    optimizer_forget.step()

                    scheduler_forget.step()
                
            # retain round
            for sample in retain_loader_retain:
                inputs, labels = sample["image"], sample["age_group"]
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

                optimizer_retain.zero_grad()
                outputs = net(inputs)
                                
                loss = criterion_retain(outputs, labels)
                loss.backward()
                optimizer_retain.step()
                                                
            # testing
            stage_epoch = "Stage 2 - Epoch " + str(ep)
            l, a = calculate_acc_loss(net, forget_loader, criterion_test)
            table.append([stage_epoch, "Forget", l, a])
            l, a = calculate_acc_loss(net, retain_loader_retain, criterion_test)
            table.append([stage_epoch, "Retain", l, a])
            l, a = calculate_acc_loss(net, validation_loader, criterion_test)
            table.append([stage_epoch, "Valid", l, a])
            
        print(tabulate(table, headers=col_names, tablefmt="fancy_grid"))
        
        return net

### Loss Acc Evaluation & Test Submission

In [13]:
if test:
    
    n_checkpoints = 1  # in the submission, there will be 512 points
    
    if not os.path.exists(test_model_path):
        response = requests.get(
            "https://storage.googleapis.com/unlearning-challenge/weights_resnet18_cifar10.pth")
        open(test_model_path, "wb").write(response.content)    
    
    os.makedirs('/kaggle/tmp', exist_ok=True)
    random.seed(42)   # just for reproducibality
        
    train_loader, retain_loader, forget_loader, validation_loader, test_loader = get_dataset(64)
    net = resnet18(weights=None, num_classes=10)
    net.to(DEVICE)
    for i in tqdm.trange(n_checkpoints):
        net.load_state_dict(torch.load(test_model_path))
        net_ = unlearning(net, retain_loader, forget_loader, validation_loader)
        state = net_.state_dict()
        torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{i}.pth')

#     # Ensure that submission.zip will contain exactly n_checkpoints 
#     unlearned_ckpts = os.listdir('/kaggle/tmp')
#     if len(unlearned_ckpts) != n_checkpoints:
#         raise RuntimeError('The submission will throw an exception otherwise.')

#     subprocess.run('zip submission.zip /kaggle/tmp/*.pth', shell=True)

## Comparison With Trained Model Exclusively on Retain Set 

In [14]:
if test:
    
    # download weights of a model trained exclusively on the retain set
    local_path = "retrain_weights_resnet18_cifar10.pth"
    if not os.path.exists(local_path):
        response = requests.get(
            "https://storage.googleapis.com/unlearning-challenge/" + local_path
        )
        open(local_path, "wb").write(response.content)

    weights_pretrained = torch.load(local_path, map_location=DEVICE)

    # load model with pre-trained weights
    rt_model = resnet18(weights=None, num_classes=10)
    rt_model.load_state_dict(weights_pretrained)
    rt_model.to(DEVICE)
    rt_model.eval()

    # test criterion
    criterion = nn.CrossEntropyLoss()
    
    net = net_

    print("------------ Exclusive Retrained Model---------------")
    l, a = calculate_acc_loss(rt_model, retain_loader, criterion)
    print(f"Retain set accuracy: {a:0.2f}%")
    print(f"Retain set loss: {l:0.2f}")
    l, a = calculate_acc_loss(rt_model, forget_loader, criterion)
    print(f"Forget set accuracy: {a:0.2f}%")
    print(f"Forget set loss: {l:0.2f}")
    l, a = calculate_acc_loss(rt_model, validation_loader, criterion)
    print(f"Validation set accuracy: {a:0.2f}%")
    print(f"Validation set loss: {l:0.2f}")
    l, a = calculate_acc_loss(rt_model, test_loader, criterion)
    print(f"Test set accuracy: {a:0.2f}%")
    print(f"Test set loss: {l:0.2f}")
        
    print("------------ Unlearned Model ---------------")
    l, a = calculate_acc_loss(net, retain_loader, criterion)
    print(f"Retain set accuracy: {a:0.2f}%")
    print(f"Retain set loss: {l:0.2f}")
    l, a = calculate_acc_loss(net, forget_loader, criterion)
    print(f"Forget set accuracy: {a:0.2f}%")
    print(f"Forget set loss: {l:0.2f}")
    l, a = calculate_acc_loss(net, validation_loader, criterion)
    print(f"Validation set accuracy: {a:0.2f}%")
    print(f"Validation set loss: {l:0.2f}")
    l, a = calculate_acc_loss(net, test_loader, criterion)
    print(f"Test set accuracy: {a:0.2f}%")
    print(f"Test set loss: {l:0.2f}")

# Evaluation using MIA

**Reference:** https://github.com/unlearning-challenge/starting-kit/blob/main/unlearning-CIFAR10.ipynb<br>
We will evaluate the trained models using Simple Membership Inference Attacks(MIA). This is **not used** as evaluation metric for the competition.

This MIA consists of a **logistic regression model** that predicts whether the model was trained on a particular sample from that sample's loss. To get an idea on the difficulty of this problem, we first plot below a histogram of the losses of the pre-trained models

## Visualize Pre-trained Model 

In [15]:
if test:
    
    def compute_losses(model_, loader):
        """Auxiliary function to compute per-sample losses"""

        criterion = nn.CrossEntropyLoss(reduction="none")
        all_losses = []

        for sample in loader:
            images, labels = sample['image'].to(DEVICE), sample['age_group'].to(DEVICE)
            logits = model_(images)
            
            losses = criterion(logits, labels).numpy(force=True)
            for l in losses:
                all_losses.append(l)

        return np.array(all_losses)

In [16]:
if test:
    
    # load model with pre-trained weights
    model = resnet18(weights=None, num_classes=10)
    weights_pretrained = torch.load(test_model_path, map_location=DEVICE)
    model.load_state_dict(weights_pretrained)
    model.to(DEVICE)
    model.eval()

    retain_losses = compute_losses(model, retain_loader)
    forget_losses = compute_losses(model, forget_loader)
    test_losses = compute_losses(model, test_loader)
    
    plt.title("Losses on retain, forget and validation set (pre-trained model)")
    plt.hist(retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")
    plt.hist(forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")
    plt.hist(test_losses, density=True, alpha=0.5, bins=50, label="Test set")
    plt.xlabel("Loss", fontsize=14)
    plt.ylabel("Frequency", fontsize=14)
    plt.xlim((0, np.max(test_losses)))
    plt.yscale("log")
    plt.legend(frameon=False, fontsize=14)
    ax = plt.gca()
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    plt.show()

As per the above plot, the distributions of losses are quite different between the train and validation sets, as expected. In what follows, we will define an MIA that leverages the fact that examples that were trained on have smaller losses compared to examples that weren't.

## MIA Implementation 

Now, we will define an MIA that leverages the fact that examples that were trained on have smaller losses compared to examples that weren't. Using this fact, the simple MIA defined below will aim to infer whether the forget set was in fact part of the training set.

This MIA is defined below. It takes as input the per-sample losses of the unlearned model on forget and test examples, and a membership label (0 or 1) indicating which of those two groups each sample comes from. It then returns the cross-validation accuracy of a linear model trained to distinguish between the two classes.

Intuitively, an unlearning algorithm is successful with respect to this simple metric if the attacker isn't able to distinguish the forget set from the test set any better than it would for the ideal unlearning algorithm (retraining from scratch without the retain set); see the last part of this MIA section for additional discussion and for computing that reference point.

In [17]:
if test:
    
    def simple_mia(sample_loss, members, n_splits=10, random_state=42):
        """Computes cross-validation score of a membership inference attack.

        Args:
          sample_loss : array_like of shape (n,).
            objective function evaluated on n samples.
          members : array_like of shape (n,),
            whether a sample was used for training.
          n_splits: int
            number of splits to use in the cross-validation.
        Returns:
          scores : array_like of size (n_splits,)
        """
        unique_members = np.unique(members)
        if not np.all(unique_members == np.array([0, 1])):
            raise ValueError("members should only have 0 & 1s")

        attack_model = linear_model.LogisticRegression()
        cv = model_selection.StratifiedShuffleSplit(
            n_splits=n_splits, random_state=random_state
        )
        return model_selection.cross_val_score(
            attack_model, sample_loss, members, cv=cv, scoring="accuracy"
        )

### MIA on Original Model 

As a reference point, we first compute the accuracy of the MIA on the original model to distinguish between the forget set and the validation set.

In [18]:
if test:
    
    np.random.seed(42)   # just for reproducibality
    
    forget_losses = compute_losses(model, forget_loader)
    
    # Since we have more forget losses than test losses, sub-sample them, to have a class-balanced dataset.
    np.random.shuffle(forget_losses)
    forget_losses = forget_losses[: len(test_losses)]

    samples_mia = np.concatenate((test_losses, forget_losses)).reshape((-1, 1))
    labels_mia = [0] * len(test_losses) + [1] * len(forget_losses)

    mia_scores = simple_mia(samples_mia, labels_mia)

    print(f"The MIA has an accuracy of {mia_scores.mean():.3f} on forgotten vs unseen images")

### MIA on Unlearned Model 

We'll now compute the accuracy of the MIA on the unlearned model. We expect the MIA to be less accurate on the unlearned model than on the original model, since the original model has not undergone a procedure to unlearn the forget set.

In [19]:
if test:

    net_forget_losses = compute_losses(net, forget_loader)
    net_retain_losses = compute_losses(net, retain_loader)
    net_test_losses = compute_losses(net, test_loader)
    
    np.random.shuffle(net_forget_losses)
    net_forget_losses = net_forget_losses[: len(test_losses)]

    net_samples_mia = np.concatenate((net_test_losses, net_forget_losses)).reshape((-1, 1))
    labels_mia = [0] * len(net_test_losses) + [1] * len(net_forget_losses)

    net_mia_scores = simple_mia(net_samples_mia, labels_mia)

    print(f"The MIA has an accuracy of {net_mia_scores.mean():.3f} on forgotten vs unseen images")

## Comparison With Original Model 

From the score above, the MIA is indeed less accurate on the unlearned model than on the original model, as expected. Finally, we'll plot the histogram of losses of the unlearned model on the train and validation set. From the below figure, we can observe that the distributions of forget and validation losses are more similar under the unlearned model compared to the original model, as expected.

In [20]:
if test:
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

    ax1.set_title(f"Pre-trained model.\nAttack accuracy: {mia_scores.mean():0.2f}")
    ax1.hist(test_losses, density=True, alpha=0.5, bins=50, label="Test set")
    ax1.hist(forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")
    ax1.hist(retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")

    ax2.set_title(f"Unlearned by fine-tuning.\nAttack accuracy: {net_mia_scores.mean():0.2f}")
    ax2.hist(net_test_losses, density=True, alpha=0.5, bins=50, label="Test set")
    ax2.hist(net_forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")
    ax2.hist(net_retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")

    ax1.set_xlabel("Loss")
    ax2.set_xlabel("Loss")
    ax1.set_ylabel("Frequency")
    ax1.set_yscale("log")
    ax2.set_yscale("log")
    ax1.set_xlim((0, np.max(test_losses)))
    ax2.set_xlim((0, np.max(test_losses)))
    for ax in (ax1, ax2):
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
    ax1.legend(frameon=False, fontsize=14)
    plt.show()

## Comparison With Trained Model Exclusively on Retain Set 

Since our goal is to approximate the model that has been trained only on the retain set, we'll consider that the gold standard is the score achieved by this model. Intuitively, we expect the MIA accuracy to be around 0.5, since for such a model, both the forget and test set are unseen samples from the same distribution. However, a number of factors such as distribution shift or class imbalance can make this number vary.

First, we will compute the MIA score on Re-trained model exclusive

In [21]:
if test:

    rt_test_losses = compute_losses(rt_model, test_loader)
    rt_forget_losses = compute_losses(rt_model, forget_loader)
    rt_retain_losses = compute_losses(rt_model, retain_loader)

    rt_samples_mia = np.concatenate((rt_test_losses, rt_forget_losses)).reshape((-1, 1))
    labels_mia = [0] * len(rt_test_losses) + [1] * len(rt_forget_losses)
    
    rt_mia_scores = simple_mia(rt_samples_mia, labels_mia)

    print(f"The MIA has an accuracy of {rt_mia_scores.mean():.3f} on forgotten vs unseen images")

Finally, as we've done before, let's compare the histograms of this ideal algorithm (re-trained model) vs the model obtain from

In [22]:
if test:
    
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 6))
    
    ax1.set_title(f"Original model.\nAttack accuracy: {mia_scores.mean():0.2f}")
    ax1.hist(test_losses, density=True, alpha=0.5, bins=50, label="Test set")
    ax1.hist(forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")
    ax1.hist(retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")

    ax2.set_title(f"Re-trained model.\nAttack accuracy: {rt_mia_scores.mean():0.2f}")
    ax2.hist(rt_test_losses, density=True, alpha=0.5, bins=50, label="Test set")
    ax2.hist(rt_forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")
    ax2.hist(rt_retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")

    ax3.set_title(f"Unlearned by fine-tuning.\nAttack accuracy: {net_mia_scores.mean():0.2f}")
    ax3.hist(net_test_losses, density=True, alpha=0.5, bins=50, label="Test set")
    ax3.hist(net_forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")
    ax3.hist(net_retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")

    ax1.set_xlabel("Loss")
    ax2.set_xlabel("Loss")
    ax3.set_xlabel("Loss")
    ax1.set_ylabel("Frequency")
    ax1.set_yscale("log")
    ax2.set_yscale("log")
    ax3.set_yscale("log")
    ax1.set_xlim((0, np.max(test_losses)))
    ax2.set_xlim((0, np.max(test_losses)))
    ax3.set_xlim((0, np.max(test_losses)))
    for ax in (ax1, ax2, ax3):
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
    ax1.legend(frameon=False, fontsize=14)
    plt.show()

# Evaluation using Combination of Forgetting Quality, Efficiency and Utility

A metric close to this is **actually used** in the competition<br>
**Reference:** https://unlearning-challenge.github.io/assets/data/Machine_Unlearning_Metric.pdf

It implements Algorithm 1, the *forgetting quality* $\mathcal{F}$, and the total scoring function from the above article. It also provides a version of a full scoring function, meant to be mimicking the one used for the competition, and uses it to score a "perfect" unlearning algorithm on CIFAR-10.

Note that the article does not include all details needed for reproducing the code for computing this metric, thus some gaps are filled in in an improvised manner. The following values and functions are unknown: 
* the value of $\delta$,
* the summary statistic $f$ that summarises the outpus of the models $R_i$ and $U_i$ into a scalar, and
* the attacks that yield the false positive rates (FPRs) and false negative rates (FNRs) needed for Algorithm 1.

**The following values are used for this notebook:**
* $\delta = 0.01$
* $f = $ cross entropy loss
* attacks = [ logistic_regression, best_threshold_attack ]

In [23]:
if test:
    from copy import deepcopy
    from typing import Callable
    from tqdm.notebook import tqdm

    from sklearn.metrics import make_scorer, accuracy_score

## Helper Functions 

In [24]:
if test:
    
    def accuracy(net, loader):
        """Return accuracy on a dataset given by the data loader."""
        correct = 0
        total = 0
        for sample in loader:
            inputs, targets = sample['image'].to(DEVICE), sample['age_group'].to(DEVICE)
            outputs = net(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        return correct / total


    def compute_outputs(net, loader):
        """Auxiliary function to compute the logits for all datapoints.
        Does not shuffle the data, regardless of the loader.
        """

        # Make sure loader does not shuffle the data
        if isinstance(loader.sampler, torch.utils.data.sampler.RandomSampler):
            loader = DataLoader(
                loader.dataset, 
                batch_size=loader.batch_size, 
                shuffle=False, 
                num_workers=loader.num_workers)

        all_outputs = []

        for sample in loader:
            inputs, targets = sample['image'].to(DEVICE), sample['age_group'].to(DEVICE)

            logits = net(inputs).detach().cpu().numpy() # (batch_size, num_classes)

            all_outputs.append(logits)

        return np.concatenate(all_outputs) # (len(loader.dataset), num_classes)


    def false_positive_rate(y_true: np.ndarray, y_pred: np.ndarray) -> float:
        """Computes the false positive rate (FPR)."""
        fp = np.sum(np.logical_and((y_pred == 1), (y_true == 0)))
        n = np.sum(y_true == 0)
        return fp / n


    def false_negative_rate(y_true: np.ndarray, y_pred: np.ndarray) -> float:
        """Computes the false negative rate (FNR)."""
        fn = np.sum(np.logical_and((y_pred == 0), (y_true == 1)))
        p = np.sum(y_true == 1)
        return fn / p


    # The SCORING dictionary is used by sklearn's `cross_validate` function so that
    # we record the FPR and FNR metrics of interest when doing cross validation
    SCORING = {
        'false_positive_rate': make_scorer(false_positive_rate),
        'false_negative_rate': make_scorer(false_negative_rate)
    }

## The summary statistic $f$

Recall that $U^s$ and $R^s$ are the distributions of (scalar) outputs of unlearned and retrained models when receiving a particular forget set example $s \in S$ as input. That is: 

$$U^s = \{ f(U_1(s)), \dots f(U_N(s)) \}$$

$$R^s = \{ f(R_1(s)), \dots f(R_N(s)) \}$$

where $M(s)$ yields the outputs obtained by feeding example s into model $M$, $f$ is the function that transforms those outputs into a scalar, and $N$ is the number of times we retrain / unlearn to obtain an approximation of the distribution of retrained / unlearned models.

The actual function $f$ used for scoring is not disclosed in the competition article. Below is an $f$ that computes the cross entropy loss between the logit output x and the highest probability class as given by the model output x. This statistic $f$ corresponds to the (negative log) probability that the model assigns to the class that the model predicts with the highest probability. 

In [25]:
if test:
    def cross_entropy_f(x):
        # To ensure this function doesn't fail due to nans, find
        # all-nan rows in x and substitude them with all-zeros.
        x[np.all(np.isnan(x), axis=-1)] = np.zeros(x.shape[-1])

        pred = torch.tensor(np.nanargmax(x, axis = -1))
        x = torch.tensor(x)

        fn = nn.CrossEntropyLoss(reduction="none")

        return fn(x, pred).numpy()

## Decision rules (attacks)
Next, we define the *decision rules (attaks)* to be used when computing the forgetting quality. 

A decision rule is a function, which takes $U^s$ and $R^s$ as input, learns to predict wheather a (scalar) output $y \in U^s \bigcup R^s$ has been generated by an unlearned or a retrained model. The rule than returns the false positive rate (FPR) and the false negative rate (FNR) of the prediction.

The original article does not specify what the decision rules used are, so the ones chosen here are improvised:
* Logistic regression attack: Train a logistic regression classifier to predict whether a scalar output comes from a retrained or an unlearned model. Use cross validation and return the average FPR and FPN across different test folds.
* Best threshold attack: Consider a model that predicts that everything below certain threshold comes from the retrained model, and everything above that threshold comes from the unlearned model. Calculate the FPR and FNR for all the possible thresholds. 

### Logistic Regression Attack

In [26]:
if test:
    
    def logistic_regression_attack(
            outputs_U, outputs_R, n_splits=2, random_state=0):
        """Computes cross-validation score of a membership inference attack.

        Args:
          outputs_U: numpy array of shape (N)
          outputs_R: numpy array of shape (N)
          n_splits: int
            number of splits to use in the cross-validation.
        Returns:
          fpr, fnr : float * float
        """
        assert len(outputs_U) == len(outputs_R)

        samples = np.concatenate((outputs_R, outputs_U)).reshape((-1, 1))
        labels = np.array([0] * len(outputs_R) + [1] * len(outputs_U))

        attack_model = linear_model.LogisticRegression()
        cv = model_selection.StratifiedShuffleSplit(
            n_splits=n_splits, random_state=random_state
        )
        scores =  model_selection.cross_validate(
            attack_model, samples, labels, cv=cv, scoring=SCORING)

        fpr = np.mean(scores["test_false_positive_rate"])
        fnr = np.mean(scores["test_false_negative_rate"])

        return fpr, fnr

### Threshold Attack

In [27]:
if test:
    
    def best_threshold_attack(
            outputs_U: np.ndarray, 
            outputs_R: np.ndarray, 
            random_state: int = 0
        ) -> tuple[list[float], list[float]]:
        """Computes FPRs and FNRs for an attack that simply splits into 
        predicted positives and predited negatives based on any possible 
        single threshold.

        Args:
          outputs_U: numpy array of shape (N)
          outputs_R: numpy array of shape (N)
        Returns:
          fpr, fnr : list[float] * list[float]
        """
        assert len(outputs_U) == len(outputs_R)

        samples = np.concatenate((outputs_R, outputs_U))
        labels = np.array([0] * len(outputs_R) + [1] * len(outputs_U))

        N = len(outputs_U)

        fprs, fnrs = [], []
        for thresh in sorted(list(samples.squeeze())):
            ypred = (samples > thresh).astype("int")
            fprs.append(false_positive_rate(labels, ypred))
            fnrs.append(false_negative_rate(labels, ypred))

        return fprs, fnrs

## The privacy degree $\epsilon_s$, and the scoring function $\mathcal{H}$

Next we specify the function`compute_epsilon_s`, which corresponds to Algorithm 1 from the scoring article. It considers a particular example $s$ from the forget set, it takes the list of FPRs and FNRs as produced by various attacks, and computes the privacy degree $\epsilon^s$ of the example $s$. 

In [28]:
if test:
    
    def compute_epsilon_s(fpr: list[float], fnr: list[float], delta: float) -> float:
        """Computes the privacy degree (epsilon) of a particular forget set example, 
        given the FPRs and FNRs resulting from various attacks.

        The smaller epsilon is, the better the unlearning is.

        Args:
          fpr: list[float] of length m = num attacks. The FPRs for a particular example. 
          fnr: list[float] of length m = num attacks. The FNRs for a particular example.
          delta: float
        Returns:
          epsilon: float corresponding to the privacy degree of the particular example.
        """
        assert len(fpr) == len(fnr)

        per_attack_epsilon = [0.]
        for fpr_i, fnr_i in zip(fpr, fnr):
            if fpr_i == 0 and fnr_i == 0:
                per_attack_epsilon.append(np.inf)
            elif fpr_i == 0 or fnr_i == 0:
                pass # discard attack
            else:
                with np.errstate(invalid='ignore'):
                    epsilon1 = np.log(1. - delta - fpr_i) - np.log(fnr_i)
                    epsilon2 = np.log(1. - delta - fnr_i) - np.log(fpr_i)
                if np.isnan(epsilon1) and np.isnan(epsilon2):
                    per_attack_epsilon.append(np.inf)
                else:
                    per_attack_epsilon.append(np.nanmax([epsilon1, epsilon2]))

        return np.nanmax(per_attack_epsilon)


    def bin_index_fn(
            epsilons: np.ndarray, 
            bin_width: float = 0.5, 
            B: int = 13
            ) -> np.ndarray:
        """The bin index function."""
        bins = np.arange(0, B) * bin_width
        return np.digitize(epsilons, bins)


    def F(epsilons: np.ndarray) -> float:
        """Computes the forgetting quality given the privacy degrees 
        of the forget set examples.
        """
        ns = bin_index_fn(epsilons)
        hs = 2. / 2 ** ns
        return np.mean(hs)

### The forgetting quality $\mathcal{F}$

The function below computes the forgetting quality $\mathcal{F}$ given only the (scalar) outputs of the $N$ retrain and $N$ unlearn models on all the forget set examples. It:
1. Iterates over each sample in $S$,
2. Performs the attacks specified to obtain lists of FPRs and FNRs,
3. Computes the privacy degree $\epsilon^s$ for each sample,
4. Computes the forgetting quality by averaging over the forget scores of all samples.

In [29]:
if test:
    
    def forgetting_quality(
            outputs_U: np.ndarray, # (N, S)
            outputs_R: np.ndarray, # (N, S)
            attacks: list[Callable] = [logistic_regression_attack],
            delta: float = 0.01
        ):
        """
        Both `outputs_U` and `outputs_R` are of numpy arrays of ndim 2:
        * 1st dimension coresponds to the number of samples obtained from the 
          distribution of each model (N=512 in the case of the competition's leaderboard) 
        * 2nd dimension corresponds to the number of samples in the forget set (S).
        """

        # N = number of model samples
        # S = number of forget samples
        N, S = outputs_U.shape

        assert outputs_U.shape == outputs_R.shape, \
            "unlearn and retrain outputs need to be of the same shape"

        epsilons = []
        pbar = tqdm(range(S))
        for sample_id in pbar:
            pbar.set_description("Computing F...")

            sample_fprs, sample_fnrs = [], []

            for attack in attacks: 
                uls = outputs_U[:, sample_id]
                rls = outputs_R[:, sample_id]

                fpr, fnr = attack(uls, rls)

                if isinstance(fpr, list):
                    sample_fprs.extend(fpr)
                    sample_fnrs.extend(fnr)
                else:
                    sample_fprs.append(fpr)
                    sample_fnrs.append(fnr)

            sample_epsilon = compute_epsilon_s(sample_fprs, sample_fnrs, delta=delta)
            epsilons.append(sample_epsilon)

        return F(np.array(epsilons))

### Global scoring function

Finally, we package everything together in a single function that takes an unlearning callable, unlearns many times and outputs a dictionary of various quantities, including:
* The forget quality $\mathcal{F}$
* The various accuracies $RA^U, TA^U, RA^R, TA^R$
* The total score = $\mathcal{F} \times \frac{RA^U}{RA^R} \times \frac{TA^U}{TA^R}$

Note that since we are provided with only one retrained model for for the CIFAR dataset example, this function includes a littel hack where we add some noise to the output of that pretrained mdoel a few times, instead of pretraining from scratch. The correct version of this function would consider $N$ separate models retrained from scratch with different random seed.  

In [30]:
if test:
    
    def score_unlearning_algorithm(
            data_loaders: dict, 
            pretrained_models: dict, 
            net,
            n: int = 10,
            delta: float = 0.01,
            f: Callable = cross_entropy_f,
            attacks: list[Callable] = [best_threshold_attack, logistic_regression_attack]
            ) -> dict:

        # n=512 in the case of unlearn and n=1 in the
        # case of retrain, since we are only provided with one retrained model here

        retain_loader = data_loaders["retain"]
        forget_loader = data_loaders["forget"]
        val_loader = data_loaders["validation"]
        test_loader = data_loaders["testing"]

        original_model = pretrained_models["original"]
        rt_model = pretrained_models["retrained"]

        outputs_U = []
        retain_accuracy = []
        test_accuracy = []
        forget_accuracy = []

        pbar = tqdm(range(n))
        for i in pbar:

            # unlearned model
            pbar.set_description(f"Unlearning...")
            u_model = net

            outputs_Ui = compute_outputs(u_model, forget_loader) 
            # The shape of outputs_Ui is (len(forget_loader.dataset), 10)
            # which for every datapoint is being cast to a scalar using the funtion f
            outputs_U.append( f(outputs_Ui) )

            pbar.set_description(f"Computing retain accuracy...")
            retain_accuracy.append(accuracy(u_model, retain_loader))

            pbar.set_description(f"Computing test accuracy...")
            test_accuracy.append(accuracy(u_model, test_loader))

            pbar.set_description(f"Computing forget accuracy...")
            forget_accuracy.append(accuracy(u_model, forget_loader))


        outputs_U = np.array(outputs_U) # (n, len(forget_loader.dataset))

        assert outputs_U.shape == (n, len(forget_loader.dataset)),\
            "Wrong shape for outputs_U. Should be (num_model_samples, num_forget_datapoints)."

        RAR = accuracy(rt_model, retain_loader)
        TAR = accuracy(rt_model, test_loader)
        FAR = accuracy(rt_model, forget_loader)

        RAU = np.mean(retain_accuracy)
        TAU = np.mean(test_accuracy)
        FAU = np.mean(forget_accuracy)

        RA_ratio = RAU / RAR
        TA_ratio = TAU / TAR

        # need to fake this a little because we only have one retrain model
        scale = np.std(outputs_U) / 10.
        outputs_Ri = compute_outputs(rt_model, forget_loader) #(len(forget_loader.dataset), 10) 
        outputs_Ri = np.expand_dims(outputs_Ri, axis=0)
        outputs_Ri = np.random.normal(
            loc=outputs_Ri, scale=scale, size=(n, *outputs_Ri.shape[-2:]))

        outputs_R = np.array([ f( oRi ) for oRi in outputs_Ri ])

        np.save("outputs_U.npy", outputs_U)
        np.save("outputs_R.npy", outputs_R)

        f = forgetting_quality(
            outputs_U, 
            outputs_R,
            attacks=attacks,
            delta=delta)

        return {
            "total_score": f * RA_ratio * TA_ratio,
            "F": f,
            "unlearn_retain_accuracy": RAU,
            "unlearn_test_accuracy": TAU, 
            "unlearn_forget_accuracy": FAU,
            "retrain_retain_accuracy": RAR,
            "retrain_test_accuracy": TAR, 
            "retrain_forget_accuracy": FAR,
            "retrain_outputs": outputs_R,
            "unlearn_outputs": outputs_U
        }

## Check Our Algorithm 

In [31]:
if test:
    
    data_loaders = {
        "training": train_loader,
        "testing": test_loader,
        "validation": validation_loader,
        "forget": forget_loader,
        "retain": retain_loader
    }

    pretrained_models = {
        "original": model,
        "retrained": rt_model
    }

    ret = score_unlearning_algorithm(data_loaders, pretrained_models, net)

In [32]:
if test:
    col_names = ["total score", "F score", "unlearn_retain_accuracy", 
                 "unlearn_test_accuracy", "unlearn_forget_accuracy"]
    table=[]
    table.append([ret["total_score"], ret["F"], ret["unlearn_retain_accuracy"]*100.0, 
                  ret["unlearn_test_accuracy"]*100.0, ret["unlearn_forget_accuracy"]*100])
    print(tabulate(table, headers=col_names, tablefmt="fancy_grid"))

In [33]:
if test:
    udata = ret["unlearn_outputs"][:,0]
    rdata = ret["retrain_outputs"][:,0]
    data = np.array([udata, rdata])

    bins = np.arange(np.min(data), np.max(data) + 0.1, 0.1)

    plt.hist(udata, bins=bins, alpha=0.7, label="Unlearned")
    plt.hist(rdata, bins=bins, alpha=0.7, label="Retrained")

    plt.legend()
    plt.show()

# Submission

In [34]:
if not test:
    
    if os.path.exists('/kaggle/input/neurips-2023-machine-unlearning/empty.txt'):
        # mock submission
        subprocess.run('touch submission.zip', shell=True)
    else:

        # Note: it's really important to create the unlearned checkpoints outside of the working directory 
        # as otherwise this notebook may fail due to running out of disk space.
        # The below code saves them in /kaggle/tmp to avoid that issue.

        os.makedirs('/kaggle/tmp', exist_ok=True)
        random.seed(42)
        retain_loader, forget_loader, validation_loader = get_dataset(64)
        net = resnet18(weights=None, num_classes=10)
        net.to(DEVICE)
        for i in range(512):
            net.load_state_dict(torch.load('/kaggle/input/neurips-2023-machine-unlearning/original_model.pth'))
            net_ = unlearning(net, retain_loader, forget_loader, validation_loader)
            state = net_.state_dict()
            torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{i}.pth')

        # Ensure that submission.zip will contain exactly 512 checkpoints 
        # (if this is not the case, an exception will be thrown).
        unlearned_ckpts = os.listdir('/kaggle/tmp')
        if len(unlearned_ckpts) != 512:
            raise RuntimeError('Expected exactly 512 checkpoints. The submission will throw an exception otherwise.')

        subprocess.run('zip submission.zip /kaggle/tmp/*.pth', shell=True)