In [1]:
import os
import requests
import random
import math
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import prune
from torch import optim
from torch.utils.data import DataLoader, TensorDataset

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

# manual random seed is used for dataset partitioning
# to ensure reproducible results across runs
SEED = 42
RNG = torch.Generator().manual_seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

Running on device: CUDA


<torch._C.Generator at 0x1b94b718b70>

In [2]:
import ssl

# Create an unverified SSL context
ssl._create_default_https_context = ssl._create_unverified_context

In [3]:
# download and pre-process CIFAR10
normalize = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_set = torchvision.datasets.CIFAR10(
    root="../example notebooks/data", train=True, download=False, transform=normalize
)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=1)

# we split held out data into test and validation set
held_out = torchvision.datasets.CIFAR10(
    root="../example notebooks/data", train=False, download=False, transform=normalize
)
test_set, val_set = torch.utils.data.random_split(held_out, [0.5, 0.5], generator=RNG)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=1)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=1)

# download the forget and retain index split
local_path = "../example notebooks/forget_idx.npy"
# if not os.path.exists(local_path):
#     response = requests.get(
#         "https://storage.googleapis.com/unlearning-challenge/" + local_path
#     )
#     open(local_path, "wb").write(response.content)
forget_idx = np.load(local_path)

# construct indices of retain from those of the forget set
forget_mask = np.zeros(len(train_set.targets), dtype=bool)
forget_mask[forget_idx] = True
retain_idx = np.arange(forget_mask.size)[~forget_mask]

# split train set into a forget and a retain set
forget_set = torch.utils.data.Subset(train_set, forget_idx)
retain_set = torch.utils.data.Subset(train_set, retain_idx)

forget_loader = torch.utils.data.DataLoader(
    forget_set, batch_size=128, shuffle=False, num_workers=1
)
# retain_loader = torch.utils.data.DataLoader(
#     retain_set, batch_size=128, shuffle=True, num_workers=1, generator=RNG
# )

In [4]:
retain_loader = torch.utils.data.DataLoader(
    retain_set, batch_size=128, shuffle=False, num_workers=1
)

In [5]:
local_path = "../example notebooks/weights/weights_resnet18_cifar10.pth"
if not os.path.exists(local_path):
    response = requests.get(
        "https://storage.googleapis.com/unlearning-challenge/weights_resnet18_cifar10.pth"
    )
    open(local_path, "wb").write(response.content)

weights_pretrained = torch.load(local_path, map_location=DEVICE) #43Mbs

# load model with pre-trained weights
model = resnet18(weights=None, num_classes=10)
model.load_state_dict(weights_pretrained)
model.to(DEVICE)
model.eval();

In [6]:
def accuracy(net, loader):
    """Return accuracy on a dataset given by the data loader."""
    correct = 0
    total = 0
    for inputs, targets in loader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = net(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return correct / total

In [7]:
def unstructure_prune(model, pruning_amount=0.2, global_pruning=False, random_init=False):

    parameters_to_prune = []
    if global_pruning:
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
                parameters_to_prune.append((module, 'weight'))

        #Global pruning
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=pruning_amount
        )

    else:
         for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
                prune.l1_unstructured(module, name='weight', amount=pruning_amount)
                parameters_to_prune.append((module, 'weight'))
                

    # Randomly re-initialize pruned weights while preserving the mask
    for module, param_name in parameters_to_prune:
        if random_init:
            mask = getattr(module, f"{param_name}_mask")  # Get the binary mask used for pruning
            init_weights = getattr(module, param_name)  # Get the current weights
            # Randomly initialize new weights
            new_weights = torch.randn_like(init_weights)
            # Apply the pruning mask to keep the pruned weights zero
            new_weights = new_weights * mask
            # Assign the new weights
            setattr(module, param_name, torch.nn.Parameter(new_weights))
        # Make the pruning permanent by removing the mask
        prune.remove(module, param_name)

In [8]:
def plot_teacher_student_outputs(teacher_logits, student_logits):
    teacher_probs = torch.nn.functional.softmax(teacher_logits, dim=0).cpu().numpy()
    student_probs = torch.nn.functional.softmax(student_logits, dim=0).cpu().numpy()
    plt.plot(teacher_probs, 'ko', label='teacher')
    plt.plot(student_probs, 'ro', label='student')
    plt.legend()
    plt.yscale('log')
    plt.show()

In [9]:
def compute_losses(net, loader):
    """Auxiliary function to compute per-sample losses"""

    criterion = nn.CrossEntropyLoss(reduction="none")
    all_losses = []

    for inputs, targets in loader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        logits = net(inputs)
        losses = criterion(logits, targets).numpy(force=True)
        for l in losses:
            all_losses.append(l)

    return np.array(all_losses)

In [10]:
def simple_mia(sample_loss, members, n_splits=10, random_state=0):
    """Computes cross-validation score of a membership inference attack.

    Args:
      sample_loss : array_like of shape (n,).
        objective function evaluated on n samples.
      members : array_like of shape (n,),
        whether a sample was used for training.
      n_splits: int
        number of splits to use in the cross-validation.
    Returns:
      scores : array_like of size (n_splits,)
    """

    unique_members = np.unique(members)
    if not np.all(unique_members == np.array([0, 1])):
        raise ValueError("members should only have 0 and 1s")

    attack_model = linear_model.LogisticRegression()
    cv = model_selection.StratifiedShuffleSplit(
        n_splits=n_splits, random_state=random_state
    )
    return model_selection.cross_val_score(
        attack_model, sample_loss, members, cv=cv, scoring="accuracy"
    )

In [11]:
def calc_mia_acc(forget_loss, test_loss):
    # make sure we have a balanced dataset for the MIA
    assert len(test_loss) == len(forget_loss)

    ft_samples_mia = np.concatenate((test_loss, forget_loss)).reshape((-1, 1))
    labels_mia = [0] * len(test_loss) + [1] * len(forget_loss)

    ft_mia_scores = simple_mia(ft_samples_mia, labels_mia)

    return ft_mia_scores.mean()

In [12]:
def get_all_metrics(test_losses, student_model, retain_loader, forget_loader, val_loader, test_loader):
    
    print(f"Retain set accuracy: {100.0 * accuracy(student_model, retain_loader):0.1f}%")
    print(f"Forget set accuracy: {100.0 * accuracy(student_model, forget_loader):0.1f}%")
    print(f"Val set accuracy: {100.0 * accuracy(student_model, val_loader):0.1f}%")
    print(f"Test set accuracy: {100.0 * accuracy(student_model, test_loader):0.1f}%")

    ft_forget_losses = compute_losses(student_model, forget_loader)
    # ft_test_losses = compute_losses(model, test_loader)

    ft_mia_scores = calc_mia_acc(ft_forget_losses, test_losses)

    print(
        f"The MIA has an accuracy of {ft_mia_scores.mean():.3f} on forgotten vs unseen images"
    )

    return ft_forget_losses, test_losses, ft_mia_scores

In [13]:
print(f"Retain set accuracy: {100.0 * accuracy(model, retain_loader):0.1f}%")
print(f"Forget set accuracy: {100.0 * accuracy(model, forget_loader):0.1f}%")
print(f"Val set accuracy: {100.0 * accuracy(model, val_loader):0.1f}%")
print(f"Test set accuracy: {100.0 * accuracy(model, test_loader):0.1f}%")

Retain set accuracy: 99.5%
Forget set accuracy: 99.3%
Val set accuracy: 88.9%
Test set accuracy: 88.3%


In [14]:
test_losses = compute_losses(model, test_loader)

In [15]:
def average_gradient_from_loader(model, optimizer, loader, num_batches):
    last_linear_layer = model.fc
    avg_grad = None
    criterion = nn.CrossEntropyLoss()
    
    count = 0
    for i, (data, target) in enumerate(loader):
        if i >= num_batches:
            break
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        if avg_grad is None:
            avg_grad = last_linear_layer.weight.grad.clone()
        else:
            avg_grad += last_linear_layer.weight.grad.clone()

        count +=1

        return avg_grad / count

In [19]:
# Only re-train last layer
for name, param in model.named_parameters():
    if name=='fc.weight':
        param.requires_grad = True
    else:
        param.requires_grad = False

In [None]:
for x in range(20, 30, 2):

    print('---'*5)
    print(x/100)

    model = resnet18(weights=None, num_classes=10)
    model.load_state_dict(weights_pretrained)
    model.to(DEVICE)

    optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)

    num_batches = len(forget_loader.dataset)

    grad1 = average_gradient_from_loader(model, optimizer, retain_loader, num_batches)

    optimizer.zero_grad()

    grad2 = average_gradient_from_loader(model, optimizer, forget_loader, num_batches)

    grad_diff = torch.abs(grad1 - grad2)

    _, indices = torch.sort(grad_diff, descending=True)

    top_x_percent = int(x/100 * len(indices))

    for idx in indices[:top_x_percent]:
        for class_idx in range(0,10):
            model.fc.weight.data[class_idx, idx] = torch.randn_like(model.fc.weight.data[class_idx, idx])

    get_all_metrics(test_losses, model, retain_loader, forget_loader, val_loader, test_loader)