In [1]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader
from PIL import Image
import torchvision
from transformers import (
    ResNetForImageClassification,
    ResNetConfig,
    AutoImageProcessor,
    Trainer,
    TrainingArguments,
)
import torch.nn.functional as F

In [2]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = Net()
dataset = load_dataset("uoft-cs/cifar10")

In [3]:
# 0 airplane 1 automobile 2 bird 3 cat 4 deer 5 dog 6 frog 7 horse 8 ship 9 truck

id2label = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}
label2id = {v: k for k, v in id2label.items()}

In [4]:
from torchvision.transforms import (
    RandomResizedCrop,
    Compose,
    Normalize,
    ToTensor,
    Resize,
)


dataset = load_dataset("uoft-cs/cifar10")
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
size = (32, 32)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])
_transforms_test = Compose([ToTensor(), normalize])
dataset_train = dataset["train"].with_format("torch")
dataset_test = dataset["test"].with_format("torch")


def collate_fn(batch):
    imgs = []
    labels = []
    for i in range(len(batch)):
        img = batch[i]["img"]
        img = _transforms(torchvision.transforms.ToPILImage()(img).convert("RGB"))
        imgs.append(img)
        labels.append(batch[i]["label"])
    return {
        "img": torch.stack(imgs),
        "label": torch.tensor(labels),
    }


def collate_fn_test(batch):
    imgs = []
    labels = []
    for i in range(len(batch)):
        img = batch[i]["img"]
        img = _transforms_test(torchvision.transforms.ToPILImage()(img).convert("RGB"))
        imgs.append(img)
        labels.append(batch[i]["label"])
    return {
        "img": torch.stack(imgs),
        "label": torch.tensor(labels),
    }


train_loader = DataLoader(
    dataset_train, batch_size=32, shuffle=True, collate_fn=collate_fn
)
test_loader = DataLoader(
    dataset_test, batch_size=32, shuffle=False, collate_fn=collate_fn_test
)

In [5]:
from tqdm.notebook import tqdm

NUM_EPOCHS = 10
LEARNING_RATE = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = model.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
        inputs, labels = data["img"].to(DEVICE), data["label"].to(DEVICE)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, loss: {running_loss/len(train_loader)}")

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data["img"].to(DEVICE), data["label"].to(DEVICE)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Accuracy: {100 * correct / total}")


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 1, loss: 1.8471073164668361
Accuracy: 43.96


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 2, loss: 1.696151414286686
Accuracy: 48.95


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 3, loss: 1.6348652745086416
Accuracy: 49.58


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 4, loss: 1.5987679278781912
Accuracy: 51.23


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 5, loss: 1.5744074581528198
Accuracy: 53.21


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 6, loss: 1.5488800437345156
Accuracy: 54.54


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 7, loss: 1.5362191085089343
Accuracy: 54.07


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 8, loss: 1.5224077966223903
Accuracy: 55.12


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 9, loss: 1.5124725585401784
Accuracy: 54.91


  0%|          | 0/1563 [00:00<?, ?it/s]

Epoch 10, loss: 1.5009080273786266
Accuracy: 55.96


In [6]:
torch.save(model.state_dict(), "model.pth")

In [7]:
model = Net()
model.load_state_dict(torch.load("model.pth"))


  model.load_state_dict(torch.load("model.pth"))


<All keys matched successfully>

In [19]:
model.conv1.bias[0]

tensor(0.0421, grad_fn=<SelectBackward0>)

In [51]:
img = dataset_test[6]["img"]
img = _transforms_test(torchvision.transforms.ToPILImage()(img).convert("RGB"))

# Convolve the image with each filter of conv1 of the model and average the output, return the filter with the max average output


def get_most_activated_filter(model, img):
    model.eval()
    with torch.no_grad():
        img = img.unsqueeze(0)
        conv1 = model.conv1(img)
        conv1 = F.relu(conv1)
        avg = torch.mean(conv1, dim=(0, 2, 3))
        return torch.argmax(avg).item()


# For each class in test_loader, get the most activated filter on average and save the result in a dict

most_activated_filters = {}
for data in test_loader:
    inputs, labels = data["img"], data["label"]
    for i in range(len(inputs)):
        label = labels[i].item()
        img = inputs[i]
        filter_id = get_most_activated_filter(model, img)
        if label not in most_activated_filters:
            most_activated_filters[label] = []
        most_activated_filters[label].append(filter_id)


most_activated = []

for i in range(10):
    count = np.bincount(most_activated_filters[i])
    # get top 3 most activated filters
    top_k = 3
    top_arr = []
    for i in range(top_k):
        m_ac = np.argmax(count)
        top_arr.append(m_ac)
        count[m_ac] = 0
    most_activated.append(top_arr)


for i, arr in enumerate(most_activated):
    print(f"Class {id2label[i]}, top 3 most activated filters: {arr}")

Class airplane, top 3 most activated filters: [3, 4, 1]
Class automobile, top 3 most activated filters: [3, 1, 5]
Class bird, top 3 most activated filters: [4, 3, 1]
Class cat, top 3 most activated filters: [1, 0, 3]
Class deer, top 3 most activated filters: [4, 1, 3]
Class dog, top 3 most activated filters: [0, 1, 4]
Class frog, top 3 most activated filters: [4, 1, 0]
Class horse, top 3 most activated filters: [4, 1, 0]
Class ship, top 3 most activated filters: [3, 1, 2]
Class truck, top 3 most activated filters: [3, 1, 0]


In [49]:
img = dataset_test[6]["img"]
img = _transforms_test(torchvision.transforms.ToPILImage()(img).convert("RGB"))

# Convolve the image with each filter of conv1 of the model and average the output, return the filter with the max average output


def get_most_activated_filter(model, img):
    model.eval()
    with torch.no_grad():
        img = img.unsqueeze(0)
        conv1 = model.conv1(img)
        conv1 = F.relu(conv1)
        pool = model.pool(conv1)
        conv2 = model.conv2(pool)
        avg = torch.mean(conv2, dim=(0, 2, 3))
        return torch.argmax(avg).item()


# For each class in test_loader, get the most activated filter on average and save the result in a dict

most_activated_filters = {}
for data in test_loader:
    inputs, labels = data["img"], data["label"]
    for i in range(len(inputs)):
        label = labels[i].item()
        img = inputs[i]
        filter_id = get_most_activated_filter(model, img)
        if label not in most_activated_filters:
            most_activated_filters[label] = []
        most_activated_filters[label].append(filter_id)


most_activated = []

for i in range(10):
    count = np.bincount(most_activated_filters[i])
    # get top 3 most activated filters
    top_k = 8
    top_arr = []
    for i in range(top_k):
        m_ac = np.argmax(count)
        top_arr.append(m_ac)
        count[m_ac] = 0
    most_activated.append(top_arr)


for i, arr in enumerate(most_activated):
    print(f"Class {id2label[i]}, top 3 most activated filters: {arr}")


Class airplane, top 3 most activated filters: [8, 15, 4, 6, 3, 2, 0, 5]
Class automobile, top 3 most activated filters: [8, 9, 6, 2, 3, 15, 0, 4]
Class bird, top 3 most activated filters: [6, 8, 2, 4, 7, 15, 3, 0]
Class cat, top 3 most activated filters: [6, 8, 7, 4, 3, 9, 0, 2]
Class deer, top 3 most activated filters: [6, 8, 2, 7, 3, 4, 15, 0]
Class dog, top 3 most activated filters: [6, 8, 4, 7, 9, 0, 2, 3]
Class frog, top 3 most activated filters: [6, 8, 2, 7, 9, 0, 3, 4]
Class horse, top 3 most activated filters: [6, 8, 2, 7, 4, 9, 0, 1]
Class ship, top 3 most activated filters: [8, 15, 4, 6, 5, 3, 2, 9]
Class truck, top 3 most activated filters: [8, 6, 9, 5, 4, 2, 7, 3]


In [58]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data["img"].to(DEVICE), data["label"].to(DEVICE)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f"Accuracy: {100 * correct / total}")


Accuracy: 79.53


In [95]:
classes_to_forget = set([label2id["airplane"], label2id["automobile"]])

dataset_classes_to_forget = dataset_train.filter(
    lambda x: x["label"].item() in classes_to_forget
)

Filter:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [98]:
dataset_classes_to_forget[0]

{'img': tensor([[[178, 178, 178,  ..., 170, 168, 165],
          [180, 179, 180,  ..., 173, 171, 168],
          [177, 177, 178,  ..., 171, 169, 167],
          ...,
          [112, 113, 114,  ..., 100,  98, 101],
          [112, 112, 113,  ..., 102, 102, 102],
          [103, 100, 103,  ...,  92,  93,  91]],
 
         [[176, 176, 176,  ..., 168, 166, 163],
          [178, 177, 178,  ..., 171, 169, 166],
          [175, 175, 176,  ..., 169, 167, 165],
          ...,
          [107, 109, 110,  ...,  97,  94,  95],
          [102, 103, 103,  ...,  95,  93,  92],
          [ 96,  93,  95,  ...,  84,  86,  84]],
 
         [[189, 189, 189,  ..., 180, 177, 174],
          [191, 190, 191,  ..., 182, 180, 177],
          [188, 188, 189,  ..., 180, 178, 176],
          ...,
          [107, 108, 110,  ...,  94,  93,  95],
          [101, 102, 103,  ...,  93,  91,  91],
          [ 92,  90,  94,  ...,  80,  80,  77]]], dtype=torch.uint8),
 'label': tensor(0)}

In [134]:
model = torchvision.models.resnet18(weights=None, num_classes=10)
model.load_state_dict(torch.load("resnet18_cifar10.pth", weights_only=True))
model = model.to(DEVICE)

LEARNING_RATE = 1e-5

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

In [154]:
# freeze weight for all layers except the last one
for name, param in model.named_parameters():
    if "fc" not in name:
        param.requires_grad = False

# Freeze all weight of the linear layer except the the ones that correspond to the classes to forget
for i, param in enumerate(model.fc.parameters()):
    if i in classes_to_forget:
        param[i].requires_grad = True


RuntimeError: you can only change requires_grad flags of leaf variables.

In [166]:
model.fc.weight[0].requires_grad = False

RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().

In [136]:
from sklearn.metrics import classification_report

classes_to_forget_loader = DataLoader(
    dataset_classes_to_forget, batch_size=32, shuffle=True, collate_fn=collate_fn_test
)
unlearning_rate = 1
perturbation_size = 3
NUM_EPOCHS = 1
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(
        enumerate(classes_to_forget_loader), total=len(classes_to_forget_loader)
    ):
        inputs, labels = data["img"].to(DEVICE), data["label"].to(DEVICE)

        # Add perturbation to the inputs
        perturbation = torch.randn_like(inputs) * perturbation_size
        inputs_perturbed = inputs + perturbation

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = (
            criterion(model(inputs_perturbed), labels) - criterion(outputs, labels)
        ) * unlearning_rate
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, loss: {running_loss/len(train_loader)}")

    model.eval()
    correct = 0
    total = 0
    preds = []
    labs = []
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data["img"].to(DEVICE), data["label"].to(DEVICE)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            preds.extend(predicted.cpu().numpy())
            labs.extend(labels.cpu().numpy())
    print(classification_report(labs, preds))
    print(f"Accuracy: {100 * correct / total}")


  0%|          | 0/313 [00:00<?, ?it/s]

Epoch 1, loss: 0.46964182810987787
              precision    recall  f1-score   support

           0       0.49      0.63      0.55      1000
           1       0.57      0.48      0.52      1000
           2       0.34      0.62      0.44      1000
           3       0.24      0.63      0.35      1000
           4       0.63      0.04      0.07      1000
           5       0.29      0.55      0.38      1000
           6       1.00      0.01      0.01      1000
           7       0.74      0.46      0.57      1000
           8       0.76      0.51      0.61      1000
           9       0.91      0.13      0.23      1000

    accuracy                           0.41     10000
   macro avg       0.60      0.41      0.37     10000
weighted avg       0.60      0.41      0.37     10000

Accuracy: 40.6


In [137]:
labels

tensor([7, 5, 8, 0, 8, 2, 7, 0, 3, 5, 3, 8, 3, 5, 1, 7], device='cuda:0')

In [138]:
# Method 1: Gradient-based importance


def gradient_importance(model, input_tensor, target_class):
    model.eval()
    input_tensor.requires_grad_()

    # Dictionary to store gradients of each layer
    layer_gradients = {}

    # Hook to capture gradients
    def save_gradients(module, grad_input, grad_output):
        layer_gradients[module] = grad_output[0].abs().sum(dim=(0, 2, 3))

    # Register hooks on each convolutional layer
    hooks = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            hook = module.register_backward_hook(save_gradients)
            hooks.append(hook)

    # Forward pass
    output = model(input_tensor)
    class_score = output[0, target_class]

    # Backward pass
    model.zero_grad()
    class_score.backward()

    # Remove hooks
    for hook in hooks:
        hook.remove()

    return layer_gradients

In [170]:
target_class = 3

inputs = (
    _transforms_test(
        torchvision.transforms.ToPILImage()(
            dataset_classes_to_forget[0]["img"]
        ).convert("RGB")
    )
    .reshape(1, 3, 32, 32)
    .to(DEVICE)
)

model

  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
