In [1]:
import sys

sys.path.append("../src")
from data_utils import UnlearningDataLoader

import mlflow
import torch
import torch.nn as nn
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import LambdaLR

from config import set_config
from data_utils import UnlearningDataLoader
from eval import (
    compute_accuracy,
    get_forgetting_rate,
    get_js_div,
    get_l2_params_distance,
    mia,
)
from mlflow_utils import mlflow_tracking_uri
from models import VGG19, AllCNN, ResNet18, ViT
from seed import set_seed

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Args
run_id = "1bcdd3b016d14404ab22c476184bff75"
mlflow.set_tracking_uri(mlflow_tracking_uri)

In [3]:
# Setup
# Load params from retraining run
retrain_run = mlflow.get_run(run_id)
seed = int(retrain_run.data.params["seed"])
dataset = retrain_run.data.params["dataset"]
model_str = retrain_run.data.params["model"]
batch_size = int(retrain_run.data.params["batch_size"])
epochs_to_retrain = int(retrain_run.data.metrics["best_epoch"])
loss_str = retrain_run.data.params["loss"]
optimizer_str = retrain_run.data.params["optimizer"]
momentum = float(retrain_run.data.params["momentum"])
weight_decay = float(retrain_run.data.params["weight_decay"])
acc_forget_retrain = int(retrain_run.data.metrics["acc_forget"])

In [4]:
# Load data
UDL = UnlearningDataLoader(dataset, batch_size, seed)
dl, dataset_sizes = UDL.load_data()
num_classes = len(UDL.classes)
input_channels = UDL.input_channels
image_size = UDL.image_size

In [5]:
# Load model architecture
if model_str == "resnet18":
    model = ResNet18(input_channels, num_classes)
elif model_str == "allcnn":
    model = AllCNN(input_channels, num_classes)
elif model_str == "vgg19":
    model = VGG19(input_channels, num_classes)
elif model_str == "vit":
    model = ViT(image_size=image_size, num_classes=num_classes)
else:
    raise ValueError("Model not supported")
# Load the original model
model = mlflow.pytorch.load_model(f"{retrain_run.info.artifact_uri}/original_model")
model.to(DEVICE)

Downloading artifacts: 100%|██████████| 6/6 [00:01<00:00,  3.83it/s]   


ResNet18(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (identity_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, m

In [6]:
import numpy as np


def original_lrp_fc_layer(model, dataloader):
    # Define a hook function to get the activations
    def get_activations_hook(module, input, output):
        activations = input[0].detach().cpu().numpy()
        activations_hook.append(activations)


    activations_hook = []
    model.fc.register_forward_hook(get_activations_hook)

    for idx, (inputs, _) in enumerate(dataloader):
        inputs = inputs.to(DEVICE)
        outputs = model(inputs)

        # Get the activations
        batch_activations = activations_hook[idx]
        batch_activations = torch.from_numpy(batch_activations).to(DEVICE)
        # print(f"batch_activations.shape: {batch_activations.shape}")

        T = torch.eye(outputs.size(-1)).to(DEVICE)
        T = T[outputs.argmax(dim=1)] # Select the row from the identity matrix that corresponds to the outputs's highest logit

        # Compute the relevance of the outputs layer
        # print(f"outputs.shape: {outputs.shape}")
        # print(f"T.shape: {T.shape}")
        R = outputs * T  # Element-wise multiplication
        # print(f"R.shape: {R.shape}")

        Z = torch.nn.functional.linear(batch_activations, model.fc.weight)
        # print(f"Z.shape: {Z.shape}")

        # print(f"model.fc.weight.shape: {model.fc.weight.shape}")

        # print(Z[0])
        # print(outputs[0])

        S = R / Z
        # print(S1[0])

        C = torch.mm(S, model.fc.weight)
        # C2 = torch.mm(S2, model.fc.weight)

        # print(C2[0] - C1[0])

        # print(f"C.shape: {C2.shape}")

    return C
    

In [7]:
def approximation_lrp_fc_layer(model, dataloader):

    for idx, (inputs, _) in enumerate(dataloader):
        inputs = inputs.to(DEVICE)
        outputs = model(inputs)

        T = torch.eye(outputs.size(-1)).to(DEVICE)
        T = T[
            outputs.argmax(dim=1)
        ]  # Select the row from the identity matrix that corresponds to the outputs's highest logit
        # R = outputs * T  # Element-wise multiplication

        # Z = torch.nn.functional.linear(batch_activations, model.fc.weight)

        # S = (output * T) / outputs = T

        # C = torch.mm(S, model.fc.weight)
        C = torch.mm(T, model.fc.weight)


    return C

This proves that for the last fc layer specifically there is no need to compute the linear transformation of the activations/inputs to the weights' space. Because this gives us the target (which is already known)

So, I can just 

In [8]:
C = original_lrp_fc_layer(model, dl["forget"])

In [9]:
relevance_per_neuron = torch.sum(C, dim=0)
print(relevance_per_neuron.shape)

torch.Size([512])


In [10]:
normalized_relevance = (relevance_per_neuron - relevance_per_neuron.min()) / (relevance_per_neuron.max() - relevance_per_neuron.min())
normalized_relevance = (normalized_relevance * 2) - 1
normalized_relevance

tensor([ 3.0672e-01, -7.8033e-02, -5.6061e-01,  3.3651e-01, -2.8531e-01,
         1.9930e-01, -6.3308e-02,  4.0247e-01, -1.0779e-01, -3.6619e-01,
        -4.1586e-01, -2.4042e-01,  3.0840e-01,  4.7466e-01, -1.6652e-01,
        -1.8996e-01,  2.6708e-01,  2.9191e-01, -5.8430e-01,  1.3592e-01,
         3.9389e-01, -4.7251e-02, -2.7985e-01, -1.4802e-01, -4.5134e-01,
         1.9388e-01,  3.5965e-01,  5.0864e-01, -2.1801e-01, -4.3966e-01,
         9.9386e-02, -8.1007e-02,  4.5876e-01, -3.0670e-02, -8.0710e-01,
         2.3430e-01,  4.2383e-01,  1.5230e-01,  7.5739e-02, -8.9421e-02,
        -3.3195e-01,  5.3842e-02,  3.5124e-01, -1.9192e-01, -1.1338e-02,
        -1.5301e-01,  2.1557e-02,  2.1885e-01,  3.1724e-01, -6.0119e-01,
         4.7164e-01,  2.3765e-01,  4.6690e-01,  1.0392e-01, -4.7191e-01,
         1.6880e-01, -1.7206e-01,  1.6652e-02,  8.2775e-01, -1.1498e-01,
        -1.9668e-01,  2.3255e-01,  3.3301e-02,  2.0710e-01, -1.0032e-01,
         1.5225e-01, -2.4244e-01,  3.9236e-01, -2.4

In [11]:
mask = torch.where(relevance_per_neuron > 0.5, torch.tensor(1), torch.tensor(0))
count_ones = torch.sum(mask)
print(count_ones)


tensor(194, device='cuda:0')


In [12]:
mask

tensor([1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0,
        0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
        1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0,
        0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0,
        1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0,
        0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
        0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1,
        0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1,
        0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0,
        1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0,
        0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1,
        1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1,

In [29]:
C_appr = approximation_lrp_fc_layer(model, dl["forget"])
relevance_per_neuron_appr = torch.sum(C_appr, dim=0)
normalized_relevance_appr = (relevance_per_neuron - relevance_per_neuron.min()) / (
    relevance_per_neuron.max() - relevance_per_neuron.min()
)
normalized_relevance_appr = (normalized_relevance_appr * 2) - 1

rpn_diff = normalized_relevance - normalized_relevance_appr
sum_diff = torch.sum(rpn_diff)
print(sum_diff)

tensor(0., device='cuda:0', grad_fn=<SumBackward0>)


In [30]:
rpn_diff

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [31]:
print(C.shape)

torch.Size([92, 512])


In [32]:
print(torch.max(rpn_diff))

tensor(0., device='cuda:0', grad_fn=<MaxBackward1>)


In [33]:
print(torch.min(rpn_diff))

tensor(0., device='cuda:0', grad_fn=<MinBackward1>)
