<a href="https://colab.research.google.com/github/harvard-visionlab/sroh/blob/main/2022/pytorch_feature_editor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn 

'''
    FeatureEditor
    
    A wrapper class that will handle 'hooking' the to-be-lesioned layer(s),
    editing them on the forward pass through the model, and returning
    the lesioned outputs.
    
    For each layer you are editing, include a mask. That mask will
    be multiplied by the layer's output, and the result will be
    passed onto the next layer.
    
    Like FeatureExtractor, FeatureEditor should be used as a context
    manager to clean up (remove) hooks when you are done with them:
    
    with FeatureEditor(model,layers,masks,return_features=False) as editor:
        output = editor(imgs)
        
'''
class FeatureEditor(nn.Module):
    def __init__(self, model, layers, masks, return_features=False):
        super().__init__()
        self.model = model
        self.layers = [layers] if isinstance(layers, str) else layers
        self.masks = [masks] if not isinstance(layers, list) else masks
        self.return_features = return_features
        self._features = {layer: torch.empty(0) for layer in layers}
        self.hooks = {}

    def hook_layers(self):        
        self.remove_hooks()
        for layer_id,mask in zip(self.layers, self.masks):
            layer = dict([*self.model.named_modules()])[layer_id]
            self.hooks[layer_id] = layer.register_forward_hook(self.edit_outputs_hook(layer_id,mask))
    
    def remove_hooks(self):
        for layer_id in self.layers:
            self._features[layer_id] = torch.empty(0)
            if layer_id in self.hooks:
                self.hooks[layer_id].remove()
                del self.hooks[layer_id]
    
    def __enter__(self, *args): 
        self.hook_layers()
        return self
    
    def __exit__(self, *args): 
        self.remove_hooks()
            
    def edit_outputs_hook(self, layer_id, mask):
        def fn(_, __, output):
            modified_output = output * mask
            self._features[layer_id] = modified_output
            return modified_output
        return fn

    def forward(self, x):
        out = self.model(x)
        if self.return_features:
            return self._features
        else:
            return out
        
def generate_mask(shape, units):
    mask = torch.ones(shape).flatten()
    mask[units] = 0
    return mask.reshape(shape).unsqueeze(0)  

In [3]:
'''
    FeatureExtractor class that allows you to retain outputs of any layer.
    
    This class uses PyTorch's "forward hooks", which let you insert a function
    that takes the input and output of a module as arguements.
    
    In this hook function you can insert tasks like storing the intermediate values,
    or as we'll do in the FeatureEditor class, actually modify the outputs.
    
    Adding these hooks can cause headaches if you don't "remove" them 
    after you are done with them. For this reason, the FeatureExtractor is 
    setup to be used as a context, which sets up the hooks when
    you enter the context, and removes them when you leave:
    
    with FeatureExtractor(model, layer_name) as extractor:
        features = extractor(imgs)
    
    If there's an error in that context (or you cancel the operation),
    the __exit__ function of the feature extractor is executed,
    which we've setup to remove the hooks. This will save you 
    headaches during debugging/development.
    
'''

import torch
import torch.nn as nn
from torchvision import models
from IPython.core.debugger import set_trace

class FeatureExtractor(nn.Module):
    def __init__(self, model, layers, detach=True, clone=True, device='cpu'):
        super().__init__()
        self.model = model
        self.layers = [layers] if isinstance(layers, str) else layers
        self.detach = detach
        self.clone = clone
        self.device = device
        self._features = {layer: torch.empty(0) for layer in layers}        
        self.hooks = {}
        
    def hook_layers(self):        
        self.remove_hooks()
        for layer_id in self.layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            self.hooks[layer_id] = layer.register_forward_hook(self.save_outputs_hook(layer_id))
    
    def remove_hooks(self):
        for layer_id in self.layers:
            self._features[layer_id] = torch.empty(0)
            if layer_id in self.hooks:
                self.hooks[layer_id].remove()
                del self.hooks[layer_id]
    
    def __enter__(self, *args): 
        self.hook_layers()
        return self
    
    def __exit__(self, *args): 
        self.remove_hooks()
        
    def save_outputs_hook(self, layer_id):
        def fn(_, __, output):
            if self.detach: output = output.detach()
            if self.clone: output = output.clone()
            if self.device: output = output.to(self.device)
            self._features[layer_id] = output
        return fn

    def forward(self, x):
        _ = self.model(x)
        return self._features
    
def get_layers(model, parent_name='', layer_info=[]):
    for module_name, module in model.named_children():
        layer_name = parent_name + '.' + module_name
        if len(list(module.named_children())):
            layer_info = get_layers(module, layer_name, layer_info=layer_info)
        else:
            layer_info.append(layer_name.strip('.'))
    
    return layer_info

def get_layer(m, layers):
    layer = layers.pop(0)
    m = getattr(m, layer)
    if len(layers) > 0:
        return get_layer(m, layers)
    return m

def get_layer_type(model, layer_name):
    m = get_layer(model, layer_name.split("."))
    return m.__class__.__name__
            
def convert_relu_layers(parent):
    for child_name, child in parent.named_children():
        if isinstance(child, nn.ReLU) and child.inplace==True:
            setattr(parent, child_name, nn.ReLU(inplace=False))
        elif len(list(child.children())) > 0:
            convert_relu_layers(child)

In [17]:
from torchvision import models 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = models.alexnet(pretrained=True)
model.to(device)
layer_names = get_layers(model)
layer_names

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "


['features.0',
 'features.1',
 'features.2',
 'features.3',
 'features.4',
 'features.5',
 'features.6',
 'features.7',
 'features.8',
 'features.9',
 'features.10',
 'features.11',
 'features.12',
 'avgpool',
 'classifier.0',
 'classifier.1',
 'classifier.2',
 'classifier.3',
 'classifier.4',
 'classifier.5',
 'classifier.6',
 'features.0',
 'features.1',
 'features.2',
 'features.3',
 'features.4',
 'features.5',
 'features.6',
 'features.7',
 'features.8',
 'features.9',
 'features.10',
 'features.11',
 'features.12',
 'avgpool',
 'classifier.0',
 'classifier.1',
 'classifier.2',
 'classifier.3',
 'classifier.4',
 'classifier.5',
 'classifier.6']

In [94]:
# let's get the shape of activation map for features.0
model.eval()
layer_name = 'features.0'
with FeatureExtractor(model, [layer_name]) as extractor:
  dummy_images = torch.rand(10,3,224,224)
  output = extractor(dummy_images)
  with torch.no_grad():
    features = output[layer_name]
output_map_shape = features.shape[1:]
output_map_shape

torch.Size([64, 55, 55])

In [95]:
# generate a mask with this same shape (with extra first dim for batch) 1x64x55x55 
mask_units = [0,1,55]
demo_mask = generate_mask(output_map_shape, mask_units)
demo_mask.to(device)
demo_mask.shape

torch.Size([1, 64, 55, 55])

In [96]:
# we're knocking out units 0,1,55 which are in the first channel
# and should be the first two elements in the first row 
# and the first element in the second row
demo_mask[0,0]

tensor([[0., 0., 1.,  ..., 1., 1., 1.],
        [0., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])

In [97]:
# Here let's edit the output, and return the feature maps so we can inspect them
model.eval()
with FeatureEditor(model, [layer_name], [demo_mask], return_features=True) as editor:
  # editor is now a copy of the model that will have certain activations "zeroed out"
  # it does this by multilying the layer output by the mask 
  # if you have a list of layer_names, e.g., [layer1, layer2]
  # and a list of masks, e.g., [mask1, mask2],
  # then it's assumed you want layer1 * mask1, layer2 * mask2
  dummy_images = torch.randn(10,3,224,224)
  with torch.no_grad():
    output = editor(dummy_images)
  masked_features = output[layer_name]
masked_features.shape  

torch.Size([10, 64, 55, 55])

In [109]:
# the masked units should all have value zero
print(masked_features[0].flatten()[mask_units])
print(masked_features[0].flatten()[mask_units]==0)

tensor([0., 0., -0.])
tensor([True, True, True])


In [110]:
# Now if you want to run the model, and just get the edited output 
model.eval()
with FeatureEditor(model, [layer_name], [demo_mask], return_features=False) as editor:
  dummy_images = torch.randn(10,3,224,224)
  with torch.no_grad():
    output = editor(dummy_images)
output.shape

torch.Size([10, 1000])

In [None]:
# e.g., to run evaluation
model.eval()
with FeatureEditor(model, [layer_name], [demo_mask], return_features=False) as editor:
  # do anything you want with the edited model...
  # results = validate(editor, val_loader)