In [16]:
from torchvision import models, transforms
from torchsummary import summary
import torch
import numpy as np
import cv2
import PIL
import matplotlib.pyplot as plt

In [17]:
model = models.resnet50(pretrained=True)

In [18]:
summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,

In [19]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [20]:
def get_image(path, transform):
    '''
    get image from path and transform it tensor
    '''
    img = PIL.Image.open(path)
    img_t = transform(img)
    img_t = img_t.unsqueeze(0)
    img_it = transform(img)
    return img_t
img = get_image('images/husky.jpg', transform)
type(img), img.shape

(torch.Tensor, torch.Size([1, 3, 224, 224]))

In [21]:
pred = model(img)
pred = torch.nn.functional.softmax(pred, dim=1)

In [22]:
c = torch.argmax(pred).item()
y_c = pred[0, c].item()
y_c

0.005702819209545851

In [23]:
target_layer = model.layer4[-1].conv3

weights = np.zeros((2048, 7, 7))

In [24]:
feature_maps = None

def hook_function(module, input, output):
    global feature_maps
    feature_maps = output.clone()

last_conv_layer = model.layer4[-1].conv3
hook = last_conv_layer.register_forward_hook(hook_function)

model.eval()

output = model(img)
output = torch.nn.functional.softmax(output, dim=1)
y_c = output[0, c].item()
argmax = torch.argmax(output).item()

In [25]:
for k in range(1):
    ablated_feature_maps = feature_maps.clone()
    ablated_feature_maps[:, k, :, :] = 0

    with torch.no_grad():
        ablated_output = model(img, feature_maps=ablated_feature_maps)
        y_c_k = ablated_output[0, c].item()

TypeError: ResNet.forward() got an unexpected keyword argument 'feature_maps'

In [None]:
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image

# Load a pre-trained ResNet model
model = models.resnet50(pretrained=True)
model.eval()

# Specify the target layer
target_layer = model.layer4[-1].conv3

# Preprocess the input image
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def get_class_activation(model, input_image, class_index):
    # Forward pass through the model to get the logits
    output = model(input_image)
    # Get the score for the target class
    y_c = output[0, class_index]
    return y_c




In [None]:
def calculate_weights(model, input_image, class_index, target_layer):
    # Perform a forward pass with all feature maps intact to get y^c
    y_c = get_class_activation(model, input_image, class_index)
    
    weights = []
    
    # Loop through each feature map in the target layer
    for i, feature_map in enumerate(target_layer):
        # Make a copy of the model
        model_copy = model
        # Set the feature map i in the target layer to zero
        feature_map.zero_()
        # Perform a forward pass to get y^c_k (after ablation)
        y_c_k = get_class_activation(model_copy, input_image, class_index)
        # Calculate the weight w_k^c
        weight_k = (y_c - y_c_k) / y_c
        weights.append(weight_k)
        
        # Restore the feature map to its original state
        feature_map.detach()
    
    return weights


In [None]:
import torch
from torchvision import models, transforms
from PIL import Image
import torch.nn.functional as F

# Load a pre-trained ResNet model
model = models.resnet50(pretrained=True)
model.eval()

# Specify the target layer
target_layer = model.layer4[-1].conv3

# Variable to store the feature maps
activation = None

# Hook function to capture the output of target_layer
def hook_fn(module, input, output):
    global activation
    activation = output

# Register the hook
hook = target_layer.register_forward_hook(hook_fn)

# Preprocess the input image
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to get class activation score
def get_class_activation(model, input_image, class_index):
    # Forward pass through the model to get the logits
    output = model(input_image)
    # Get the score for the target class
    y_c = output[0, class_index]
    return y_c


In [None]:
def calculate_weights(model, input_image, class_index):
    # Forward pass to get y^c and capture the activations
    y_c = get_class_activation(model, input_image, class_index)
    global activation
    original_activation = activation.clone().detach()  # Store the original activations
    
    weights = []
    # Loop through each feature map
    for i in range(activation.shape[1]):  # Number of feature maps
        print(i / activation.shape[1] * 100, '%')
        # Ablate the i-th feature map by setting it to zero
        activation[:, i, :, :] = 0
        # Forward pass with the ablated feature map
        y_c_k = get_class_activation(model, input_image, class_index)
        
        print(y_c_k)
        # Calculate the weight w_k^c
        weight_k = (y_c - y_c_k) / y_c
        weights.append(weight_k.item())
        
        # Restore the feature map to its original state
        activation[:, i, :, :] = original_activation[:, i, :, :]
    
    return weights


In [None]:
# Load an image
image = Image.open('images/husky.jpg')
input_image = preprocess(image).unsqueeze(0)  # Add batch dimension

# Choose a class index (e.g., 0 for the first class)
class_index = 0

# Calculate weights for each feature map in the target layer
weights = calculate_weights(model, input_image, class_index)
print(weights)

# Remove the hook when done
hook.remove()


0.0 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.048828125 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.09765625 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.146484375 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.1953125 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.244140625 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.29296875 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.341796875 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.390625 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.439453125 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.48828125 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.537109375 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.5859375 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.634765625 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.68359375 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.732421875 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.78125 %
tensor(-0.0681, grad_fn=<SelectBackward0>)
0.830078125 %
tensor(-0.0681, grad_fn=<SelectBackward0

KeyboardInterrupt: 