Experiment A2: Activation Patching with Counterfactuals

Same setup as A1
TODO: record adversarial label to test label correctness

https://nnsight.net/notebooks/tutorials/activation_patching/

In [1]:
import nnsight
from nnsight import NNsight
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from robustness import model_utils
from robustness.datasets import CIFAR

import torchvision.transforms as transforms
from PIL import Image

import os
import sys
# Add the parent directory to the path to import custom modules
sys.path.append(os.path.abspath(os.path.join('..')))
from models.resnet import ResNet18

2025-04-16 19:16:16.097800: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
# load saved model
datapath = '/u/yshi23/distribution-shift/datasets'
model_pth = "/u/yshi23/distribution-shift/adversarial/out/checkpoint.pt.best"
device = "cpu"

# model init using robustness mechanism
# model, _ = model_utils.make_and_restore_model(
#     arch='resnet18', dataset=CIFAR(datapath), resume_path=model_pth, device=device
# )
# Load the checkpoint with map_location to ensure it's on the right device
model = ResNet18(num_classes=10)
checkpoint = torch.load(model_pth, map_location=device)

# Get the state dict
state_dict = checkpoint["model"]

# Create a new state dict with properly formatted keys
new_state_dict = {}
for k, v in state_dict.items():
    # Remove 'module.' prefix if present
    if k.startswith('module.'):
        name = k[7:]  # remove 'module.'
    else:
        name = k
    new_state_dict[name] = v

# Load the modified state dict with strict=False to ignore missing/unexpected keys
model.load_state_dict(new_state_dict, strict=False)

print("Model loaded successfully on", device)

Model loaded successfully on cpu


In [3]:
def load_cifar10_image(image_path):
    # Load the image
    img = Image.open(image_path)
    
    # Define the transforms that match ResNet18 expectations
    # ResNet18 expects 224x224 images, but CIFAR-10 is 32x32
    # We'll resize and normalize according to ImageNet stats as typically used with ResNet
    transform = transforms.Compose([
        transforms.Resize(32),  # Resize to 224x224
        transforms.ToTensor(),   # Convert to tensor
        transforms.Normalize(          # Normalize with CIFAR-10 stats
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2470, 0.2435, 0.2616]
        )
    ])
    
    # Apply transformations
    img_tensor = transform(img)
    
    # Add batch dimension
    img_tensor = img_tensor.unsqueeze(0)  # Shape: [1, 3, 224, 224]
    
    return img_tensor

# Usage example
orig_path = "images/natural_image_0.png"
clean_img = load_cifar10_image(orig_path)

low_ep_path = "images/mask_small_epsilon_0.png"
low_eps_img = load_cifar10_image(low_ep_path)

high_ep_path = "images/mask_large_epsilon_0.png"
high_eps_img = load_cifar10_image(high_ep_path)


In [4]:
# Get all layer names (keys) from the model's state dictionary
layer_names = list(model.state_dict().keys())

# Print all layer names
print("Model layer names:")
for name in layer_names:
    print(name)

Model layer names:
conv1.weight
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.bn1.running_mean
layer1.0.bn1.running_var
layer1.0.bn1.num_batches_tracked
layer1.0.conv2.weight
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.0.bn2.running_mean
layer1.0.bn2.running_var
layer1.0.bn2.num_batches_tracked
layer1.1.conv1.weight
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.bn1.running_mean
layer1.1.bn1.running_var
layer1.1.bn1.num_batches_tracked
layer1.1.conv2.weight
layer1.1.bn2.weight
layer1.1.bn2.bias
layer1.1.bn2.running_mean
layer1.1.bn2.running_var
layer1.1.bn2.num_batches_tracked
layer2.0.conv1.weight
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.bn1.running_mean
layer2.0.bn1.running_var
layer2.0.bn1.num_batches_tracked
layer2.0.conv2.weight
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.bn2.running_mean
layer2.0.bn2.running_var
layer2.0.bn2.num_batches_tracked
layer2.0.shortcut.0.weight


In [5]:
nnsight_model = NNsight(model)

In [9]:
# First trace
with nnsight_model.trace(high_eps_img) as tracer:
    conv1_out = nnsight_model.conv1.output.save()
    layer1_out = nnsight_model.layer1.output.save()
    layer2_out = nnsight_model.layer2.output.save()
    layer3_out = nnsight_model.layer3.output.save()
    layer4_out = nnsight_model.layer4.output.save()
    linear_out = nnsight_model.linear.output.save()

# Edit layer1
with nnsight_model.edit() as layer1_edit:
    nnsight_model.layer1.output = layer1_out

# Edit layer2
with nnsight_model.edit() as layer2_edit:
    nnsight_model.layer2.output = layer2_out

with nnsight_model.edit() as linear_edit:
    nnsight_model.linear.output = linear_out

# Test if patching works at all
with nnsight_model.edit() as extreme_edit:
    # Replace with zeros or significantly altered values
    nnsight_model.layer1.output = layer1_out * 100  # Multiply by 100 to create extreme change

print("extremely edited: ", extreme_edit(clean_img))

# Compare results
print("layer1 edited: ", layer1_edit(clean_img))
print("layer2 edited: ", layer2_edit(clean_img))
print("linear edited: ", linear_edit(clean_img))
print("original: ", nnsight_model(clean_img))
print("original with adv input: ", nnsight_model(high_eps_img))

extremely edited:  tensor([[ 0.6504,  0.8534, -1.1871,  0.7822, -0.0573,  0.2202, -0.6878, -0.3623,
          0.5581, -0.2089]], grad_fn=<AddmmBackward0>)
layer1 edited:  tensor([[ 0.6504,  0.8534, -1.1871,  0.7822, -0.0573,  0.2202, -0.6878, -0.3623,
          0.5581, -0.2089]], grad_fn=<AddmmBackward0>)
layer2 edited:  tensor([[ 0.6504,  0.8534, -1.1871,  0.7822, -0.0573,  0.2202, -0.6878, -0.3623,
          0.5581, -0.2089]], grad_fn=<AddmmBackward0>)
linear edited:  tensor([[ 0.6504,  0.8534, -1.1871,  0.7822, -0.0573,  0.2202, -0.6878, -0.3623,
          0.5581, -0.2089]], grad_fn=<AddmmBackward0>)
original:  tensor([[ 0.6504,  0.8534, -1.1871,  0.7822, -0.0573,  0.2202, -0.6878, -0.3623,
          0.5581, -0.2089]], grad_fn=<AddmmBackward0>)
original with adv input:  tensor([[ 0.6713,  0.7972, -1.0669,  0.8259,  0.0118,  0.2595, -0.6503, -0.4557,
          0.5525, -0.2232]], grad_fn=<AddmmBackward0>)


In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd


def _save_main_layers(model):
    return {
        "conv1": model.conv1.output.save(),
        "bn1": model.bn1.output.save(),
        "layer1": model.layer1.output.save(),
        "layer2": model.layer2.output.save(),
        "layer3": model.layer3.output.save(),
        "layer4": model.layer4.output.save(),
        "linear": model.linear.output.save(),
    }


def _save_block_outputs(layer, prefix):
    return {f"{prefix}.block{i}": block.output.save() for i, block in enumerate(layer)}


def _compute_diffs(clean_act, adv_act):
    return {
        "mean": torch.mean(torch.abs(adv_act - clean_act)).item(),
        "max": torch.max(torch.abs(adv_act - clean_act)).item(),
    }


def visualize_activation_differences(model, clean_img, adv_img):
    """
    Visualize differences in activations between clean and adversarial images
    across all layers of a ResNet model.
    """
    diffs = {}

    # Trace clean image
    with model.trace(clean_img):
        clean_main = _save_main_layers(model)
        clean_blocks = {
            **_save_block_outputs(model.layer1, "layer1"),
            **_save_block_outputs(model.layer2, "layer2"),
            **_save_block_outputs(model.layer3, "layer3"),
            **_save_block_outputs(model.layer4, "layer4"),
        }

    # Trace adversarial image
    with model.trace(adv_img):
        adv_main = _save_main_layers(model)
        adv_blocks = {
            **_save_block_outputs(model.layer1, "layer1"),
            **_save_block_outputs(model.layer2, "layer2"),
            **_save_block_outputs(model.layer3, "layer3"),
            **_save_block_outputs(model.layer4, "layer4"),
        }

    # Compute diffs for main layers
    for name in clean_main:
        diffs[name] = _compute_diffs(clean_main[name], adv_main[name])

    # Compute diffs for residual blocks
    for name in clean_blocks:
        diffs[name] = _compute_diffs(clean_blocks[name], adv_blocks[name])

    # Create a DataFrame for plotting
    df = pd.DataFrame.from_dict(diffs, orient="index").reset_index()
    df.columns = ["Layer", "Mean Absolute Difference", "Max Absolute Difference"]

    # Sorting logic based on hierarchical naming
    def get_sort_index(layer_name):
        parts = layer_name.split(".")
        if layer_name.startswith("conv1"):
            return 0
        elif layer_name.startswith("bn1"):
            return 1
        elif layer_name.startswith("relu"):
            return 2
        elif layer_name.startswith("maxpool"):
            return 3
        elif "layer" in parts[0]:
            layer_num = int(parts[0][-1])
            block_num = int(parts[1][-1]) if len(parts) > 1 else -1
            return 4 + layer_num * 10 + (block_num + 1 if block_num >= 0 else 0)
        elif layer_name.startswith("avgpool"):
            return 100
        elif layer_name.startswith("linear"):
            return 101
        return 999  # fallback

    df["SortIdx"] = df["Layer"].apply(get_sort_index)
    df = df.sort_values("SortIdx")

    # Plot
    plt.figure(figsize=(14, 10))

    plt.subplot(2, 1, 1)
    sns.barplot(x="Layer", y="Mean Absolute Difference", data=df)
    plt.xticks(rotation=90)
    plt.title(
        "Mean Absolute Activation Differences Between Clean and Adversarial Images"
    )

    plt.subplot(2, 1, 2)
    sns.barplot(x="Layer", y="Max Absolute Difference", data=df)
    plt.xticks(rotation=90)
    plt.title(
        "Maximum Absolute Activation Differences Between Clean and Adversarial Images"
    )

    plt.tight_layout()
    plt.savefig("activation_differences.png")
    plt.show()

    return df

In [17]:
visualize_activation_differences(nnsight_model, clean_img, high_eps_img)

AttributeError: 'ResNet' object has no attribute 'maxpool'