In [2]:
import sys

sys.path.append("../src")
from data_utils import UnlearningDataLoader

import mlflow
import torch
import torch.nn as nn
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import LambdaLR

from config import set_config
from data_utils import UnlearningDataLoader
from eval import (
    compute_accuracy,
    get_forgetting_rate,
    get_js_div,
    get_l2_params_distance,
    mia,
)
from mlflow_utils import mlflow_tracking_uri
from models import VGG19, AllCNN, ResNet18, ViT
from seed import set_seed

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Args
run_id = "1bcdd3b016d14404ab22c476184bff75"
mlflow.set_tracking_uri(mlflow_tracking_uri)

In [4]:
# Setup
# Load params from retraining run
retrain_run = mlflow.get_run(run_id)
seed = int(retrain_run.data.params["seed"])
dataset = retrain_run.data.params["dataset"]
model_str = retrain_run.data.params["model"]
batch_size = int(retrain_run.data.params["batch_size"])
epochs_to_retrain = int(retrain_run.data.metrics["best_epoch"])
loss_str = retrain_run.data.params["loss"]
optimizer_str = retrain_run.data.params["optimizer"]
momentum = float(retrain_run.data.params["momentum"])
weight_decay = float(retrain_run.data.params["weight_decay"])
acc_forget_retrain = int(retrain_run.data.metrics["acc_forget"])

In [5]:
# Load data
UDL = UnlearningDataLoader(dataset, batch_size, seed)
dl, dataset_sizes = UDL.load_data()
num_classes = len(UDL.classes)
input_channels = UDL.input_channels
image_size = UDL.image_size

In [6]:
# Load model architecture
if model_str == "resnet18":
    model = ResNet18(input_channels, num_classes)
elif model_str == "allcnn":
    model = AllCNN(input_channels, num_classes)
elif model_str == "vgg19":
    model = VGG19(input_channels, num_classes)
elif model_str == "vit":
    model = ViT(image_size=image_size, num_classes=num_classes)
else:
    raise ValueError("Model not supported")
# Load the original model
model = mlflow.pytorch.load_model(f"{retrain_run.info.artifact_uri}/original_model")
model.to(DEVICE)

Downloading artifacts: 100%|██████████| 6/6 [00:01<00:00,  4.12it/s]   


ResNet18(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (identity_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, m

In [7]:
import numpy as np


def lrp_fc_layer(model, dataloader):
    # Define a hook function to get the activations
    def get_activations_hook(module, input, output):
        activations = input[0].detach().cpu().numpy()
        activations_hook.append(activations)


    activations_hook = []
    model.fc.register_forward_hook(get_activations_hook)

    for idx, (inputs, _) in enumerate(dataloader):
        inputs = inputs.to(DEVICE)
        outputs = model(inputs)

        # Get the activations
        batch_activations = activations_hook[idx]
        batch_activations = torch.from_numpy(batch_activations).to(DEVICE)
        # print(f"batch_activations.shape: {batch_activations.shape}")

        T = torch.eye(outputs.size(-1)).to(DEVICE)
        T = T[outputs.argmax(dim=1)] # Select the row from the identity matrix that corresponds to the outputs's highest logit

        # Compute the relevance of the outputs layer
        # print(f"outputs.shape: {outputs.shape}")
        # print(f"T.shape: {T.shape}")
        R = outputs * T  # Element-wise multiplication
        # print(f"R.shape: {R.shape}")

        Z = torch.nn.functional.linear(batch_activations, model.fc.weight)
        # print(f"Z.shape: {Z.shape}")

        # print(f"model.fc.weight.shape: {model.fc.weight.shape}")

        # print(Z[0])
        # print(outputs[0])

        S1 = R / Z
        # print(S1[0])

        C1 = torch.mm(S1, model.fc.weight)
        # C2 = torch.mm(S2, model.fc.weight)

        # print(C2[0] - C1[0])

        # print(f"C.shape: {C2.shape}")

    return C1    
    

This proves that for the last fc layer specifically there is no need to compute the linear transformation of the activations/inputs to the weights' space. Because this gives us the target (which is already known)

So, I can just 

In [8]:
C = lrp_fc_layer(model, dl["forget"])

In [9]:
relevance_per_neuron = torch.sum(C, dim=0)
print(relevance_per_neuron.shape)

torch.Size([512])


In [11]:
normalized_relevance = (relevance_per_neuron - relevance_per_neuron.min()) / (relevance_per_neuron.max() - relevance_per_neuron.min())
normalized_relevance = (normalized_relevance * 2) - 1
normalized_relevance

tensor([ 5.6443e-01,  3.0196e-01, -3.4898e-01,  3.0182e-01,  7.7416e-02,
         2.7558e-01,  2.2935e-01,  6.2738e-02,  1.3916e-01, -3.6274e-01,
        -1.1860e-01, -2.5921e-01,  7.5304e-01,  2.4125e-01,  3.0494e-01,
        -1.0073e-02,  1.1075e-01,  5.0309e-03, -4.1824e-01,  4.4073e-02,
         3.7878e-01,  2.5172e-03, -2.3915e-01, -6.6012e-02, -1.8909e-01,
         1.3688e-02,  8.5580e-03,  2.9859e-01, -2.3933e-01, -4.6247e-01,
        -1.0160e-01, -1.0404e-01,  3.0131e-01, -3.0913e-02, -7.7992e-01,
         5.5145e-01,  1.3846e-01, -3.3275e-01,  3.3197e-01,  2.3156e-01,
        -1.8119e-01,  2.7211e-02,  2.5591e-01, -7.4714e-02,  1.3309e-02,
         4.9302e-01,  1.8665e-01,  1.7270e-01,  1.5868e-01, -4.8621e-01,
         1.4368e-01,  4.7177e-03,  8.9674e-01,  4.0898e-01, -4.1888e-01,
        -1.5632e-02, -1.4628e-01,  1.7971e-01,  6.4151e-01,  8.0303e-02,
        -2.0325e-01,  5.4051e-01,  1.2389e-01,  2.7504e-01, -3.2037e-01,
         1.6246e-01, -7.8501e-02,  9.9449e-02,  5.5

In [14]:
mask = torch.where(relevance_per_neuron > 0.5, torch.tensor(1), torch.tensor(0))
count_ones = torch.sum(mask)
print(count_ones)


tensor(193, device='cuda:0')


In [15]:
mask

tensor([1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0,
        0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0,
        1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0,
        0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0,
        0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1,
        1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0,
        1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0,
        0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1,
        1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0,

In [None]:
Ουσιαστικά, βρήκαμε τους πιο σημαντικούς νευρώνες από το προτελευταίο layer.
Και μπορούμε να πειράξουμε τα βάρη τους στο τελευταιο layer.