-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numerical instability in ResNet50 heatmaps #148
Comments
Hey Rodrigo, I have seen this behavior before with the Gamma rule used in combination with the ResNet-Canonization (replacing the residual connections and weighting them by contribution) in different implementations unrelated to Zennit. I know for a fact that the same instability also happens if you choose the ZBox bounds too low, much lower than the actual bounds of the data. Do you have a specific setup for VGG16? At which epsilon did this happen? For VGG16 I have not seen it before, except for the aforementioned bounds. Could you also list the specific Composites for which you have seen this behaviour? |
Thank you for the prompt reply, Christopher! Issue ReproductionThe snippet above reproduces the instability with the following setup:
The following snippet uses the following setup:
Bug reproduction
import cv2
import numpy
import torch
from torch.nn import AvgPool2d, Conv2d, Linear
from torchvision.models import resnet50
from zennit.composites import EpsilonGammaBox, NameMapComposite
from zennit.core import BasicHook, collect_leaves, stabilize
from zennit.rules import Epsilon, Gamma, ZBox
from zennit.torchvision import ResNetCanonizer
from matplotlib import pyplot as plt
from zennit.image import imgify
# the LRP-Epsilon from the tutorial
class GMontavonEpsilon(BasicHook):
def __init__(self, stabilize_epsilon=1e-6, epsilon=0.25):
super().__init__(
input_modifiers=[lambda input: input],
param_modifiers=[lambda param, _: param],
output_modifiers=[lambda output: output],
gradient_mapper=(lambda out_grad, outputs: out_grad / stabilize(
outputs[0] + epsilon * (outputs[0] ** 2).mean() ** .5, stabilize_epsilon)),
reducer=(lambda inputs, gradients: inputs[0] * gradients[0])
)
# use the gpu if requested and available, else use the cpu
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class BatchNormalize:
def __init__(self, mean, std, device=None):
self.mean = torch.tensor(mean, device=device)[None, :, None, None]
self.std = torch.tensor(std, device=device)[None, :, None, None]
def __call__(self, tensor):
return (tensor - self.mean) / self.std
# mean and std of ILSVRC2012 as computed for the torchvision models
norm_fn = BatchNormalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225), device=device)
batch_size = 1
# the maximal input shape, needed for the ZBox rule
shape = (batch_size, 3, 224, 224)
# the highest and lowest pixel values for the ZBox rule
low = norm_fn(torch.zeros(*shape, device=device))
high = norm_fn(torch.ones(*shape, device=device))
model = resnet50(pretrained=True)
model.eval()
# only these get rules, linear layers will be attributed by the gradient alone
# target_types = (Conv2d, AvgPool2d)
target_types = (Conv2d, AvgPool2d, Linear)
# lookup module -> name
child_name = {module: name for name, module in model.named_modules()}
# the layers in sequential order without any containers etc.
layers = list(enumerate(collect_leaves(model)))
# list of tuples [([names..], rule)] as used by NameMapComposite
name_map = [
([child_name[module] for n, module in layers if n == 0 and isinstance(module, target_types)], ZBox(low=low, high=high)),
([child_name[module] for n, module in layers if 1 <= n <= 16 and isinstance(module, target_types)], Gamma(0.25)),
([child_name[module] for n, module in layers if 17 <= n <= 30 and isinstance(module, target_types)], GMontavonEpsilon(stabilize_epsilon=0, epsilon=0.25)),
([child_name[module] for n, module in layers if 31 <= n and isinstance(module, target_types)], Epsilon(0)),
]
# create the composite from the name map
composite = NameMapComposite(name_map, canonizers=[ResNetCanonizer()])
R = None
with composite.context(model) as modified_model:
# compute attribution
# Returns a numpy array in BGR color space, not RGB
img = cv2.imread('castle.jpg')
# Convert from BGR to RGB color space
img = img[..., ::-1]
# img.shape is (224, 224, 3), where 3 corresponds to RGB channels
# Divide by 255 (max. RGB value) to normalize pixel values to [0,1]
img = img/255.0
data = norm_fn(
torch.FloatTensor(
img[numpy.newaxis].transpose([0, 3, 1, 2])*1
)
)
data.requires_grad = True
output = modified_model(data)
output[0].max().backward()
# print absolute sum of attribution
print(data.grad.abs().sum().item())
R = data.grad
heatmap = imgify(
R.detach().cpu().sum(1),
symmetric=True,
grid=True,
cmap='seismic',
)
plt.imshow(heatmap) Input(s): Outputs: Root cause of numerical instabilityFrom my observations, I've narrowed down the issue to the denominators in the equations below. Equations
The issue arises depending on the implementation of the I have tested the following heuristics: Heuristic implementations
epsilon: float = 0.1
dividend: torch.Tensor = torch.Tensor([-epsilon, 5, -5, -10])
# tensor([ -0.1000, 5.0000, -5.0000, -10.0000])
Additional insightsDisclaimer: The following images have been generated with my own implementation of LRP. Nevertheless, the error can also be reproduced using zennit by modifying the heuristics in the It is worth noting that heatmaps that exhibit only few visual relevance concentrations, as shown in the ResNet50 heatmaps in the first comment, do have non-zero relevance scores elsewhere, their visibility strongly depend on the plotting settings—see example below. Comparison
Heatmap of relevance scores with numerical instability: Same relevance scores with numerical instability, adjusted plotting settings: Heatmap with input in the same plot: Comparing
|
Hey Rodrigo, thank you for your insights! This is also why it is so dependent on the heuristic used in the stabilize function. In my analysis I have also found that the ResNet residual connections can also lead to vanishing contributions when they are attributed with In my conclusion, there is no global solution that could be implemented in Zennit. As a side note, I have also noticed that for VGG the necessary epsilon grows when the bias is not included, as I am currently preparing the 'no-bias' feature. Here is also the snippet I used for my analysis. I am thinking of also adding something similiar as a tutorial, to also debug conservativity. Python code#!/usr/bin/env python3
import torch
from PIL import Image
from torchvision.models import resnet18, vgg11_bn
from torchvision.transforms import Compose, Resize, CenterCrop, Normalize, ToTensor
from zennit.attribution import Gradient
from zennit.core import collect_leaves, RemovableHandleList
from zennit.composites import EpsilonGammaBox
from zennit.torchvision import ResNetCanonizer, VGGCanonizer
from zennit.image import imsave
def trace_hook(target):
def trace(module, input, output):
output.retain_grad()
target.append((module, output))
return trace
def main():
transform_img = Compose([
Resize(256),
CenterCrop(224),
])
transform_norm = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
transform = Compose([
transform_img,
ToTensor(),
transform_norm,
])
image = Image.open('dornbusch-lighthouse.jpg')
data = transform(image)[None]
modules = []
model = resnet18(pretrained=True).eval()
canonizer = ResNetCanonizer()
# model = vgg11_bn(pretrained=True).eval()
# canonizer = VGGCanonizer()
low, high = transform_norm(torch.tensor([[[[[0.]]] * 3], [[[[1.]]] * 3]]))
# low, high = (-3., 3.)
composite = EpsilonGammaBox(
low=low,
high=high,
zero_params=['bias'],
gamma=4.,
epsilon=1e-6,
canonizers=[canonizer]
)
target = torch.eye(1000)[[437]]
with Gradient(model=model, composite=composite) as attributor:
handles = RemovableHandleList(
module.register_forward_hook(trace_hook(modules))
for module in collect_leaves(model)
)
output, attribution = attributor(data, target)
handles.remove()
print(f'Prediction: {output.argmax(1)[0].item()}')
lines = [
f'{n + 1:03d} {module.__class__.__name__:17s}: '
f'{(output.grad.max() - output.grad.min()).item():.2e} {output.grad.sum().item():.2e}'
for n, (module, output) in enumerate(modules)
]
lines.insert(
0,
f'000 {"input":17s}: {(attribution.max() - attribution.min()).item():.2e} {attribution.sum().item():.2e}'
)
print('\n'.join(lines))
imsave('heatmap.png', attribution.sum(1)[0], symmetric=True, cmap='coldnhot')
if __name__ == '__main__':
main() Do you agree with my observation? I would then proceed to provide stronger stabilizer customization for class Stabilizer:
def __init__(self, epsilon=1e-6, clip=False, mean_scale=False, dim=None):
self.epsilon = epsilon
self.clip = clip
self.mean_scale = mean_scale
self.dim = dim
def __call__(self, input):
sign = ((input == 0.).to(input) + input.sign())
epsilon = self.epsilon
if self.mean_scale:
dim = self.dim
if self.dim is None:
dim = tuple(range(1, input.ndim))
epsilon = epsilon * ((input ** 2).mean(dim=dim) ** .5)
if self.clip:
return sign * input.abs().clip(min=epsilon)
return input + sign * epsilon And then I will maybe just check inside the |
Hey, Thank you Rodrigo for putting so much effort into solving the instability and thanks to Christopher for taking the time to answer all the issues in depth. Best, |
As community service, here is the output for resnet18 with the standard values, but bias deactivated. The first column after the layer name is the distance from the maximum to the minimum value, and the second value is the sum (should stay 1 if conservative) Default parameters (broken)
gamma=4 and epsilon=1e-6 (better)
|
Hi Chris, Thanks that's interesting. It confirms the correlation between heatmap quality and hidden concept attribution. |
I might add that this could be a problem of the gamma-rule, which does not fit (as it is) to the skip-connections of the ResNet model. However, if the denominator is negative, it becomes more positive with larger gamma. Then, the denominator can become smaller in magnitude and relevances can explode. This usually does not happen, as negative denominator means negative pre-activation, which is set to zero with ReLU non-linearities. So the problem is, that the gamma-rule assumes, that only neurons with positive outputs have relevance/contributed in the forward pass. One way to fix it would be to make the gamma rule more generic, by (1) checking the sign of the output and then either favoring positive (if output pos.) or negative contributions (if output neg.). This way, the denominator always becomes larger. A fix like that would probably also stabilize relevance propagation if other non-linearities are used, where negative outputs receive relevance/contribute in the forward pass. As I think of it, this could also be interesting for other rules, e.g. alpha-beta. Here, for alpha1_beta0 all positive contributions receive relevance if the output is negative. Would it not be more sensible if all negative contributions receive relevance? |
I have run a small experiment and making the gamma-rule symmetric (for pos. and neg. output, indicated by *) seems to result in much more reasonable results. The total relevance at the input stays bounded and heatmaps seem sensible. Here is my implementation of the updated rule
|
Thank you all for taking the time to look into this issue so thoroughly. I am yet to go in detail through the snippets and some comments but I am really inspired by the fruitful conversation we have going on right now. A few thoughts came into my mind while reading the observations made so far:
Food for thought
References:
|
Specifically for ResNet50, there is also a pretty nasty artifact in modified backprop explanations caused by the 1x1conv stride 2 downsampling shortcuts. Short of excluding those shortcuts, there is not really a way to get around that. I suspect that this may further contribute to the concentrated attributions you are reporting (in addition to the issues with the gamma-rule). This artifact should also appear for other composites, but not for other models. Below are some example heatmaps showing that artifact for PascalVOC. |
Hey everyone, coming back to @maxdreyer's suggestion, it is true that the Gamme rule is only defined on positive inputs. |
For reference only: I just made public my LRP implementation |
Calculating the relevance on ResNet50 seems to be prone to a numerical instability, producing heatmaps where all attribution is concentrated in a few spots because the values in those spots have become larger than the rest. See heatmap in bug reproduction section. I can confirm that this unexpected behavior also happens using different composites and different images.
I have also seen this issue on VGG16 in my own LRP implementation depending on the heuristic used in the stabilize function.
Bug reproduction
Code based on snippet provided in #76 (comment).
Minimal reproducible example:
Input(s):
Input image:
Outputs:
Text:
Heatmap:
Additional information
The bug is not limited to the
castle.jpg
image, it can also be reproduced using the following image. See the corresponding heatmap below.Input image:
Heatmap:
The text was updated successfully, but these errors were encountered: