In [1]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader
from PIL import Image
import torchvision
from transformers import (
    ResNetForImageClassification,
    ResNetConfig,
    AutoImageProcessor,
    Trainer,
    TrainingArguments,
)

In [54]:
model = torchvision.models.resnet18(weights=None, num_classes=10)
dataset = load_dataset("uoft-cs/cifar10")

In [55]:
# 0 airplane 1 automobile 2 bird 3 cat 4 deer 5 dog 6 frog 7 horse 8 ship 9 truck

id2label = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}
label2id = {v: k for k, v in id2label.items()}

In [56]:
from torchvision.transforms import (
    RandomResizedCrop,
    Compose,
    Normalize,
    ToTensor,
    Resize,
)

normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
size = (32, 32)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
_transforms_test = Compose([ToTensor(), normalize])
dataset_train = dataset["train"].with_format("torch")
dataset_test = dataset["test"].with_format("torch")


def collate_fn(batch):
    imgs = []
    labels = []
    for i in range(len(batch)):
        img = batch[i]["img"]
        img = _transforms(torchvision.transforms.ToPILImage()(img).convert("RGB"))
        imgs.append(img)
        labels.append(batch[i]["label"])
    return {
        "img": torch.stack(imgs),
        "label": torch.tensor(labels),
    }


def collate_fn_test(batch):
    imgs = []
    labels = []
    for i in range(len(batch)):
        img = batch[i]["img"]
        img = _transforms_test(torchvision.transforms.ToPILImage()(img).convert("RGB"))
        imgs.append(img)
        labels.append(batch[i]["label"])
    return {
        "img": torch.stack(imgs),
        "label": torch.tensor(labels),
    }


train_loader = DataLoader(
    dataset_train, batch_size=32, shuffle=True, collate_fn=collate_fn
)
test_loader = DataLoader(
    dataset_test, batch_size=32, shuffle=False, collate_fn=collate_fn_test
)

In [6]:
from tqdm.notebook import tqdm

NUM_EPOCHS = 30
LEARNING_RATE = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = model.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
        inputs, labels = data["img"].to(DEVICE), data["label"].to(DEVICE)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, loss: {running_loss/len(train_loader)}")

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data["img"].to(DEVICE), data["label"].to(DEVICE)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Accuracy: {100 * correct / total}")


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 1, loss: 1.795039008995393
Accuracy: 47.6


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 2, loss: 1.5914045251186124
Accuracy: 54.97


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 3, loss: 1.4620253069997216
Accuracy: 59.42


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 4, loss: 1.363623392864137
Accuracy: 62.53


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 5, loss: 1.2850833551592349
Accuracy: 65.04


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 6, loss: 1.2280053770931119
Accuracy: 67.98


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 7, loss: 1.1696820822344784
Accuracy: 68.9


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 8, loss: 1.1249324651536312
Accuracy: 69.33


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 9, loss: 1.0877842302514587
Accuracy: 71.24


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 10, loss: 1.0636299866861207
Accuracy: 70.87


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 11, loss: 1.029019861402835
Accuracy: 71.76


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 12, loss: 1.0010183738807952
Accuracy: 74.67


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 13, loss: 0.9766985210217655
Accuracy: 74.46


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 14, loss: 0.9551631154460322
Accuracy: 74.95


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 15, loss: 0.9427501573176699
Accuracy: 76.15


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 16, loss: 0.9157915212600107
Accuracy: 72.8


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 17, loss: 0.9107270503913601
Accuracy: 75.53


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 18, loss: 0.8868892821682925
Accuracy: 76.04


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 19, loss: 0.8835234440295878
Accuracy: 76.74


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 20, loss: 0.8573895084766417
Accuracy: 76.48


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 21, loss: 0.8494890102650672
Accuracy: 76.59


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 22, loss: 0.8358208187215235
Accuracy: 77.06


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 23, loss: 0.8235932132104079
Accuracy: 78.13


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 24, loss: 0.8112926377139638
Accuracy: 77.45


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 25, loss: 0.8069932798086948
Accuracy: 75.88


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 26, loss: 0.7991814075260687
Accuracy: 78.32


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 27, loss: 0.7886353340274008
Accuracy: 78.41


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 28, loss: 0.7877932414936851
Accuracy: 79.06


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 29, loss: 0.7667092678643005
Accuracy: 77.8


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 30, loss: 0.7632537742246059
Accuracy: 79.35


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 31, loss: 0.7595700959478977
Accuracy: 80.03


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 32, loss: 0.7462837995338043
Accuracy: 78.61


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 33, loss: 0.7390682436644993
Accuracy: 78.66


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 34, loss: 0.7421268224716187
Accuracy: 79.9


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 35, loss: 0.7294148210714966
Accuracy: 80.14


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 36, loss: 0.7193061078082882
Accuracy: 79.45


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 37, loss: 0.7021297211076537
Accuracy: 79.68


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 38, loss: 0.7151760959655752
Accuracy: 80.0


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 39, loss: 0.698286418188709
Accuracy: 77.72


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 40, loss: 0.6911750316848682
Accuracy: 81.11


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 41, loss: 0.6901041402869399
Accuracy: 79.84


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 42, loss: 0.6869095488477043
Accuracy: 79.77


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 43, loss: 0.6797422567049968
Accuracy: 78.75


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 44, loss: 0.6717341112560442
Accuracy: 80.91


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 45, loss: 0.6730729290215693
Accuracy: 79.88


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 46, loss: 0.6666935441380346
Accuracy: 79.88


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 47, loss: 0.6554229899201726
Accuracy: 79.72


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 48, loss: 0.6614931393092973
Accuracy: 79.44


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 49, loss: 0.6405021957266582
Accuracy: 79.28


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 50, loss: 0.6473206681600383
Accuracy: 79.53


In [57]:
model = torchvision.models.resnet18(weights=None, num_classes=10)
model.load_state_dict(torch.load("resnet18_cifar10.pth", weights_only=True))
model = model.to(DEVICE)

In [58]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data["img"].to(DEVICE), data["label"].to(DEVICE)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f"Accuracy: {100 * correct / total}")


Accuracy: 79.53


In [95]:
classes_to_forget = set([label2id["airplane"], label2id["automobile"]])

dataset_classes_to_forget = dataset_train.filter(
    lambda x: x["label"].item() in classes_to_forget
)

Filter:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [98]:
dataset_classes_to_forget[0]

{'img': tensor([[[178, 178, 178,  ..., 170, 168, 165],
          [180, 179, 180,  ..., 173, 171, 168],
          [177, 177, 178,  ..., 171, 169, 167],
          ...,
          [112, 113, 114,  ..., 100,  98, 101],
          [112, 112, 113,  ..., 102, 102, 102],
          [103, 100, 103,  ...,  92,  93,  91]],
 
         [[176, 176, 176,  ..., 168, 166, 163],
          [178, 177, 178,  ..., 171, 169, 166],
          [175, 175, 176,  ..., 169, 167, 165],
          ...,
          [107, 109, 110,  ...,  97,  94,  95],
          [102, 103, 103,  ...,  95,  93,  92],
          [ 96,  93,  95,  ...,  84,  86,  84]],
 
         [[189, 189, 189,  ..., 180, 177, 174],
          [191, 190, 191,  ..., 182, 180, 177],
          [188, 188, 189,  ..., 180, 178, 176],
          ...,
          [107, 108, 110,  ...,  94,  93,  95],
          [101, 102, 103,  ...,  93,  91,  91],
          [ 92,  90,  94,  ...,  80,  80,  77]]], dtype=torch.uint8),
 'label': tensor(0)}

In [134]:
model = torchvision.models.resnet18(weights=None, num_classes=10)
model.load_state_dict(torch.load("resnet18_cifar10.pth", weights_only=True))
model = model.to(DEVICE)

LEARNING_RATE = 1e-5

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

In [154]:
# freeze weight for all layers except the last one
for name, param in model.named_parameters():
    if "fc" not in name:
        param.requires_grad = False

# Freeze all weight of the linear layer except the the ones that correspond to the classes to forget
for i, param in enumerate(model.fc.parameters()):
    if i in classes_to_forget:
        param[i].requires_grad = True


RuntimeError: you can only change requires_grad flags of leaf variables.

In [166]:
model.fc.weight[0].requires_grad = False

RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().

In [136]:
from sklearn.metrics import classification_report

classes_to_forget_loader = DataLoader(
    dataset_classes_to_forget, batch_size=32, shuffle=True, collate_fn=collate_fn_test
)
unlearning_rate = 1
perturbation_size = 3
NUM_EPOCHS = 1
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(
        enumerate(classes_to_forget_loader), total=len(classes_to_forget_loader)
    ):
        inputs, labels = data["img"].to(DEVICE), data["label"].to(DEVICE)

        # Add perturbation to the inputs
        perturbation = torch.randn_like(inputs) * perturbation_size
        inputs_perturbed = inputs + perturbation

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = (
            criterion(model(inputs_perturbed), labels) - criterion(outputs, labels)
        ) * unlearning_rate
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, loss: {running_loss/len(train_loader)}")

    model.eval()
    correct = 0
    total = 0
    preds = []
    labs = []
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data["img"].to(DEVICE), data["label"].to(DEVICE)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            preds.extend(predicted.cpu().numpy())
            labs.extend(labels.cpu().numpy())
    print(classification_report(labs, preds))
    print(f"Accuracy: {100 * correct / total}")


  0%|          | 0/313 [00:00<?, ?it/s]

Epoch 1, loss: 0.46964182810987787
              precision    recall  f1-score   support

           0       0.49      0.63      0.55      1000
           1       0.57      0.48      0.52      1000
           2       0.34      0.62      0.44      1000
           3       0.24      0.63      0.35      1000
           4       0.63      0.04      0.07      1000
           5       0.29      0.55      0.38      1000
           6       1.00      0.01      0.01      1000
           7       0.74      0.46      0.57      1000
           8       0.76      0.51      0.61      1000
           9       0.91      0.13      0.23      1000

    accuracy                           0.41     10000
   macro avg       0.60      0.41      0.37     10000
weighted avg       0.60      0.41      0.37     10000

Accuracy: 40.6


In [137]:
labels

tensor([7, 5, 8, 0, 8, 2, 7, 0, 3, 5, 3, 8, 3, 5, 1, 7], device='cuda:0')

In [138]:
# Method 1: Gradient-based importance


def gradient_importance(model, input_tensor, target_class):
    model.eval()
    input_tensor.requires_grad_()

    # Dictionary to store gradients of each layer
    layer_gradients = {}

    # Hook to capture gradients
    def save_gradients(module, grad_input, grad_output):
        layer_gradients[module] = grad_output[0].abs().sum(dim=(0, 2, 3))

    # Register hooks on each convolutional layer
    hooks = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            hook = module.register_backward_hook(save_gradients)
            hooks.append(hook)

    # Forward pass
    output = model(input_tensor)
    class_score = output[0, target_class]

    # Backward pass
    model.zero_grad()
    class_score.backward()

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return layer_gradients

In [170]:
target_class = 3

inputs = (
    _transforms_test(
        torchvision.transforms.ToPILImage()(
            dataset_classes_to_forget[0]["img"]
        ).convert("RGB")
    )
    .reshape(1, 3, 32, 32)
    .to(DEVICE)
)

model

  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
