In [7]:
import nnsight
from nnsight import NNsight
import torch

from robustness import model_utils
from robustness.datasets import CIFAR

import torchvision.transforms as transforms
from PIL import Image

import os
import sys
# Add the parent directory to the path to import custom modules
sys.path.append(os.path.abspath(os.path.join('..')))
from models.resnet import ResNet18

In [8]:
# load saved model
datapath = '/u/yshi23/distribution-shift/datasets'
model_pth = "/u/yshi23/distribution-shift/adversarial/out/checkpoint.pt.best"
device = "cpu"

# model init using robustness mechanism
# model, _ = model_utils.make_and_restore_model(
#     arch='resnet18', dataset=CIFAR(datapath), resume_path=model_pth, device=device
# )
# Load the checkpoint with map_location to ensure it's on the right device
model = ResNet18(num_classes=10)
checkpoint = torch.load(model_pth, map_location=device)

# Get the state dict
state_dict = checkpoint["model"]

# Create a new state dict with properly formatted keys
new_state_dict = {}
for k, v in state_dict.items():
    # Remove 'module.' prefix if present
    if k.startswith('module.'):
        print("Removing 'module.' prefix from key:", k)
        name = k[7:]  # remove 'module.'
    else:
        name = k
    new_state_dict[name] = v

# Load the modified state dict with strict=False to ignore missing/unexpected keys
model.load_state_dict(new_state_dict, strict=False)

print("Model loaded successfully on", device)

Removing 'module.' prefix from key: module.normalizer.new_mean
Removing 'module.' prefix from key: module.normalizer.new_std
Removing 'module.' prefix from key: module.model.conv1.weight
Removing 'module.' prefix from key: module.model.bn1.weight
Removing 'module.' prefix from key: module.model.bn1.bias
Removing 'module.' prefix from key: module.model.bn1.running_mean
Removing 'module.' prefix from key: module.model.bn1.running_var
Removing 'module.' prefix from key: module.model.bn1.num_batches_tracked
Removing 'module.' prefix from key: module.model.layer1.0.conv1.weight
Removing 'module.' prefix from key: module.model.layer1.0.bn1.weight
Removing 'module.' prefix from key: module.model.layer1.0.bn1.bias
Removing 'module.' prefix from key: module.model.layer1.0.bn1.running_mean
Removing 'module.' prefix from key: module.model.layer1.0.bn1.running_var
Removing 'module.' prefix from key: module.model.layer1.0.bn1.num_batches_tracked
Removing 'module.' prefix from key: module.model.laye

In [11]:
def load_cifar10_image(image_path):
    # Load the image
    img = Image.open(image_path)
    
    # Define the transforms that match ResNet18 expectations
    # ResNet18 expects 224x224 images, but CIFAR-10 is 32x32
    # We'll resize and normalize according to ImageNet stats as typically used with ResNet
    transform = transforms.Compose([
        transforms.Resize(32),  # Resize to 224x224
        transforms.ToTensor(),   # Convert to tensor
        transforms.Normalize(          # Normalize with CIFAR-10 stats
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2470, 0.2435, 0.2616]
        )
    ])
    
    # Apply transformations
    img_tensor = transform(img)
    
    # Add batch dimension
    img_tensor = img_tensor.unsqueeze(0)  # Shape: [1, 3, 224, 224]
    
    return img_tensor

# Usage example
image_path = "images/adv_image_small_epsilon_5.png"
# print current path
print("Current working directory:", os.getcwd())
img_tensor = load_cifar10_image(image_path)


Current working directory: /u/yshi23/distribution-shift/interpretability


In [12]:
# Get all layer names (keys) from the model's state dictionary
layer_names = list(model.state_dict().keys())

# Print all layer names
print("Model layer names:")
for name in layer_names:
    print(name)

# Optionally, to save the list to a file
with open('model_layer_names.txt', 'w') as f:
    for name in layer_names:
        if "attacker" not in name: # filter out attacker layers
            f.write(name + '\n')

Model layer names:
conv1.weight
bn1.weight
bn1.bias
bn1.running_mean
bn1.running_var
bn1.num_batches_tracked
layer1.0.conv1.weight
layer1.0.bn1.weight
layer1.0.bn1.bias
layer1.0.bn1.running_mean
layer1.0.bn1.running_var
layer1.0.bn1.num_batches_tracked
layer1.0.conv2.weight
layer1.0.bn2.weight
layer1.0.bn2.bias
layer1.0.bn2.running_mean
layer1.0.bn2.running_var
layer1.0.bn2.num_batches_tracked
layer1.1.conv1.weight
layer1.1.bn1.weight
layer1.1.bn1.bias
layer1.1.bn1.running_mean
layer1.1.bn1.running_var
layer1.1.bn1.num_batches_tracked
layer1.1.conv2.weight
layer1.1.bn2.weight
layer1.1.bn2.bias
layer1.1.bn2.running_mean
layer1.1.bn2.running_var
layer1.1.bn2.num_batches_tracked
layer2.0.conv1.weight
layer2.0.bn1.weight
layer2.0.bn1.bias
layer2.0.bn1.running_mean
layer2.0.bn1.running_var
layer2.0.bn1.num_batches_tracked
layer2.0.conv2.weight
layer2.0.bn2.weight
layer2.0.bn2.bias
layer2.0.bn2.running_mean
layer2.0.bn2.running_var
layer2.0.bn2.num_batches_tracked
layer2.0.shortcut.0.weight


In [13]:
nnsight_model = NNsight(model)
print(nnsight_model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [14]:
with nnsight_model.trace(img_tensor) as tracer:
    # Forward pass through the model

    conv1_out = nnsight_model.conv1.output.save()
    layer1_out = nnsight_model.layer1.output.save()
    layer2_out = nnsight_model.layer2.output.save()
    layer3_out = nnsight_model.layer3.output.save()
    layer4_out = nnsight_model.layer4.output.save()
    linear_out = nnsight_model.linear.output.save()


print(conv1_out.shape)
print(layer1_out.shape)
print(layer2_out.shape)
print(layer3_out.shape)
print(layer4_out.shape)
print(linear_out.shape)

torch.Size([1, 64, 32, 32])
torch.Size([1, 64, 32, 32])
torch.Size([1, 128, 16, 16])
torch.Size([1, 256, 8, 8])
torch.Size([1, 512, 4, 4])
torch.Size([1, 10])


In [15]:
out = model(img_tensor)