In [None]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt

from typing import Tuple, List, Callable

In [None]:
# dummy model

class Dummy(nn.Module):
    def __init__(self): 
        super(Dummy, self).__init__()
        self.lin1 = nn.Linear(2, 2)
        self.relu1 = nn.ReLU()
        self.lin2 = nn.Linear(2, 2)
        self.relu2 = nn.ReLU()
    
    def forward(self, x):
        x = self.lin1(x)
        x = self.relu1(x)
        x = self.lin2(x)
        x = self.relu2(x)
        return x

dummy = Dummy()

In [None]:
dummy.state_dict()

In [None]:
x = [1, 1]
x = torch.Tensor(x)
dummy(x)

In [None]:
# sucessful hook approach to modify activations
def get_activation(neuron, activation):
    def hook(model, input, output):
        modified_out = output.detach()
        modified_out[neuron] = activation
        return modified_out
    return hook

hook_handle = dummy.lin2.register_forward_hook(get_activation(1, 2))
output = dummy(x)
print(output)
hook_handle.remove()

In [None]:
def get_neurons(model) -> List[Tuple[nn.Module, int]]:
    """
        Returns a list of tuples containing the layers in which the neurons are in
        and the number of neurons in that layer.
    """
    return [
        *[(model.lin1, i) for i in range(2)],
        *[(model.lin2, i) for i in range(2)],
    ]
    
def get_labels(model) -> List[int]:
    """
        Returns the list of labels in the model, 
        aka the last layer neurons.
    """
    return [0, 1] # two values for output

In [None]:
def modify_activation(neuron, activation):
    def hook(model, input, output):
        modified_out = output.detach()
        modified_out[neuron] = activation
        return modified_out
    return hook
    
def NSF(model, label: int, layer: nn.Module, neuron: int, image: torch.Tensor) -> Callable:
    def func(x: float) -> float:
        hook_handle = layer.register_forward_hook(modify_activation(neuron, x))
        output = model(image)
        hook_handle.remove()
        return output[label]
    return func

In [None]:
def identify_candidate(C, neurons, labels, base_imgs):
    """
        C: the model in question
        base_imgs: list of tuples containing (image, label)
    """
    max_n = 0
    max_l = None
    max_v = 0
    for layer, neuron in neurons:
        labelLift = []
        for label in labels:
            min_img_v = float('inf')
            for img in base_imgs:
                image, img_label = img
                if img_label == label: continue
                x = torch.linspace(-1, 1, 100)
                img_v = max(
                    NSF(C, label, layer, neuron, image)(xx)
                    for xx in x
                    ) - C(image)[label]
                min_img_v = min(min_img_v, img_v)
            labelLift.append(min_img_v)
        labelLift.sort(reverse=True)
        n_v = labelLift[0] - labelLift[1]
        if n_v > max_v:
            max_v = n_v
            max_n = neuron
            max_l = layer
    return max_l, max_n

In [None]:
C = Dummy()
neurons = get_neurons(C)
labels = get_labels(C)
base_imgs = [
    (torch.rand(2), np.random.choice(labels)) for _ in range(20)
]

identify_candidate(C, neurons, labels, base_imgs)