In [1]:
import os
import requests
import random
import math
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection
import copy

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import prune
from torch import optim
from torch.utils.data import DataLoader, TensorDataset

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

# manual random seed is used for dataset partitioning
# to ensure reproducible results across runs
SEED = 42
RNG = torch.Generator().manual_seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

import sys
sys.path.append('../')
from utils import *

Running on device: CUDA


In [2]:
import ssl

# Create an unverified SSL context
ssl._create_default_https_context = ssl._create_unverified_context

In [3]:
batch_size = 32

In [4]:
# download and pre-process CIFAR10
normalize = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

train_set = torchvision.datasets.CIFAR10(
    root="../example notebooks/data", train=True, download=False, transform=normalize
)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=1)

# we split held out data into test and validation set
held_out = torchvision.datasets.CIFAR10(
    root="../example notebooks/data", train=False, download=False, transform=normalize
)
test_set, val_set = torch.utils.data.random_split(held_out, [0.5, 0.5], generator=RNG)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=1)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=1)

# download the forget and retain index split
local_path = "../example notebooks/forget_idx.npy"
# if not os.path.exists(local_path):
#     response = requests.get(
#         "https://storage.googleapis.com/unlearning-challenge/" + local_path
#     )
#     open(local_path, "wb").write(response.content)
forget_idx = np.load(local_path)

# construct indices of retain from those of the forget set
forget_mask = np.zeros(len(train_set.targets), dtype=bool)
forget_mask[forget_idx] = True
retain_idx = np.arange(forget_mask.size)[~forget_mask]

# split train set into a forget and a retain set
forget_set = torch.utils.data.Subset(train_set, forget_idx)
retain_set = torch.utils.data.Subset(train_set, retain_idx)

forget_loader = torch.utils.data.DataLoader(
    forget_set, batch_size=batch_size, shuffle=True, num_workers=1
)
# retain_loader = torch.utils.data.DataLoader(
#     retain_set, batch_size=128, shuffle=True, num_workers=1, generator=RNG
# )

In [5]:
retain_loader = torch.utils.data.DataLoader(
    retain_set, batch_size=batch_size, shuffle=True, num_workers=1
)

- Retain set accuracy: 99.5%
- Forget set accuracy: 99.3%
- Val set accuracy: 88.9%
- Test set accuracy: 88.3%

In [6]:
local_path = "../example notebooks/weights/weights_resnet18_cifar10.pth"
if not os.path.exists(local_path):
    response = requests.get(
        "https://storage.googleapis.com/unlearning-challenge/weights_resnet18_cifar10.pth"
    )
    open(local_path, "wb").write(response.content)

weights_pretrained = torch.load(local_path, map_location=DEVICE) #43Mbs

# load net with pre-trained weights
net = resnet18(weights=None, num_classes=10)
net.load_state_dict(weights_pretrained)
net.to(DEVICE)
net.eval();

In [7]:
val_losses = compute_losses(net, val_loader)
# test_losses = compute_losses(net, test_loader)

In [8]:
# Extract feature and pooling layers to create a Custom Model
class CustomResNet18(nn.Module):
    def __init__(self, original_model):
        super(CustomResNet18, self).__init__()
        
        # Extract features and pooling layers
        self.features = nn.Sequential(*list(original_model.children())[:-2])
        self.pooling = list(original_model.children())[-2]
    
    def forward(self, x):
        x = self.features(x)
        x = self.pooling(x)
        x = torch.squeeze(x)
        return x

custom_model = CustomResNet18(net).to(DEVICE)

In [9]:
# Initialize
val_embeddings = {}
retain_embeddings = {}

# Compute embeddings for val_loader
for val_batch in val_loader:
    images = val_batch[0].to(DEVICE)
    person_ids = val_batch[1]
    embeddings = custom_model(images)
    for i, person_id in enumerate(person_ids):
        val_embeddings.setdefault(person_id.item(), []).append(embeddings[i].detach())

# Compute embeddings for retain_loader
for retain_batch in retain_loader:
    images = retain_batch[0].to(DEVICE)
    targets = retain_batch[1]
    embeddings = net(images)
    for i, target in enumerate(targets):
        if target in range(0,10): # TODO [0, 1]:  # Only consider targets 0 and 1
            retain_embeddings.setdefault(target.item(), []).append(embeddings[i].detach())

In [45]:
# Contrastive Loss
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
    
    def forward(self, output1, output2, label):
        euclidean_distance = nn.functional.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

criterion = ContrastiveLoss()
optimizer = optim.AdamW(custom_model.parameters(), lr=0.001)

# Step 3: Contrastive Learning for one epoch
for batch in forget_loader:
    custom_model.train()
    optimizer.zero_grad()
    
    inputs = batch[0].to(DEVICE)
    targets = batch[1]
    person_ids = batch[1]
    
    # Forward pass to get embeddings for the forget_batch
    forget_embeddings = custom_model(inputs)
    
    positive_pairs = []
    negative_pairs = []

    with torch.no_grad():  # Disable gradient computation to save memory
        # Fetch Positive Pairs
        for pid in person_ids:
            pid = pid.item()
            if pid in val_embeddings:
                selected_embedding = random.choice(val_embeddings[pid])
                positive_pairs.append(selected_embedding)
            else:
                print(f"Skipping person_id {pid} for positive pairs, not found in val_embeddings.")
                continue

        # Convert to tensors for ease of computation
        positive_pairs = torch.stack(positive_pairs).to(DEVICE)

        # Fetch Negative Pairs
        for tgt in targets:
            tgt = tgt.item()
            if tgt in retain_embeddings:
                selected_embedding = random.choice(retain_embeddings[tgt])
                negative_pairs.append(selected_embedding)
            else:
                print(f"Skipping target {tgt} for negative pairs, not found in retain_embeddings.")
                continue

        # Convert to tensors for ease of computation
        # negative_pairs = torch.stack(negative_pairs).to(DEVICE)

    import sys
    sys.exit()
    
    # Compute Contrastive Loss
    positive_loss = criterion(forget_embeddings, positive_pairs, torch.zeros(positive_pairs.shape[0]).to(DEVICE))
    negative_loss = criterion(forget_embeddings, negative_pairs, torch.ones(negative_pairs.shape[0]).to(DEVICE))
    
    # Total loss
    loss = positive_loss + negative_loss

    loss.backward()
    optimizer.step()

    with torch.no_grad():
        net.eval()
        forget_acc = accuracy(net, forget_loader)
    print(f"Forget set accuracy: {100.0 * forget_acc:0.2f}%")
    print('--'*10)

    if forget_acc < 0.86:
        break

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [52]:


print("output1 shape: ", forget_embeddings.shape)
print("output2 shape: ", negative_pairs.shape)
print("label shape: ", torch.ones(negative_pairs.shape[0]).to(DEVICE).shape)
euclidean_distance = nn.functional.pairwise_distance(forget_embeddings, negative_pairs)

output1 shape:  torch.Size([32, 512])
output2 shape:  torch.Size([32, 10])
label shape:  torch.Size([32])


RuntimeError: The size of tensor a (512) must match the size of tensor b (10) at non-singleton dimension 1

In [42]:
criterion(forget_embeddings, positive_pairs, torch.zeros(positive_pairs.shape[0]).to(DEVICE))

tensor(20.1439, device='cuda:0', grad_fn=<MeanBackward0>)

In [None]:
negative_pairs = []

# Fetch Negative Pairs
for tgt in targets:
    tgt = tgt.item()
    if tgt in retain_embeddings:
        print('a')
        selected_embedding = random.choice(retain_embeddings[tgt])
        negative_pairs.append(selected_embedding)

In [50]:
negative_pairs = torch.stack(negative_pairs).to(DEVICE)

In [53]:
negative_pairs

tensor([[-1.8050e+00, -1.8805e+00, -1.4124e+00, -1.8482e-01, -2.2213e-01,
          9.9724e-01, -1.0107e+00,  9.7968e+00, -2.1353e+00, -2.1434e+00],
        [-4.2494e-02,  5.8165e+00, -2.1119e-01, -2.0664e+00, -3.3988e+00,
         -1.3326e+00, -1.6098e+00,  1.3352e-01, -1.8334e+00,  4.5444e+00],
        [-1.0446e+00, -2.5142e+00, -3.9545e+00,  1.0687e+01, -8.3356e-01,
          4.8609e+00, -1.7591e+00, -1.5781e+00, -1.9732e+00, -1.8909e+00],
        [ 8.5357e+00,  1.5755e+00, -6.3768e-01, -2.6231e+00, -1.8164e+00,
         -1.8184e+00, -1.6429e+00, -1.8333e+00,  1.0776e+00, -8.1713e-01],
        [-6.0232e-01, -2.1849e+00, -1.7728e+00, -1.9746e+00,  1.0377e+01,
         -3.5582e+00, -9.5972e-01,  1.2621e+00, -7.8732e-01,  2.0070e-01],
        [-2.8061e+00, -2.1484e+00, -1.5990e+00,  7.1116e-01,  4.6521e+00,
          9.5043e+00, -1.1965e+00, -7.6086e-01, -2.8754e+00, -3.4815e+00],
        [ 8.3634e-01, -1.0836e+00,  1.1408e+00, -9.6132e-01, -3.2298e+00,
         -3.0115e+00,  7.7856e+0

In [40]:
criterion(forget_embeddings, negative_pairs, torch.ones(negative_pairs.shape[0]).to(DEVICE))

RuntimeError: The size of tensor a (512) must match the size of tensor b (10) at non-singleton dimension 1

In [32]:
negative_pairs.shape

torch.Size([32, 10])

In [29]:
forget_embeddings.shape

torch.Size([32, 512])

In [31]:
retain_embeddings.keys()

dict_keys([2, 9, 6, 5, 8, 3, 4, 0, 7, 1])

In [21]:
person_ids

tensor([5, 2, 1, 7, 0, 9, 7, 8])

In [20]:
forget_embeddings.shape

torch.Size([32, 512])

In [19]:
positive_pairs.shape

torch.Size([8, 512])

In [None]:
# Initialize custom model with original model's weights
custom_model = CustomResNet18(net)

# Enable training for layers before pooling
for param in custom_model.features[-1].parameters():  # Assuming the last layer is layer4
    param.requires_grad = True

# Contrastive Learning
optimizer = optim.SGD(custom_model.features[-1].parameters(), lr=0.001)
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)

In [None]:
# For Contrastive Learning
for sample in forget_loader:
    optimizer.zero_grad()
    inputs_forget = sample[0]
    person_id_forget = sample[1]
    
    embeddings_forget = custom_model(inputs_forget)

    # Find Positive Pairs in val_loader
    embeddings_val = []
    for val_sample in val_loader:
        if val_sample[1] == person_id_forget:
            embeddings_val.append(custom_model(val_sample[0]))
    embeddings_val = torch.stack(embeddings_val)

In [None]:
# Apply pruning
# pct = 0.10
# unstructure_prune(net, pct, global_pruning=True, random_init=False)

torch.save({
    'net': net.state_dict(),
}, f'./checkpoints/temp_checkpoint.pth')

In [None]:
LR = 1e-4

In [None]:
def forget_step(net, forget_loader, starting_forget_acc, counter_bool=0):

    # Initialize for all epochs
    forget_acc = copy.copy(starting_forget_acc)
    print(f'Starting with {100.0 * forget_acc:0.2f}% forget accuracy')
    forget_acc_threshold_1 = forget_acc*0.90
    forget_acc_threshold_2 = forget_acc*0.80
    print(f'Thresholds: {100.0 * forget_acc_threshold_1:0.2f}%, {100.0 * forget_acc_threshold_2:0.2f}%')

    
    iter_forget = iter(forget_loader)
    current_batch = 0

    initial_forget_lr = LR
    if counter_bool>0:
        initial_forget_lr = initial_forget_lr*(1.5**counter_bool)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.0)
    forget_optimizer = optim.AdamW(net.parameters(), lr=initial_forget_lr)

    not_depleted = True
    counter = 0

    net.train()

    while (forget_acc > forget_acc_threshold_1) & (not_depleted):

        try:
            sample = next(iter_forget)
            counter+=1
            
        except StopIteration:
            not_depleted = False
            print('depleted')
            break

        inputs, targets = sample
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        forget_optimizer.zero_grad()

        # Forward pass
        logits = net(inputs)

        # Calculate loss
        classification_loss = criterion(logits, targets)

        loss = -1*classification_loss
        loss.backward()
        forget_optimizer.step()

        current_batch+=1
        print(current_batch)

        with torch.no_grad():
            net.eval()
            forget_acc = accuracy(net, forget_loader)
        print(f"Forget set accuracy: {100.0 * forget_acc:0.2f}%")
        print('--'*10)

        if (forget_acc<forget_acc_threshold_2):
            initial_forget_lr = initial_forget_lr/2
            current_batch = 0
            forget_acc = starting_forget_acc
            print('Restoring')
            checkpoint = torch.load(f'./checkpoints/temp_checkpoint.pth')
            net.load_state_dict(checkpoint['net'])
            forget_optimizer = optim.AdamW(net.parameters(), lr=initial_forget_lr)

        if counter>=10:
            break

    return net, counter

In [None]:
def retrain_step(net, retain_loader):

    initial_retain_lr = LR
    criterion = nn.CrossEntropyLoss()
    optimizer_retain = optim.SGD(net.parameters(), lr=initial_retain_lr, momentum=0.9, weight_decay=5e-4)

    warmup_current_batch = 0
    warmup_batches = math.ceil(0.4*len(retain_loader.dataset))
    
    net.train()

    for sample in retain_loader:

        inputs, targets = sample
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        warmup_current_batch += 1

        # Warm-up for the first 'warmup_batches' batches
        if warmup_current_batch <= warmup_batches:
            adjust_learning_rate(optimizer_retain, warmup_current_batch, warmup_batches, initial_retain_lr)

        optimizer_retain.zero_grad()

        # Forward pass
        logits = net(inputs)

        # Calculate loss
        criterion = nn.CrossEntropyLoss(label_smoothing=0.4)
        classification_loss = criterion(logits, targets)
        loss = classification_loss
        loss.backward()
        optimizer_retain.step()

    torch.save({
        'net': net.state_dict(),
    }, f'./checkpoints/temp_checkpoint.pth')

    return net

In [None]:
epochs = 10
counter_bool = 0

for ep in range(epochs):

    print(f'Epoch: {ep}')
    print('****'*10)

    if ep!=epochs-1:
        with torch.no_grad():
            net.eval()
            starting_forget_acc = accuracy(net, forget_loader)

        if starting_forget_acc > 0.86:
            net, counter = forget_step(net, forget_loader, starting_forget_acc, counter_bool)
            if counter>=10:
                counter_bool+=1
                print(f'COUNTER BOOL ON = {counter_bool}')
        else:
            print('Not doing forget step')

    print('""" """'*5)
    print('Retrain')
    net = retrain_step(net, retain_loader)

    ft_forget_losses = compute_losses(net, forget_loader)

    ft_mia_scores = calc_mia_acc(ft_forget_losses, val_losses)

    print(
        f"The MIA has an accuracy of {ft_mia_scores.mean():.3f} on forgotten vs unseen images"
    )

    if np.abs(0.5-ft_mia_scores.mean())<0.01:
        break

In [None]:
fig = plt.Figure(figsize=(16, 6))

plt.title(
    f"Unlearned by fine-tuning.\nAttack accuracy: {ft_mia_scores.mean():0.2f}"
)
plt.hist(val_losses, density=True, alpha=0.5, bins=50, label="Test set")
plt.hist(ft_forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")

plt.xlabel("Loss")
plt.yscale("log")
plt.xlim((0, np.max(val_losses)))
plt.legend()

plt.show()