Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LRP for resnet model #4

Open
linhlpv opened this issue Feb 20, 2021 · 4 comments
Open

LRP for resnet model #4

linhlpv opened this issue Feb 20, 2021 · 4 comments

Comments

@linhlpv
Copy link

linhlpv commented Feb 20, 2021

Thank for your works!
I see that you implement LRP for vgg model. But vgg is simple model with single Sequential and does not have residual connection. Could you help me to implement LRP for complex model, such as ResNet?
Thank you so much!

@sdw95927
Copy link

sdw95927 commented Apr 30, 2021

Thank for your works!
I see that you implement LRP for vgg model. But vgg is simple model with single Sequential and does not have residual connection. Could you help me to implement LRP for complex model, such as ResNet?
Thank you so much!

I implement the resnet convert as follows:

import torch
import torchvision
from lrp.conv       import Conv2d 
from lrp.linear     import Linear
from lrp.sequential import Sequential, Bottleneck

conversion_table = { 
        'Linear': Linear,
        'Conv2d': Conv2d
    }

# # # # # Convert torch.models.resnetxx to lrp model
def convert_resnet(module, modules=None):
    # First time
    if modules is None: 
        modules = []
        for m in module.children():
            convert_resnet(m, modules=modules)
            
            # if isinstance(m, torch.nn.Sequential):
            #     break
            
            # Vgg model has a flatten, which is not represented as a module
            # so this loop doesn't pick it up.
            # This is a hack to make things work.
            if isinstance(m, torch.nn.AdaptiveAvgPool2d): 
                modules.append(torch.nn.Flatten())

        sequential = Sequential(*modules)
        return sequential

    # Recursion
    if isinstance(module, torch.nn.Sequential): 
        for m in module.children():
            convert_resnet(m, modules=modules)

    elif isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv2d):
        class_name = module.__class__.__name__
        lrp_module = conversion_table[class_name].from_torch(module)
        modules.append(lrp_module)
    # maxpool is handled with gradient for the moment

    elif isinstance(module, torch.nn.ReLU): 
        # avoid inplace operations. They might ruin PatternNet pattern
        # computations
        modules.append(torch.nn.ReLU())
    elif isinstance(module, torchvision.models.resnet.Bottleneck):
        # For torchvision Bottleneck
        bottleneck = Bottleneck()
        bottleneck.conv1 = Conv2d.from_torch(module.conv1)
        bottleneck.conv2 = Conv2d.from_torch(module.conv2)
        bottleneck.conv3 = Conv2d.from_torch(module.conv3)
        bottleneck.bn1 = module.bn1
        bottleneck.bn2 = module.bn2
        bottleneck.bn3 = module.bn3
        bottleneck.relu = torch.nn.ReLU()
        if module.downsample is not None:
            bottleneck.downsample = module.downsample
            bottleneck.downsample[0] = Conv2d.from_torch(module.downsample[0])
        modules.append(bottleneck)
    else:
        modules.append(module)

and edit the sequential.py as follows:

import torch

from . import Linear, Conv2d
from .maxpool import MaxPool2d
from .functional.utils import normalize

def grad_decorator_fn(module):
    """
        Currently not used but can be used for debugging purposes.
    """
    def fn(x): 
        return normalize(x)
    return fn

avoid_normalization_on = ['relu', 'maxp']
def do_normalization(rule, module):
    if "pattern" not in rule.lower(): return False
    return not str(module)[:4].lower() in avoid_normalization_on

def is_kernel_layer(module):
    return isinstance(module, Conv2d) or isinstance(module, Linear) or isinstance(module, Bottleneck)

def is_rule_specific_layer(module):
    return isinstance(module, MaxPool2d)

class Sequential(torch.nn.Sequential):
    def forward(self, input, explain=False, rule="epsilon", pattern=None):
        if not explain: return super(Sequential, self).forward(input)

        first = True

        # copy references for user to be able to reuse patterns
        if pattern is not None: pattern = list(pattern) 

        for module in self:
            if do_normalization(rule, module):
                input.register_hook(grad_decorator_fn(module))

            if is_kernel_layer(module): 
                P = None
                if pattern is not None: 
                    P = pattern.pop(0)
                input = module.forward(input, explain=True, rule=rule, pattern=P)

            elif is_rule_specific_layer(module):
                input = module.forward(input, explain=True, rule=rule)

            else: # Use gradient as default for remaining layer types
                input = module(input)
            first = False

        if do_normalization(rule, module): 
            input.register_hook(grad_decorator_fn(module))

        return input

class Bottleneck(torch.nn.Module):
    def __init__(self):
        super(Bottleneck, self).__init__()
        self.downsample = None
    
    def forward(self, x, explain=True, rule="epsilon", pattern=None):
        identity = x
        
        if pattern is not None:
            out = self.conv1(x, explain=explain, rule=rule, pattern=pattern[0])
            out = self.bn1(out)
            out = self.relu(out)

            out = self.conv2(out, explain=explain, rule=rule, pattern=pattern[1])
            out = self.bn2(out)
            out = self.relu(out)

            out = self.conv3(out, explain=explain, rule=rule, pattern=pattern[2])
            out = self.bn3(out)

            if self.downsample is not None:
                identity = self.downsample[0](x, explain, rule, pattern=pattern[3])
                identity = self.downsample[1](identity)
        else:
            out = self.conv1(x, explain=explain, rule=rule)
            out = self.bn1(out)
            out = self.relu(out)

            out = self.conv2(out, explain=explain, rule=rule)
            out = self.bn2(out)
            out = self.relu(out)

            out = self.conv3(out, explain=explain, rule=rule)
            out = self.bn3(out)

            if self.downsample is not None:
                identity = self.downsample[0](x, explain, rule)
                identity = self.downsample[1](identity)
            
        out += identity
        out = self.relu(out)

        return out

For patternnet, also need to modify the _fit_pattern function in patterns.py:

def _fit_pattern(model, train_loader, max_iter, device, mask_fn = lambda y: torch.ones_like(y)):
    stats_x     = [] 
    stats_y     = []
    stats_xy    = []
    weights     = []
    cnt         = []
    cnt_all     = []

    first = True
    for b, (x, _) in enumerate(tqdm(train_loader)): 
        x = x.to(device)

        i = 0
        for m in model:
            # For Bottleneck
            if isinstance(m, Bottleneck):
                if first:
                    stats_x.append([])
                    stats_y.append([])
                    stats_xy.append([])
                    weights.append([])
                    
                y = m.conv1(x)
                mask = mask_fn(y).float().to(device)
                if m.conv1.bias is not None:
                    y_wo_bias = y - m.conv1.bias.view(-1, 1, 1)
                else:
                    y_wo_bias = y.clone()
                cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.conv1, x, y_wo_bias, mask)
                if first:
                    stats_x[i].append(RunningMean(x_.shape, device))
                    stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
                    stats_xy[i].append(RunningMean(xy_.shape, device))
                    weights[i].append((w, w_fn))
                stats_x[i][0].update(x_, cnt_)
                stats_y[i][0].update(y_.sum(0), cnt_all_)
                stats_xy[i][0].update(xy_, cnt_)
                
                x1 = y.clone()
                x1 = m.bn1(x1)
                x1 = m.relu(x1)
                
                y = m.conv2(x1)
                mask = mask_fn(y).float().to(device)
                if m.conv2.bias is not None:
                    y_wo_bias = y - m.conv2.bias.view(-1, 1, 1)
                else:
                    y_wo_bias = y.clone()
                cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.conv2, x1, y_wo_bias, mask)
                if first:
                    stats_x[i].append(RunningMean(x_.shape, device))
                    stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
                    stats_xy[i].append(RunningMean(xy_.shape, device))
                    weights[i].append((w, w_fn))
                stats_x[i][1].update(x_, cnt_)
                stats_y[i][1].update(y_.sum(0), cnt_all_)
                stats_xy[i][1].update(xy_, cnt_)
                
                x2 = y.clone()
                x2 = m.bn2(x2)
                x2 = m.relu(x2)
                
                y = m.conv3(x2)
                mask = mask_fn(y).float().to(device)
                if m.conv3.bias is not None:
                    y_wo_bias = y - m.conv3.bias.view(-1, 1, 1)
                else:
                    y_wo_bias = y.clone()
                cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.conv3, x2, y_wo_bias, mask)
                if first:
                    stats_x[i].append(RunningMean(x_.shape, device))
                    stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
                    stats_xy[i].append(RunningMean(xy_.shape, device))
                    weights[i].append((w, w_fn))
                stats_x[i][2].update(x_, cnt_)
                stats_y[i][2].update(y_.sum(0), cnt_all_)
                stats_xy[i][2].update(xy_, cnt_)
                
                y = m.bn3(y)
                
                if m.downsample is not None:
                    identity = m.downsample[0](x)
                    mask = mask_fn(identity).float().to(device)
                    if m.downsample[0].bias is not None:
                        y_wo_bias = y - m.downsample[0].bias.view(-1, 1, 1)
                    else:
                        y_wo_bias = y.clone()
                    cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m.downsample[0], x, y_wo_bias, mask)
                    if first:
                        stats_x[i].append(RunningMean(x_.shape, device))
                        stats_y[i].append(RunningMean(y_.shape, device)) # Use all y
                        stats_xy[i].append(RunningMean(xy_.shape, device))
                        weights[i].append((w, w_fn))
                    stats_x[i][3].update(x_, cnt_)
                    stats_y[i][3].update(y_.sum(0), cnt_all_)
                    stats_xy[i][3].update(xy_, cnt_)
                    identity = m.downsample[1](identity)
                
                y += identity
                x = m.relu(y)
                i += 1
                continue
                
            y = m(x) # Note, this includes bias.
            
            if not (isinstance(m, torch.nn.Linear) or isinstance(m, torch.nn.Conv2d)): 
                x = y.clone()
                continue
            
            mask = mask_fn(y).float().to(device)
            if m.bias is not None:
                if isinstance(m, torch.nn.Conv2d): 
                    y_wo_bias = y - m.bias.view(-1, 1, 1) 
                else:                              
                    y_wo_bias = y - m.bias.clone()
            else:
                y_wo_bias = y

            cnt_, cnt_all_, x_, y_, xy_, w, w_fn = _prod(m, x, y_wo_bias, mask)

            if first:
                stats_x.append(RunningMean(x_.shape, device))
                stats_y.append(RunningMean(y_.shape, device)) # Use all y
                stats_xy.append(RunningMean(xy_.shape, device))
                weights.append((w, w_fn))

            stats_x[i].update(x_, cnt_)
            stats_y[i].update(y_.sum(0), cnt_all_)
            stats_xy[i].update(xy_, cnt_)

            x = y.clone()
            i += 1

            
        first = False

        if max_iter is not None and b+1 == max_iter: break

    def pattern(x_mean, y_mean, xy_mean, W2d):
        x_  = x_mean.value
        y_  = y_mean.value
        xy_ = xy_mean.value

        W, w_fn = W2d
        ExEy = x_ * y_
        cov_xy = xy_ - ExEy # [in, out]

        w_cov_xy = torch.diag(W @ cov_xy) # [out,]

        A = safe_divide(cov_xy, w_cov_xy[None, :])
        A = w_fn(A) # Reshape to original kernel size

        return A
        
    # patterns = [pattern(*vars) for vars in zip(stats_x, stats_y, stats_xy, weights)]
    patterns = []
    for vars in zip(stats_x, stats_y, stats_xy, weights):
        if isinstance(vars[0], RunningMean):
            patterns.append(pattern(*vars))
        else:
            patterns_sub = []
            for vars_sub in zip(vars[0], vars[1], vars[2], vars[3]):
                patterns_sub.append(pattern(*vars_sub))
            patterns.append(patterns_sub)
    return patterns

The LRP for the adding manipulation is not added yet, will probably need to consider implementing this.

@miladsikaroudi
Copy link

Thank you for your posting @sdw95927 .
I used these lines of code for generating ResNet heatmaps.
The problem is the heatmaps for ResNet are not so meaningful.
I am attaching the heatmaps generated for pretrained VGG and ResNet as below. Any idea?

VGG

RES

@sdw95927
Copy link

I can see the patterns from ResNet too, just not as clear as VGG. I think it's mainly due to the complex structure, such as residual connection, in ResNet, whereas VGG is simple and straightforward.

@zah-tane
Copy link

@miladsikaroudi Can you please share what you did to make this work with ResNet?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants