In [None]:
import os
import requests
import random
import math
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection
import copy

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import prune
from torch import optim
from torch.utils.data import DataLoader, TensorDataset

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

# manual random seed is used for dataset partitioning
# to ensure reproducible results across runs
SEED = 42
RNG = torch.Generator().manual_seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

import sys
sys.path.append('../')
from utils import *

In [None]:
import ssl

# Create an unverified SSL context
ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
batch_size = 32

In [None]:
# download and pre-process CIFAR10
normalize = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_set = torchvision.datasets.CIFAR10(
    root="../example notebooks/data", train=True, download=False, transform=normalize
)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=1)

# we split held out data into test and validation set
held_out = torchvision.datasets.CIFAR10(
    root="../example notebooks/data", train=False, download=False, transform=normalize
)
test_set, val_set = torch.utils.data.random_split(held_out, [0.5, 0.5], generator=RNG)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=1)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=1)

# download the forget and retain index split
local_path = "../example notebooks/forget_idx.npy"
# if not os.path.exists(local_path):
#     response = requests.get(
#         "https://storage.googleapis.com/unlearning-challenge/" + local_path
#     )
#     open(local_path, "wb").write(response.content)
forget_idx = np.load(local_path)

# construct indices of retain from those of the forget set
forget_mask = np.zeros(len(train_set.targets), dtype=bool)
forget_mask[forget_idx] = True
retain_idx = np.arange(forget_mask.size)[~forget_mask]

# split train set into a forget and a retain set
forget_set = torch.utils.data.Subset(train_set, forget_idx)
retain_set = torch.utils.data.Subset(train_set, retain_idx)

forget_loader = torch.utils.data.DataLoader(
    forget_set, batch_size=batch_size, shuffle=True, num_workers=1
)
# retain_loader = torch.utils.data.DataLoader(
#     retain_set, batch_size=128, shuffle=True, num_workers=1, generator=RNG
# )

In [None]:
retain_loader = torch.utils.data.DataLoader(
    retain_set, batch_size=batch_size, shuffle=True, num_workers=1
)

- Retain set accuracy: 99.5%
- Forget set accuracy: 99.3%
- Val set accuracy: 88.9%
- Test set accuracy: 88.3%

In [None]:
local_path = "../example notebooks/weights/weights_resnet18_cifar10.pth"
if not os.path.exists(local_path):
    response = requests.get(
        "https://storage.googleapis.com/unlearning-challenge/weights_resnet18_cifar10.pth"
    )
    open(local_path, "wb").write(response.content)

weights_pretrained = torch.load(local_path, map_location=DEVICE) #43Mbs

# load net with pre-trained weights
net = resnet18(weights=None, num_classes=10)
net.load_state_dict(weights_pretrained)
net.to(DEVICE)
net.eval();

In [None]:
val_losses = compute_losses(net, val_loader)
# test_losses = compute_losses(net, test_loader)

In [None]:
# Apply pruning
# pct = 0.10
# unstructure_prune(net, pct, global_pruning=True, random_init=False)

torch.save({
    'net': net.state_dict(),
}, f'./checkpoints/temp_checkpoint.pth')

In [None]:
LR = 1e-4

In [None]:
def forget_step(net, forget_loader, starting_forget_acc, counter_bool=0):

    # Initialize for all epochs
    forget_acc = copy.copy(starting_forget_acc)
    print(f'Starting with {100.0 * forget_acc:0.2f}% forget accuracy')
    forget_acc_threshold_1 = forget_acc*0.90
    forget_acc_threshold_2 = forget_acc*0.80
    print(f'Thresholds: {100.0 * forget_acc_threshold_1:0.2f}%, {100.0 * forget_acc_threshold_2:0.2f}%')

    
    iter_forget = iter(forget_loader)
    current_batch = 0

    initial_forget_lr = LR
    if counter_bool>0:
        initial_forget_lr = initial_forget_lr*(1.5**counter_bool)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.0)
    forget_optimizer = optim.AdamW(net.parameters(), lr=initial_forget_lr)

    not_depleted = True
    counter = 0

    net.train()

    while (forget_acc > forget_acc_threshold_1) & (not_depleted):

        try:
            sample = next(iter_forget)
            counter+=1
            
        except StopIteration:
            not_depleted = False
            print('depleted')
            break

        inputs, targets = sample
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        forget_optimizer.zero_grad()

        # Forward pass
        logits = net(inputs)

        # Calculate loss
        classification_loss = criterion(logits, targets)

        loss = -1*classification_loss
        loss.backward()
        forget_optimizer.step()

        current_batch+=1
        print(current_batch)

        with torch.no_grad():
            net.eval()
            forget_acc = accuracy(net, forget_loader)
        print(f"Forget set accuracy: {100.0 * forget_acc:0.2f}%")
        print('--'*10)

        if (forget_acc<forget_acc_threshold_2):
            initial_forget_lr = initial_forget_lr/2
            current_batch = 0
            forget_acc = starting_forget_acc
            print('Restoring')
            checkpoint = torch.load(f'./checkpoints/temp_checkpoint.pth')
            net.load_state_dict(checkpoint['net'])
            forget_optimizer = optim.AdamW(net.parameters(), lr=initial_forget_lr)

        if counter>=10:
            break

    return net, counter

In [None]:
def retrain_step(net, retain_loader):

    initial_retain_lr = LR
    criterion = nn.CrossEntropyLoss()
    optimizer_retain = optim.SGD(net.parameters(), lr=initial_retain_lr, momentum=0.9, weight_decay=5e-4)

    warmup_current_batch = 0
    warmup_batches = math.ceil(0.4*len(retain_loader.dataset))
    
    net.train()

    for sample in retain_loader:

        inputs, targets = sample
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        warmup_current_batch += 1

        # Warm-up for the first 'warmup_batches' batches
        if warmup_current_batch <= warmup_batches:
            adjust_learning_rate(optimizer_retain, warmup_current_batch, warmup_batches, initial_retain_lr)

        optimizer_retain.zero_grad()

        # Forward pass
        logits = net(inputs)

        # Calculate loss
        criterion = nn.CrossEntropyLoss(label_smoothing=0.4)
        classification_loss = criterion(logits, targets)
        loss = classification_loss
        loss.backward()
        optimizer_retain.step()

    torch.save({
        'net': net.state_dict(),
    }, f'./checkpoints/temp_checkpoint.pth')

    return net

In [None]:
epochs = 10
counter_bool = 0

for ep in range(epochs):

    print(f'Epoch: {ep}')
    print('****'*10)

    if ep!=epochs-1:
        with torch.no_grad():
            net.eval()
            starting_forget_acc = accuracy(net, forget_loader)

        if starting_forget_acc > 0.86:
            net, counter = forget_step(net, forget_loader, starting_forget_acc, counter_bool)
            if counter>=10:
                counter_bool+=1
                print(f'COUNTER BOOL ON = {counter_bool}')
        else:
            print('Not doing forget step')

    print('""" """'*5)
    print('Retrain')
    net = retrain_step(net, retain_loader)

    ft_forget_losses = compute_losses(net, forget_loader)

    ft_mia_scores = calc_mia_acc(ft_forget_losses, val_losses)

    print(
        f"The MIA has an accuracy of {ft_mia_scores.mean():.3f} on forgotten vs unseen images"
    )

    if np.abs(0.5-ft_mia_scores.mean())<0.01:
        break

In [None]:
fig = plt.Figure(figsize=(16, 6))

plt.title(
    f"Unlearned by fine-tuning.\nAttack accuracy: {ft_mia_scores.mean():0.2f}"
)
plt.hist(val_losses, density=True, alpha=0.5, bins=50, label="Test set")
plt.hist(ft_forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")

plt.xlabel("Loss")
plt.yscale("log")
plt.xlim((0, np.max(val_losses)))
plt.legend()

plt.show()