In [1]:
import os
import subprocess
import requests
import tqdm
import random

from tabulate import tabulate

import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection

import pandas as pd
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torchvision.utils import make_grid
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Subset
import time

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'



In [2]:
# It's really important to add an accelerator to your notebook, as otherwise the submission will fail.
# We recomment using the P100 GPU rather than T4 as it's faster and will increase the chances of passing the time cut-off threshold.

if DEVICE != 'cuda':
    raise RuntimeError('Make sure you have added an accelerator to your notebook; the submission will fail otherwise!')

# Custom Configuration

## Seed Config

In [3]:
torch.manual_seed(3047)

G_retain = torch.Generator()
G_retain.manual_seed(20)

G_forget = torch.Generator()
G_forget.manual_seed(30)

G_validate = torch.Generator()
G_validate.manual_seed(40)

G_test = torch.Generator()
G_test.manual_seed(40)


G_unexp_retain = torch.Generator()
G_unexp_retain.manual_seed(50)

G_unexp_forget = torch.Generator()
G_unexp_forget.manual_seed(60)

<torch._C.Generator at 0x7e275cd99ad0>

## Testing Version or Submission Version

Here at the time of **testing**, we will keep the **internet on**, set **test=True** and load pretrained cifar10 model to test our unlearning algorithm implementation. In this case, we will utilize some parts from the given starting kit by the competition organizers.<br>
**Link:** https://github.com/unlearning-challenge/starting-kit/blob/main/unlearning-CIFAR10.ipynb <br><br>
At the time of **submission**, **internet off** and **test=False**

In [4]:
test = False
# test = True

# Load Dataset

In [5]:
# Helper functions for loading the CIFAR10 dataset.

if test:

    # The directory for a dataset and a pretrained model
    test_dir = './test'
    test_model_path = os.path.join(test_dir, "weights_resnet18_cifar10.pth")
    os.makedirs(test_dir, exist_ok=True)

    class PublicDataset(Dataset):

        def __init__(self, ds: Dataset):
            self._ds = ds

        def __len__(self):
            return len(self._ds)

        def __getitem__(self, index):
            item = self._ds[index]
            result = {
                'image': item[0],
                'image_id': index,
                'age_group': item[1],
                'age': item[1],
                'person_id': index,
            }
            return result

    def get_dataset(batch_size_r, batch_size_f,batch_size_vt, thinning_param: int=1, root=test_dir) -> tuple[DataLoader, DataLoader, DataLoader, DataLoader]:

        # utils
        normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]) # ??

        # create dataset
        train_set = torchvision.datasets.CIFAR10(root=test_dir, train=True, download=True, transform=normalize)
        train_ds = PublicDataset(train_set)

        # download the forget and retain index split
        local_path = "forget_idx.npy"
        if not os.path.exists(local_path):
            response = requests.get(
                "https://storage.googleapis.com/unlearning-challenge/" + local_path
            )
            open(local_path, "wb").write(response.content)

        forget_idx = np.load(local_path)

        # construct indices of retain from those of the forget set
        forget_mask = np.zeros(len(train_set.targets), dtype=bool)
        forget_mask[forget_idx] = True
        retain_idx = np.arange(forget_mask.size)[~forget_mask]

        # split train set into a forget and a retain set
        forget_ds = Subset(train_ds, forget_idx)
        retain_ds = Subset(train_ds, retain_idx)

        full_val_set = torchvision.datasets.CIFAR10(root=test_dir, train=False, download=True, transform=normalize)

        test_set, val_set = torch.utils.data.random_split(full_val_set, [0.5, 0.5])

        val_ds = PublicDataset(val_set)
        test_ds = PublicDataset(test_set)

        retain_loader = DataLoader(retain_ds, batch_size=batch_size_r, shuffle=True, generator=G_retain)
        forget_loader = DataLoader(forget_ds, batch_size=batch_size_f, shuffle=True, generator=G_forget)
        validation_loader = DataLoader(val_ds, batch_size=batch_size_vt, shuffle=True, generator=G_validate)
        test_loader = DataLoader(test_ds, batch_size=batch_size_vt, shuffle=True, generator=G_test)

        return retain_loader, forget_loader, validation_loader, test_loader

# Evaluation Using Loss & Accuracy

In [6]:
def calculate_acc_loss(net, dataloader, criterion):
    net.eval()
    total_samp = 0
    total_acc = 0
    total_loss = 0.0

    for sample in dataloader:
        images, labels = sample['image'].to(DEVICE), sample['age_group'].to(DEVICE)
        _pred = net(images).to(DEVICE)
        total_samp += len(labels)
        loss = criterion(_pred, labels)
        total_loss += loss.item()
        total_acc += (_pred.max(1)[1] == labels).float().sum().item()

    mean_loss = total_loss / len(dataloader)
    mean_acc = total_acc / total_samp * 100.0

    return mean_loss, mean_acc

### Loss Acc Evaluation & Test Submission

In [7]:
def printAccLoss(net):
  criterion = nn.CrossEntropyLoss()
  l, a = calculate_acc_loss(net, retain_loader, criterion)
  print(f"Retain set accuracy: {a:0.2f}%")
  print(f"Retain set loss: {l:0.2f}")
  l, a = calculate_acc_loss(net, forget_loader, criterion)
  print(f"Forget set accuracy: {a:0.2f}%")
  print(f"Forget set loss: {l:0.2f}")
  l, a = calculate_acc_loss(net, validation_loader, criterion)
  print(f"Validation set accuracy: {a:0.2f}%")
  print(f"Validation set loss: {l:0.2f}")
  l, a = calculate_acc_loss(net, test_loader, criterion)
  print(f"Test set accuracy: {a:0.2f}%")
  print(f"Test set loss: {l:0.2f}")

In [8]:
def sampleLoss(net, sample, criterion):
  inputs = sample["image"].to(DEVICE)
  labels = sample['age_group'].to(DEVICE)
  outputs = net(inputs).to(DEVICE)
  loss = criterion(outputs,  labels)
  return loss

def grads(net, sample, criterion):
  loss = sampleLoss(net, sample, criterion)
  loss.backward()
  g = {}
  with torch.no_grad():
    for name, param in net.named_parameters():
      if param.grad != None:
        # g.append(param.grad.clone())
        g[name] = param.grad.clone()
  return g

def gradsToDevice(net, sample, criterion):
  loss = sampleLoss(net, sample, criterion)
  loss.backward()
  g = {}
  with torch.no_grad():
    for name, param in net.named_parameters():
      if param.grad != None:
        # g.append(param.grad.clone())
        g[name] = param.grad.clone().to(DEVICE)
  return g


def gradsFlat(net, sample, criterion):
  loss = sampleLoss(net, sample, criterion)
  loss.backward()
  g = []
  with torch.no_grad():
    for param in net.parameters():
      if param.grad != None:
        g.append(torch.flatten(param.grad))
    return torch.cat(g)


def unitDotProduct(a, b):
  a = torch.flatten(a)
  b = torch.flatten(b)

  return torch.nn.functional.cosine_similarity(a, b,dim = 0)

def findComponent(direction1, vector1):
    direction = torch.flatten(direction1)
    vector = torch.flatten(vector1)
    dotProd = torch.dot(direction , vector)
    dirMagSquared = torch.dot(direction, direction)
    return ((dotProd / dirMagSquared) * direction).view_as(direction1)

def compareGradFlattenedAll(net, fLoader, rLoader, epochs, lr, unit):
  net.train()
  criterion = nn.CrossEntropyLoss()
  for _ in range(epochs):
    print(f"epoch : {_ + 1}")
    for fsample in fLoader:
      # fsample = next(iter(fLoader))
      rsample = next(iter(rLoader))
#       rgrads = grads(net, rsample, criterion)
#       fgrads = grads(net, fsample, criterion)
      r = gradsFlat(net, rsample, criterion)
      f = gradsFlat(net, fsample, criterion)

      with torch.no_grad():
        componentInDirOfrgrad = findComponent(r, f)
        perp = f - componentInDirOfrgrad
        for param in net.parameters():
          i = 0
          if param.grad != None:
            sz = param.numel()
#             print(sz)
            cgradf = perp[i : i + sz]
            i += sz
            cgrad = cgradf.view_as(param)
            norm = torch.norm(cgradf)
            if norm == 0 :
              continue
            cgrad /= norm
            param += lr * cgrad

  net.eval()


def compareGradFlattenedAll2(net, fLoader, rLoader, epochs, lr, unit, beta = .9):
  net.train()
  criterion = nn.CrossEntropyLoss()
  perpC = 0
    
  for _ in range(epochs):
    print(f"epoch : {_ + 1}")
    for fsample in fLoader:
      # fsample = next(iter(fLoader))
      rsample = next(iter(rLoader))
      rgrads = grads(net, rsample, criterion)
      fgrads = grads(net, fsample, criterion)

      with torch.no_grad():
        rgf = []
        fgf = []
        indices = {}
        lens = {}
        lastIndex = 0

        for key in rgrads.keys():
          r = rgrads[key]
          f = fgrads[key]
          rf = torch.flatten(r)
          ff = torch.flatten(f)

          indices[key] = lastIndex
          lens[key] = rf.shape[0]
          lastIndex += rf.shape[0]

          rgf.append(rf)
          fgf.append(ff)

        r = torch.cat(rgf)
        f = torch.cat(fgf)
        componentInDirOfrgrad = findComponent(r, f)
        perp = f - componentInDirOfrgrad
        if unit:
            norm = torch.norm(perp)
            if norm == 0:
                continue
            perp /= norm
        
        perpC = beta * perpC + (1 - beta) * perp
        perp = perpC
        # print(indices)

        for name, param in net.named_parameters():
          if param.grad != None:
            cgradf = perp[indices[name] : indices[name] + lens[name]]
            cgrad = cgradf.view_as(param)
#             if unit:
#               norm = torch.norm(cgradf)
#               if norm == 0 :
#                 continue
#               cgrad /= norm
            # print(f"name : {name}, grad : {cgrad}")
            # print(f"torch norm : {torch.norm(cgradf)}")
            param += lr * cgrad

  net.eval()

def getModel():
  test_dir = './test'
  local_path = "weights_resnet18_cifar10.pth"
  test_model_path = os.path.join(test_dir, local_path)
  if not os.path.exists(test_model_path):
    response = requests.get(
        "https://storage.googleapis.com/unlearning-challenge/" + local_path)
    open(test_model_path, "wb").write(response.content)
  net = resnet18(weights=None, num_classes=10)
  net.load_state_dict(torch.load(test_model_path, map_location = DEVICE))
  return net

def freezeBN(net):
  for name, module in net.named_modules():
    if isinstance(module, nn.BatchNorm2d):
      module.eval()
      for param in module.parameters():
        param.requires_grad = False

def unlearning(net, forget_loader, retain_loader, vLoader):
#     freezeBN(net)
    compareGradFlattenedAll(net, forget_loader, retain_loader, 12, .002, True)
    compareGradFlattenedAll(net, retain_loader, forget_loader, 3, -.002, True)
    compareGradFlattenedAll(net, retain_loader, forget_loader, 1, -.001, True)
    



In [9]:
if test:
    retain_loader, forget_loader, validation_loader, test_loader = get_dataset(64, 64, 64)
#     r,f, vt
    forget = get_dataset(64, 128, 64)
    retain = get_dataset(256, 64, 64)

In [10]:
if test:
    addNoiseList = ['layer4.0.conv1','layer4.0.conv2', 'layer4.0.downsample','layer4.0.downsample.0','layer4.0.downsample.1', 'layer4.1.conv1', 'layer4.1.conv2','fc']
    freezeList = ['conv1','bn1','layer1','layer1.0','layer1.0.conv1','layer1.0.bn1','layer1.0.conv2','layer1.0.bn2','layer1.1','layer1.1.conv1','layer1.1.bn1','layer1.1.conv2','layer1.1.bn2','layer2','layer2.0','layer2.0.conv1','layer2.0.bn1','layer2.0.conv2','layer2.0.bn2','layer2.0.downsample','layer2.0.downsample.0','layer2.0.downsample.1','layer2.1','layer2.1.conv1','layer2.1.bn1','layer2.1.conv2','layer2.1.bn2','layer3','layer3.0','layer3.0.conv1','layer3.0.bn1','layer3.0.conv2','layer3.0.bn2','layer3.0.downsample','layer3.0.downsample.0','layer3.0.downsample.1','layer3.1','layer3.1.conv1','layer3.1.bn1','layer3.1.conv2','layer3.1.bn2','layer4.0.bn1','layer4.0.bn2','layer4.1.bn1','layer4.1.bn2']

    net = getModel()
    net = net.to(DEVICE)
    freezeBN(net)
    # freeze(net, freezeList)
    # addNoise(net, addNoiseList)

256 -> 257
128 -> 257
64 -> 237
mix -> 150

In [11]:

if test:
    start_time = time.time()
#     compareGradFlattenedAll(net, forget_loader, retain_loader, 12, .002, True)
#     compareGradFlattenedAll2(net, forget[1], forget[0], 6, .05, True)
#     compareGradFlattenedAll2(net, forget[1], forget[0], 3, .01, True)
#     compareGradFlattenedAll2(net, retain[0], retain[1], 4, -.05, True)
    
    compareGradFlattenedAll2(net, forget[1], forget[0], 5, .02, True, .9)
    compareGradFlattenedAll2(net, retain[0], retain[1], 3, -.02, True,  .9)
    
    
#     compareGradFlattenedAll(net, retain_loader, forget_loader, 3, -.002, True)
#     compareGradFlattenedAll2(net, retain[0], retain[1], 3, -.002, True)
    
    
#     compareGradFlattenedAll2(net, forget[1], forget[0], 2, .05, True)
#     compareGradFlattenedAll2(net, retain[0], retain[1], 1, -.01, True)
    
    
#     compareGradFlattenedAll2(net, forget[1], forget[0], 2, .05, True)
#     compareGradFlattenedAll2(net, retain[0], retain[1], 1, -.01, True)

    
#     compareGradFlattenedAll2(net, forget[1], forget[0], 2, .05, True)
#     compareGradFlattenedAll2(net, retain[0], retain[1], 1, -.01, True)

#     unlearning(net, forget_loader, retain_loader, validation_loader)
    end_time = time.time()
    elapsed_time = end_time - start_time
    print("Elapsed time:", elapsed_time, "seconds")

In [12]:
if test:
    printAccLoss(net)

In [13]:
if test:
    printAccLoss(net)

In [14]:
if test:
    compareGradFlattenedAll(net, forget_loader, retain_loader, 12, .002, True)
    printAccLoss(net)


In [15]:
if test:
#     compareGradFlattenedAll(net, retain_loader, forget_loader , 3, -.002, True)
    compareGradFlattenedAll(net, retain_loader, forget_loader , 1, -.001, True)

    printAccLoss(net)


## Comparison With Trained Model Exclusively on Retain Set

------------ Base Model ---------------

Retain set accuracy: 99.48%

Retain set loss: 1.97%

Forget set accuracy: 99.32%

Forget set loss: 2.17%

Validation set accuracy: 88.28%

Validation set loss: 44.61%

Test set accuracy: 89.00%

Test set loss: 44.96%


------------ Unlearned Model ---------------

Retain set accuracy: 99.48%

Retain set loss: 1.97%

Forget set accuracy: 99.32%

Forget set loss: 2.16%

Validation set accuracy: 88.36%

Validation set loss: 46.37%

Test set accuracy: 88.92%

Test set loss: 44.01%

In [16]:
if test:

    # download weights of a model trained exclusively on the retain set
    local_path = "retrain_weights_resnet18_cifar10.pth"
    if not os.path.exists(local_path):
        response = requests.get(
            "https://storage.googleapis.com/unlearning-challenge/" + local_path
        )
        open(local_path, "wb").write(response.content)

    weights_pretrained = torch.load(local_path, map_location=DEVICE)

    # load model with pre-trained weights
    rt_model = resnet18(weights=None, num_classes=10)
    rt_model.load_state_dict(weights_pretrained)
    rt_model.to(DEVICE)
    rt_model.eval()

    # test criterion
    criterion = nn.CrossEntropyLoss()

    printAccLoss(rt_model)

# Evaluation using MIA

**Reference:** https://github.com/unlearning-challenge/starting-kit/blob/main/unlearning-CIFAR10.ipynb<br>
We will evaluate the trained models using Simple Membership Inference Attacks(MIA). This is **not used** as evaluation metric for the competition.

This MIA consists of a **logistic regression model** that predicts whether the model was trained on a particular sample from that sample's loss. To get an idea on the difficulty of this problem, we first plot below a histogram of the losses of the pre-trained models

## Visualize Pre-trained Model

In [17]:
if test:

    def compute_losses(model_, loader):
        """Auxiliary function to compute per-sample losses"""

        criterion = nn.CrossEntropyLoss(reduction="none")
        all_losses = []

        for sample in loader:
            images, labels = sample['image'].to(DEVICE), sample['age_group'].to(DEVICE)
            logits = model_(images)

            losses = criterion(logits, labels).numpy(force=True)
            for l in losses:
                all_losses.append(l)

        return np.array(all_losses)

In [18]:
if test:

    # load model with pre-trained weights
    model = resnet18(weights=None, num_classes=10)
    weights_pretrained = torch.load(test_model_path, map_location=DEVICE)
    model.load_state_dict(weights_pretrained)
    model.to(DEVICE)
    model.eval()

    retain_losses = compute_losses(model, retain_loader)
    forget_losses = compute_losses(model, forget_loader)
    test_losses = compute_losses(model, test_loader)

    plt.title("Losses on retain, forget and validation set (pre-trained model)")
    plt.hist(retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")
    plt.hist(forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")
    plt.hist(test_losses, density=True, alpha=0.5, bins=50, label="Test set")
    plt.xlabel("Loss", fontsize=14)
    plt.ylabel("Frequency", fontsize=14)
    plt.xlim((0, np.max(test_losses)))
    plt.yscale("log")
    plt.legend(frameon=False, fontsize=14)
    ax = plt.gca()
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    plt.show()

As per the above plot, the distributions of losses are quite different between the train and validation sets, as expected. In what follows, we will define an MIA that leverages the fact that examples that were trained on have smaller losses compared to examples that weren't.

## MIA Implementation

Now, we will define an MIA that leverages the fact that examples that were trained on have smaller losses compared to examples that weren't. Using this fact, the simple MIA defined below will aim to infer whether the forget set was in fact part of the training set.

This MIA is defined below. It takes as input the per-sample losses of the unlearned model on forget and test examples, and a membership label (0 or 1) indicating which of those two groups each sample comes from. It then returns the cross-validation accuracy of a linear model trained to distinguish between the two classes.

Intuitively, an unlearning algorithm is successful with respect to this simple metric if the attacker isn't able to distinguish the forget set from the test set any better than it would for the ideal unlearning algorithm (retraining from scratch without the retain set); see the last part of this MIA section for additional discussion and for computing that reference point.

In [19]:
if test:

    def simple_mia(sample_loss, members, n_splits=10, random_state=42):
        """Computes cross-validation score of a membership inference attack.

        Args:
          sample_loss : array_like of shape (n,).
            objective function evaluated on n samples.
          members : array_like of shape (n,),
            whether a sample was used for training.
          n_splits: int
            number of splits to use in the cross-validation.
        Returns:
          scores : array_like of size (n_splits,)
        """

        unique_members = np.unique(members)
        if not np.all(unique_members == np.array([0, 1])):
            raise ValueError("members should only have 0 & 1s")

        attack_model = linear_model.LogisticRegression()
        cv = model_selection.StratifiedShuffleSplit(
            n_splits=n_splits, random_state=random_state
        )
        return model_selection.cross_val_score(
            attack_model, sample_loss, members, cv=cv, scoring="accuracy"
        )

### MIA on Original Model

As a reference point, we first compute the accuracy of the MIA on the original model to distinguish between the forget set and the validation set.

In [20]:
if test:

    forget_losses = compute_losses(model, forget_loader)

    # Since we have more forget losses than test losses, sub-sample them, to have a class-balanced dataset.
    np.random.shuffle(forget_losses)
    forget_losses = forget_losses[: len(test_losses)]

    samples_mia = np.concatenate((test_losses, forget_losses)).reshape((-1, 1))
    labels_mia = [0] * len(test_losses) + [1] * len(forget_losses)

    mia_scores = simple_mia(samples_mia, labels_mia)

    print(f"The MIA has an accuracy of {mia_scores.mean():.3f} on forgotten vs unseen images")

### MIA on Unlearned Model

We'll now compute the accuracy of the MIA on the unlearned model. We expect the MIA to be less accurate on the unlearned model than on the original model, since the original model has not undergone a procedure to unlearn the forget set.

In [21]:
if test:

    net_forget_losses = compute_losses(net, forget_loader)
    net_retain_losses = compute_losses(net, retain_loader)
    net_test_losses = compute_losses(net, test_loader)

    np.random.shuffle(net_forget_losses)
    net_forget_losses = net_forget_losses[: len(test_losses)]

    net_samples_mia = np.concatenate((net_test_losses, net_forget_losses)).reshape((-1, 1))
    labels_mia = [0] * len(net_test_losses) + [1] * len(net_forget_losses)

    net_mia_scores = simple_mia(net_samples_mia, labels_mia)

    print(f"The MIA has an accuracy of {net_mia_scores.mean():.3f} on forgotten vs unseen images")

## Comparison With Original Model

From the score above, the MIA is indeed less accurate on the unlearned model than on the original model, as expected. Finally, we'll plot the histogram of losses of the unlearned model on the train and validation set. From the below figure, we can observe that the distributions of forget and validation losses are more similar under the unlearned model compared to the original model, as expected.

In [22]:
if test:

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

    ax1.set_title(f"Pre-trained model.\nAttack accuracy: {mia_scores.mean():0.2f}")
    ax1.hist(test_losses, density=True, alpha=0.5, bins=50, label="Test set")
    ax1.hist(forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")
    ax1.hist(retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")

    ax2.set_title(f"Unlearned by fine-tuning.\nAttack accuracy: {net_mia_scores.mean():0.2f}")
    ax2.hist(net_test_losses, density=True, alpha=0.5, bins=50, label="Test set")
    ax2.hist(net_forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")
    ax2.hist(net_retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")

    ax1.set_xlabel("Loss")
    ax2.set_xlabel("Loss")
    ax1.set_ylabel("Frequency")
    ax1.set_yscale("log")
    ax2.set_yscale("log")
    ax1.set_xlim((0, np.max(test_losses)))
    ax2.set_xlim((0, np.max(test_losses)))
    for ax in (ax1, ax2):
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
    ax1.legend(frameon=False, fontsize=14)
    plt.show()

## Comparison With Trained Model Exclusively on Retain Set

Since our goal is to approximate the model that has been trained only on the retain set, we'll consider that the gold standard is the score achieved by this model. Intuitively, we expect the MIA accuracy to be around 0.5, since for such a model, both the forget and test set are unseen samples from the same distribution. However, a number of factors such as distribution shift or class imbalance can make this number vary.

First, we will compute the MIA score on Re-trained model exclusive

In [23]:
if test:

    rt_test_losses = compute_losses(rt_model, test_loader)
    rt_forget_losses = compute_losses(rt_model, forget_loader)
    rt_retain_losses = compute_losses(rt_model, retain_loader)

    rt_samples_mia = np.concatenate((rt_test_losses, rt_forget_losses)).reshape((-1, 1))
    labels_mia = [0] * len(rt_test_losses) + [1] * len(rt_forget_losses)

    rt_mia_scores = simple_mia(rt_samples_mia, labels_mia)

    print(f"The MIA has an accuracy of {rt_mia_scores.mean():.3f} on forgotten vs unseen images")

Finally, as we've done before, let's compare the histograms of this ideal algorithm (re-trained model) vs the model obtain from

In [24]:
if test:

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 6))

    ax1.set_title(f"Original model.\nAttack accuracy: {mia_scores.mean():0.2f}")
    ax1.hist(test_losses, density=True, alpha=0.5, bins=50, label="Test set")
    ax1.hist(forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")
    ax1.hist(retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")

    ax2.set_title(f"Re-trained model.\nAttack accuracy: {rt_mia_scores.mean():0.2f}")
    ax2.hist(rt_test_losses, density=True, alpha=0.5, bins=50, label="Test set")
    ax2.hist(rt_forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")
    ax2.hist(rt_retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")

    ax3.set_title(f"Unlearned by fine-tuning.\nAttack accuracy: {net_mia_scores.mean():0.2f}")
    ax3.hist(net_test_losses, density=True, alpha=0.5, bins=50, label="Test set")
    ax3.hist(net_forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")
    ax3.hist(net_retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")

    ax1.set_xlabel("Loss")
    ax2.set_xlabel("Loss")
    ax3.set_xlabel("Loss")
    ax1.set_ylabel("Frequency")
    ax1.set_yscale("log")
    ax2.set_yscale("log")
    ax3.set_yscale("log")
    ax1.set_xlim((0, np.max(test_losses)))
    ax2.set_xlim((0, np.max(test_losses)))
    ax3.set_xlim((0, np.max(test_losses)))
    for ax in (ax1, ax2, ax3):
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
    ax1.legend(frameon=False, fontsize=14)
    plt.show()

# Submission

In [25]:
# Helper functions for loading the hidden dataset.

if not test:

    def load_example(df_row):
        image = torchvision.io.read_image(df_row['image_path'])
        result = {
            'image': image,
            'image_id': df_row['image_id'],
            'age_group': df_row['age_group'],
            'age': df_row['age'],
            'person_id': df_row['person_id']
        }
        return result


    class HiddenDataset(Dataset):
        '''The hidden dataset.'''
        def __init__(self, split='train'):
            super().__init__()
            self.examples = []

            df = pd.read_csv(f'/kaggle/input/neurips-2023-machine-unlearning/{split}.csv')
            df['image_path'] = df['image_id'].apply(
                lambda x: os.path.join('/kaggle/input/neurips-2023-machine-unlearning/', 'images', x.split('-')[0], x.split('-')[1] + '.png'))
            df = df.sort_values(by='image_path')
            df.apply(lambda row: self.examples.append(load_example(row)), axis=1)
            if len(self.examples) == 0:
                raise ValueError('No examples.')

        def __len__(self):
            return len(self.examples)

        def __getitem__(self, idx):
            example = self.examples[idx]
            image = example['image']
            image = image.to(torch.float32)
            example['image'] = image
            return example


    def get_dataset(batch_size_r, batch_size_f, batch_size_v):
        '''Get the dataset.'''
        retain_ds = HiddenDataset(split='retain')
        forget_ds = HiddenDataset(split='forget')
        val_ds = HiddenDataset(split='validation')

        retain_loader = DataLoader(retain_ds, batch_size=batch_size_r, shuffle=True, generator=G_retain)
        forget_loader = DataLoader(forget_ds, batch_size=batch_size_f, shuffle=True, generator=G_forget)
        validation_loader = DataLoader(val_ds, batch_size=batch_size_v, shuffle=True, generator=G_validate)

        return retain_loader, forget_loader, validation_loader

In [26]:
if not test:

    if os.path.exists('/kaggle/input/neurips-2023-machine-unlearning/empty.txt'):
        # mock submission
        subprocess.run('touch submission.zip', shell=True)
    else:

        # Note: it's really important to create the unlearned checkpoints outside of the working directory
        # as otherwise this notebook may fail due to running out of disk space.
        # The below code saves them in /kaggle/tmp to avoid that issue.

        os.makedirs('/kaggle/tmp', exist_ok=True)
#         retain_loader, forget_loader, validation_loader = get_dataset(512)
        forget = get_dataset(64, 128, 64)
        retain = get_dataset(256, 64, 64)
        
        net = resnet18(weights=None, num_classes=10)
        net.to(DEVICE)
        for i in range(512):
            net.load_state_dict(torch.load('/kaggle/input/neurips-2023-machine-unlearning/original_model.pth'))
#             unlearning(net, retain_loader, forget_loader, validation_loader)
            freezeBN(net)
#             compareGradFlattenedAll2(net, forget[1], forget[0], 5, .0005, True)
#             compareGradFlattenedAll2(net, retain[0], retain[1], 3, -.0005, True)

#             compareGradFlattenedAll2(net, forget[1], forget[0], 3, .01, True)
#             compareGradFlattenedAll2(net, forget[1], forget[0], 4, .005, True)
#             compareGradFlattenedAll2(net, retain[0], retain[1],4, -.02, True)

            compareGradFlattenedAll2(net, forget[1], forget[0], 5, .02, True, .9)
            compareGradFlattenedAll2(net, retain[0], retain[1], 3, -.02, True,  .9)
            
#             compareGradFlattenedAll2(net, forget[1], forget[0], 4, .005, True)
#             compareGradFlattenedAll2(net, retain[0], retain[1], 3, -.005, True)

            state = net.state_dict()
            torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{i}.pth')

        # Ensure that submission.zip will contain exactly 512 checkpoints
        # (if this is not the case, an exception will be thrown).
        unlearned_ckpts = os.listdir('/kaggle/tmp')
        if len(unlearned_ckpts) != 512:
            raise RuntimeError('Expected exactly 512 checkpoints. The submission will throw an exception otherwise.')

        subprocess.run('zip submission.zip /kaggle/tmp/*.pth', shell=True)