In [1]:
from transformers import AutoModel, AutoTokenizer
import torch

In [2]:
model_name = "zhihan1996/DNABERT-2-117M"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

Some weights of BertModel were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [317]:
from typing import Tuple
class LRPCheckpoint(torch.autograd.Function):
    """Identity autograd fcn for marking where to capture relevance."""
    @staticmethod
    def forward(ctx, input: torch.Tensor) -> torch.Tensor:
        return input

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
        return grad_output, None, None

create_checkpoint = LRPCheckpoint.apply

In [318]:
def checkpoint_hook(module, input, output):
    return create_checkpoint(output)

In [319]:
for layer_module in model.encoder.layer:
    layer_module.attention.self.register_forward_hook(checkpoint_hook)

In [3]:
dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]
hidden_states : torch.Tensor = model(inputs, requires_grad=True)[0] # [1, sequence_length, 768]

In [239]:
hidden_states.shape

torch.Size([1, 17, 768])

In [279]:
with torch.no_grad():
    m = hidden_states.max(-1)
    # print(m)
    b = torch.zeros_like(a)
    for i, inds in enumerate(m.indices):
        b[i,list(range(hidden_states.shape[1])),inds] = torch.ones_like(m.values[0])

In [4]:
fcns = [ [hidden_states.grad_fn] ]
visited = set()
names = set()

count = 0
while fcns:
    count += 1
    # print(len(fcns))
    new_fcns = []
    for fcn_list in fcns:
        for fcn in fcn_list:
            if fcn is None or fcn in visited:
                continue
            if type(fcn).__name__ not in names:
                names.add(type(fcn).__name__)
            visited.add(fcn)
            new_fcns.append([ fcn_tup[0] for fcn_tup in fcn.next_functions ])
        # new_fcns += [ [ fcn_tup[0] for fcn_tup in curr.next_functions ] for curr in fcn_list if (curr is not None) ]
    fcns = new_fcns

print(count)

73


In [322]:
names

{'AccumulateGrad',
 'AddBackward0',
 'AddmmBackward0',
 'BmmBackward0',
 'DivBackward0',
 'EmbeddingBackward0',
 'ExpandBackward0',
 'GeluBackward0',
 'IndexFirstAxisBackward',
 'IndexPutFirstAxisBackward',
 'LRPCheckpointBackward',
 'MmBackward0',
 'MulBackward0',
 'NativeLayerNormBackward0',
 'PermuteBackward0',
 'ReshapeAliasBackward0',
 'SelectBackward0',
 'SliceBackward0',
 'SoftmaxBackward0',
 'TBackward0',
 'UnsafeViewBackward0',
 'ViewBackward0'}

In [6]:
visited = list(visited)
visited = sorted(visited, key=lambda fcn: fcn._sequence_nr())

In [7]:
list(enumerate(visited))

[(0, <EmbeddingBackward0 at 0x29145a6a0>),
 (1, <EmbeddingBackward0 at 0x29145a6d0>),
 (2, <AddBackward0 at 0x29145a3d0>),
 (3, <NativeLayerNormBackward0 at 0x29145a130>),
 (4, <ViewBackward0 at 0x291463e80>),
 (5, <torch.autograd.function.IndexFirstAxisBackward at 0x291185e40>),
 (6, <TBackward0 at 0x29146ebb0>),
 (7, <AddmmBackward0 at 0x29146eaf0>),
 (8, <torch.autograd.function.IndexPutFirstAxisBackward at 0x2911fdb40>),
 (9, <ViewBackward0 at 0x29146e9a0>),
 (10, <ViewBackward0 at 0x29146e880>),
 (11, <SliceBackward0 at 0x29146ec70>),
 (12, <SliceBackward0 at 0x29146ebe0>),
 (13, <SelectBackward0 at 0x29146eb20>),
 (14, <SliceBackward0 at 0x29146ea90>),
 (15, <SliceBackward0 at 0x29146ea30>),
 (16, <PermuteBackward0 at 0x29146e940>),
 (17, <SliceBackward0 at 0x29146eca0>),
 (18, <SliceBackward0 at 0x29146ec10>),
 (19, <SelectBackward0 at 0x29146eb50>),
 (20, <SliceBackward0 at 0x29146eac0>),
 (21, <SliceBackward0 at 0x29146ea60>),
 (22, <PermuteBackward0 at 0x29146e970>),
 (23, <S

In [8]:
dir(visited[3])

['__call__',
 '__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '_input_metadata',
 '_raw_saved_bias',
 '_raw_saved_input',
 '_raw_saved_result1',
 '_raw_saved_result2',
 '_raw_saved_weight',
 '_register_hook_dict',
 '_saved_bias',
 '_saved_input',
 '_saved_normalized_shape',
 '_saved_result1',
 '_saved_result2',
 '_saved_weight',
 '_sequence_nr',
 '_set_sequence_nr',
 'metadata',
 'name',
 'next_functions',
 'register_hook',
 'register_prehook',
 'requires_grad']

In [6]:
# For denominator stability in relevance distribution
epsilon = 10e-6

def renormalize_epsilon(rz, rx, ry):
    # Renormalizes output relevances after dividing by a denominator with epsilon added to preserve conservation
    scale = rz / (rx + ry)
    return rx * scale, ry * scale

In [110]:
class AddBackwardPromise:
    def __init__(self, promise, idx):
        # promise: shared data between the promise origin and both branches.
        # idx: specifies which argument/operand the branch is looking for.
        # fwd: applies all operations to the operand found from a branch to the origin of the promise.
        # bwd: applies all operations to the relevance of the operand from the origin of the promise
        #   to the end of the branch, possibly in steps if one or more Checkpoints were on the branch.
        #   Structure: [ (checkpoint1, fcn_to_get_from_origin_to_checkpoint1),
        #                (checkpoint2, fcn_to_get_from_checkpoint1_to_checkpoint2),
        #                ...
        #                (None, fcn_to_get_from_last_checkpoint_to_curnode) ]
        #   So you should apply from left to right, but the inner functions themselves nest right to left.
        # fwd_shape: used as target shape for shape-modifying operations in fwd
        self.promise = promise
        self.parent = promise["parent"]
        self.children = None # Will be set to list[AddBackwardPromise] if further nested AddBackward Nodes are found.
        self.idx = idx
        self.fwd = lambda x: x
        self.bwd = [ (None, lambda x: x) ] # This will chain in case we come across a Checkpoint partway through
        self.fwd_shape = promise["rout"].shape # This will update after a shape-modifying operation is added to fwd
        self.other_branch = None

    def nest_fwd(self, next_f):
        # Nests a new operation for recovering the operand for the promise origin
        self.fwd = lambda x: self.fwd(next_f(x))

    def checkpoint(self, new_checkpoint):
        # Marks a checkpoint in the backwards op chain and opens a new chain after the checkpoint
        self.bwd[-1] = (new_checkpoint, self.bwd[-1][1])
        self.bwd.append((None, lambda x: x))

    def nest_bwd(self, next_f):
        # Stacks on a new operation for transforming the promised relevance back down the branch
        last_checkpoint, most_recent_f = self.bwd[-1] # last_checkpoint is actually always None here
        self.bwd[-1] = (last_checkpoint, lambda x: next_f(most_recent_f(x)))

    @property
    def ready(self):
        return self.promise["ready"]

    @property
    def complete(self):
        # Flags if the promise is done all forward and backward execution
        # If the promise is complete and has children, it will have set its children's rout value to its bwd(rin)
        # result. So the children need only check if parent.complete is True to begin their own exec_bwd().
        return self.promise["complete"]

    @property
    def arg1(self):
        return self.promise["args"][0] # TODO: see if we really need both of these or just the arg for the branch

    @property
    def arg2(self):
        return self.promise["args"][1]

    @property
    def shape(self):
        return self.fwd_shape

    @property
    def rin(self):
        return self.promise["rins"][self.idx]

    @property
    def rout(self):
        return self.promise["rout"]

    def set_rout(self, new_rout):
        self.promise["rout"] = new_rout

    def exec_bwd(self):
        # Perform each saved backward execution chain to propagate relevance back down the branch.
        # Save values for any checkpoints marked along the path and return them with their respective checkpoints.
        assert(self.ready and (self.parent is None or self.parent.complete),
               "Promise backward execution was triggered before promise was ready or before parent promise was complete.")
        res = self.rin
        checkpoints = []
        for checkpoint, fcn in self.bwd:
            res = fcn(res)
            if checkpoint is not None:
                checkpoints.append((checkpoint, res))
        return checkpoints, res

    def compute_rins(self):
        # Compute base branch relevances based on sum of squares ratios.
        assert(self.ready and (self.parent is None or self.parent.complete))
        arg1, arg2 = self.promise["args"]
        r = self.promise["rout"]
        denom = arg1 ** 2 + arg2 ** 2 + epsilon
        r1 = (arg1 ** 2 / denom) * r
        r2 = (arg2 ** 2 / denom) * r
        self.promise["rins"][0], self.promise["rins"][1] = renormalize_epsilon(r, r1, r2)

    def trigger_promise_completion(self):
        # This is only called once a promise receives its second argument.
        assert(self.ready, "Promise completion was triggered before promise was ready.")
        if self.parent is None or self.parent.complete:
            # Either reached root of a promise tree, or we are in the exec_bwd call of a child of a completed promise.
            self.compute_rins()
            checkpoints1, res1 = self.exec_bwd()
            checkpoints2, res2 = self.other_branch.exec_bwd()
            # Save checkpoint relevances to their grad_fn metadatas to collect later.
            for checkpoint, val in checkpoints1 + checkpoints2:
                checkpoint.metadata["checkpoint_relevance"] = val
            self.complete = True

            if self.children is not None:
                # Now that we have calculated the end relevance_in of this branch, we can feed it to the children promises.
                for child_promise in self.children:
                    child_promise.set_rout(res1)
                    child_promise.trigger_promise_completion()

            if self.other_branch.children is not None:
                # Do the same for the other branch in this promise. (I should really make Promise and Branch two different classes...)
                for child_promise in self.other_branch.children:
                    child_promise.set_rout(res2)
                    child_promise.trigger_promise_completion()

        else:
            # If there is a parent promise, but it is not complete yet, we can now set its arg with this promise's result.
            # This is what triggers the propagation of the arguments back to the root of the promise tree.
            self.promise["parent"].setarg(self.arg1 + self.arg2)

    def setarg(self, value):
        # Set the corresponding arg for this branch and check if the promise is ready
        self.promise["args"][self.idx] = self.fwd(value)
        self.promise["ready"] = all([ x is not None for x in self.promise["args"] ])
        if self.promise["ready"]:
            self.trigger_promise_completion()

def AddBackwardProp(grad_fn, r):
    # IMPORTANT: AddBackward0 does not actually store any operands of the addition, so we have
    # to get a bit tricky.
    # The idea is to return a "promise", a dict wrapped in a class which contains the outgoing relevance, and
    # placeholders for the operands and their respective relevances.
    # From what I know right now, AddBackward0 is the only math-op grad_fn that does this, so the hope is
    # that we pass this promise down the graph, and we encounter one of:
    #   1. AccumulateGrad or another math-op that we can get the result from
    #   2. A function that follows the identity or uniform rule like GeluBackward0 or LayerNormBackward0
    #   3. A mutation function like SliceBackward0 or ReshapeBackward0
    #   4. (worst case) Another AddBackward0
    # For case 2 and 3, we would have to keep an arbitrarily composable function which progressively
    # nests the operations that must be done on the result, once it is found, to make it equivalent
    # to the downstream addition operand. When we find a node with the result, we simply apply f(result)
    # to get the actual operand for the original addition.
    # However, we will also need to keep a similar function but for going backwards from the addition
    # back to the result node, but this time for the relevance.
    # If at this time, both operands have been found, compute and store the relevances for both in the
    # promise. If not, move this node to the 
    # For case 4, we would simply have to nest a promise within the existing promise.
    # So the only time this algorithm will branch is if there are multiple additions with no result-
    # yielding grad_fn's in between.

    promise = {
        "rout": r,
        "args": [None, None],
        "rins": [None, None],
        "ready": False,
        "complete": False,
        "parent": None,
    }
    if isinstance(r, AddBackwardPromise):
        promise["parent"] = r
        promise["rout"] = torch.zeros(r.fwd_shape) # Placeholder for shape

    promise1 = AddBackwardPromise(promise, 0)
    promise2 = AddBackwardPromise(promise, 1)

    promise1.other_branch = promise2
    promise2.other_branch = promise1

    if isinstance(r, AddBackwardPromise):
        r.children = [promise1, promise2]

    grad_fn.metadata["promise"] = promise

    return promise1, promise2

In [None]:
# Handling AddMmBackward0 is not the exact same as just AddBackward0. The calculation of the in-relevances
# is slightly different because you need to consider the matmul after the addition is propagated.
# AddMmBackward0 also has 3 inputs, rather than 2 as in AddBackward0, so they will not be easily
# compatible with the current promise tree structure.
# I created the promise and promise tree structure with the aim to not have to re-visit the promise
# origin nodes after sending out the promises (we re-visit the promise tree nodes, but not the grad_fn
# nodes themselves). AddMmBackward0 would add a good amount of extra work if we were to handle it with
# promises, it would require a completely different promise completion handling sequence.
# It would be much easier to simply decompose an AddMmBackward0 into an AddBackward0 and a MmBackward0
# in our graph, then traverse using the normal AddBackward0 promises, where we can fill in the Mm side first.

# Since the autograd Nodes are code-generated and not exposed to the torch API, we just redefine shell
# classes for the ones we need to instantiate, with the fields we need according to the dir() of the
# original classes.
class AddBackward0:
    def __init__(self, next_functions):
        self.name = "AddBackward0"
        self.next_functions = next_functions
        self.metadata = {}

class MmBackward0:
    def __init__(self, next_functions, mat1, mat2):
        self.name = "MmBackward0"
        self.next_functions = next_functions
        self._saved_self = mat1
        self._saved_self_sym_sizes = mat1.shape
        self._saved_mat2 = mat2
        self._saved_mat2_sym_sizes = mat2.shape

def decompose_addmmbackward(grad_fn):
    # Assuming grad_fn is an instance of AddMmBackward, returns an AddBackward0 instance that is the parent
    # of an MmBackward0 instance and the first function in grad_fn.next_functions.
    # The MmBackward0 is then the parent of the last two functions in grad_fn.next_functions.
    mm_fn = MmBackward0(grad_fn.next_functions[1:], grad_fn._saved_mat1, grad_fn._saved_mat2)
    add_fn = AddBackward0((grad_fn.next_functions[0], (mm_fn, 0)))

    return add_fn


In [None]:
import torch.nn.functional as F

"""
For all these functions, grad_fn is the autograd Node returned from traversing the autograd graph.
r is the relevance tensor of the output of the given Node.
"""

def ViewBackwardProp(grad_fn, r):
    upstream_shape = grad_fn._saved_self_sym_sizes
    if isinstance(r, AddBackwardPromise):
        target_shape = r.fwd_shape
        r.nest_fwd(lambda x: torch.view(x, target_shape))
        r.nest_bwd(lambda x: torch.view(x, upstream_shape))
        r.fwd_shape = upstream_shape
        return r
    return torch.view(r, upstream_shape)

def UnsafeViewBackwardProp(grad_fn, r):
    upstream_shape = grad_fn._saved_self_sym_sizes
    if isinstance(r, AddBackwardPromise):
        target_shape = r.fwd_shape
        r.nest_fwd(lambda x: torch.view(x, target_shape))
        r.nest_bwd(lambda x: torch.view(x, upstream_shape))
        r.fwd_shape = upstream_shape
        return r
    return torch.view(r, upstream_shape)

def ReshapeBackwardProp(grad_fn, r):
    upstream_shape = grad_fn._saved_self_sym_sizes
    if isinstance(r, AddBackwardPromise):
        target_shape = r.fwd_shape
        r.nest_fwd(lambda x: torch.reshape(x, target_shape))
        r.nest_bwd(lambda x: torch.reshape(x, upstream_shape))
        r.fwd_shape = upstream_shape
        return r
    return torch.reshape(r, grad_fn._saved_self_sym_sizes)

def SliceBackwardProp(grad_fn, r):
    # Assumes the index corresponding to _saved_start in the forward pass is non-negative.
    # If it was negative-indexed, i.e. x[-i:] autograd saves the index as INT_MAX - i.
    upstream_shape = grad_fn._saved_self_sym_sizes
    sliced_dim = grad_fn._saved_dim
    start = grad_fn._saved_start # TODO: Come back to take care of the negative index case.
    full_size = upstream_shape[sliced_dim]
    end = start + r.shape[sliced_dim]

    # We wish to pad r so that it becomes the correct size along the sliced dimension
    pad = []
    to_pad = [start, full_size - end]

    # Iterate in reverse order, since F.pad() takes in dims from last to first,
    # see https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
    # All dims other than sliced_dim should be 0, 0
    for dim in range(len(upstream_shape) - 1, -1, -1):
        pad += [0, 0] if dim != sliced_dim else to_pad
    pad = tuple(pad)

    if isinstance(r, AddBackwardPromise):
        r.nest_fwd(lambda x: torch.ops.aten.slice(x, sliced_dim, start, end))
        r.nest_bwd(lambda x: F.pad(x, pad))
        r.fwd_shape = upstream_shape
        return r
    return F.pad(r, pad)

def IndexBackwardProp(grad_fn, r):
    # An Index can be compound, unlike Slice, i.e. a[[0,1], [1,2]] is ONE Index op, whereas a[:,1:] is TWO Slice ops.
    # This is because (in this case) the second Slice depends on the first. It's saying that from the result of
    # the first slice, for each element, select index 1 from the first, and index 2 from the second (assuming of
    # course that 1 and 2 are in bounds for each element returned by the first slice).
    # Therefore, the length of the first Slice acts as an upper bound for the length of the Slices that succeed it.
    # If you wanted it to select indices 1 and 2 for each resulting element, you would just use a Slice for the last
    # dim instead of an Index.
    upstream_shape = grad_fn._saved_self_sym_sizes
    
    idxs = [ torch.tensor(x) if x is not None else None for x in grad_fn._saved_indices ]

    def undoIndex(x):
        out = torch.zeros(upstream_shape, dtype=x.dtype, device=x.device)
        return torch.ops.aten.index_put(out, idxs, x)

    if isinstance(r, AddBackwardPromise):
        r.nest_fwd(lambda x: torch.ops.aten.index(x, idxs))
        r.nest_bwd(undoIndex)
        r.fwd_shape = upstream_shape
        return r
    return undoIndex(r)
            

def SelectBackwardProp(grad_fn, r):
    upstream_shape = grad_fn._saved_self_sym_sizes
    dim = grad_fn._saved_dim
    idx = grad_fn._saved_index

    def undoSelect(x):
        out = torch.zeros(upstream_shape, dtype=x.dtype, device=x.device)
        x_expanded = torch.unsqueeze(x, dim)
        out.select(dim, idx).copy_(x_expanded)
        return out

    if isinstance(r, AddBackwardPromise):
        r.nest_fwd(lambda x: torch.select(x, dim, idx))
        r.nest_bwd(undoSelect)
        r.fwd_shape = upstream_shape
        return r

    return undoSelect(r)

def TBackwardProp(grad_fn, r):
    # Not sure why TBackward is different from TransposeBackward, but it seems like this is only
    # in Linear layer matmuls on W for xW^T before Mm and Addmm operations.
    assert(len(r.shape) == 2, "Assumption was that matrix would be 2d Linear weights.") # For now assume that it is only 2d matmuls for Linear layers.

    transpose = lambda x: x.T

    if isinstance(r, AddBackwardPromise):
        r.nest_fwd(transpose)
        r.nest_bwd(transpose)
        new_shape = list(r.fwd_shape)
        new_shape[0], new_shape[1] = new_shape[1], new_shape[0]
        r.fwd_shape = tuple(new_shape)
        return r

    return transpose(r)

def TransposeBackwardProp(grad_fn, r):
    dim1 = grad_fn._saved_dim0
    dim2 = grad_fn._saved_dim1

    if dim1 == 2**32 - 2:
        dim1 = -2
    if dim2 == 2**32 - 1:
        dim2 = -1

    swapaxes = lambda x: torch.swapaxes(x, dim1, dim2)

    if isinstance(r, AddBackwardPromise):
        r.nest_fwd(swapaxes)
        r.nest_bwd(swapaxes)
        new_shape = list(r.fwd_shape)
        new_shape[dim1], new_shape[dim2] = new_shape[dim2], new_shape[dim1]
        r.fwd_shape = tuple(new_shape)
        return r
    
    return swapaxes(r)

def PermuteBackwardProp(grad_fn, r):
    dims = grad_fn._saved_dims
    permute = lambda x: torch.permute(x, dims)

    if isinstance(r, AddBackwardPromise):
        r.nest_fwd(permute)
        r.nest_bwd(permute)
        new_shape = list(r.fwd_shape)
        for old_dim, new_dim in enumerate(dims):
            new_shape[old_dim] = r.fwd_shape[new_dim]
        r.fwd_shape = tuple(new_shape)
        return r
    return permute(r)

def ExpandBackwardProp(grad_fn, r):
    upstream_shape = grad_fn._saved_self_sym_sizes
    downstream_shape = r.shape
    assert(len(upstream_shape) == len(downstream_shape), "Expand should not increase number of dimensions.")

    expand_input = [ dim2 if dim1 != dim2 else -1 for dim1, dim2 in zip(upstream_shape, downstream_shape) ]

    def undoExpand(x):
        for i, expand_dim in enumerate(expand_input):
            if expand_dim != -1:
                x = x.select(i, 0).unsqueeze(i)
        return x

    expand = lambda x: x.expand(*expand_input)

    if isinstance(r, AddBackwardPromise):
        r.nest_fwd(expand)
        r.nest_bwd(undoExpand)
        r.fwd_shape = upstream_shape
        return r

    return undoExpand(r)

def MulBackwardProp(grad_fn, r):
    arg1 = grad_fn._saved_self
    arg2 = grad_fn._saved_other

    if isinstance(r, AddBackwardPromise):
        if arg1 is None:
            r.nest_fwd(lambda x: x * arg2)
        else:
            r.setarg(arg1 * arg2)
            if r.complete:
                r = r.rin
            else:
                return None # Trigger requeue

    if arg1 is None:
        # Tensor-scalar product, disregard scalar
        return r, 0.0

    denom = arg1.abs() + arg2.abs() + epsilon
    r1 = (arg1.abs() / denom) * r
    r2 = (arg2.abs() / denom) * r

    return renormalize_epsilon(r, r1, r2)

def DivBackwardProp(grad_fn, r):
    arg1 = grad_fn._saved_self
    arg2 = grad_fn._saved_other

    if isinstance(r, AddBackwardPromise):
        if arg1 is None:
            r.nest_fwd(lambda x: x / arg2)
        else:
            r.setarg(arg1 / arg2)
            if r.complete:
                r = r.rin
            else:
                return None # Trigger requeue

    if arg1 is None:
        # Tensor-scalar product, disregard scalar
        return r, 0.0

    denom = arg1.abs() + (1 / arg2).abs() + epsilon
    r1 = (arg1.abs() / denom) * r
    r2 = ((1 / arg2).abs() / denom) * r

    return renormalize_epsilon(r, r1, r2)

def MmBackwardProp(grad_fn, r):
    x = grad_fn._saved_mat1 # i j
    weights = grad_fn._saved_mat2 # j k
    z = torch.matmul(x, weights)
    if isinstance(r, AddBackwardPromise):
        r.setarg(z)
        if r.complete:
            r = r.rin
        else:
            # If this is the first branch of the promise
            return None # Make this trigger a requeue

    i, j = x.shape
    k = weights.shape[1]
    intermediates = torch.einsum("ij, jk -> ijk", x, weights)

    z = z.unsqueeze(1).broadcast_to((i,j,k))

    ratios = intermediates / z

    # return relevance for input and relevance for weight
    return ratios.sum(dim=2, keepdims=True) * r, ratios.sum(dim=0, keepdims=True) * r.T

def BmmBackwardProp(grad_fn, r):
    mat1 = grad_fn.saved_self
    mat2 = grad_fn.saved_mat2
    z = torch.matmul(mat1, mat2)
    if isinstance(r, AddBackwardPromise):
        r.setarg(z)
        if r.complete:
            r = r.rin
        else:
            # If this is the first branch of the promise
            return None # Make this trigger a requeue

    b, i, j = mat1.shape
    k = mat2.shape[-1]
    intermediates = torch.einsum("bij, bjk -> bijk", mat1, mat2)

    z = z.unsqueeze(2).broadcast_to((b,i,j,k))

    ratios = intermediates / z

    # return relevance for mat1 and relevance for mat2
    return ratios.sum(dim=2, keepdims=True) * r, ratios.sum(dim=1, keepdims=True) * r.T

def NativeLayerNormBackwardProp(grad_fn, r):
    mean = grad_fn._saved_result1
    gamma = grad_fn._saved_weight
    beta = grad_fn._saved_bias
    rec_stddev = grad_fn._saved_result2
    def layerNorm(x):
        normalized = (x - mean) * rec_stddev
        return normalized * gamma + beta
    if isinstance(r, AddBackwardPromise):
        # Identity for relevance going backwards, no need for bwd
        r.nest_fwd(layerNorm)

    # next_functions will correspond to input, weights, bias
    # We only care about propagating through the input for LayerNorm.
    return r, 0.0, 0.0

def GeluBackwardProp(grad_fn, r):
    if isinstance(r, AddBackwardPromise):
        # Identity for relevance going backwards, no need for bwd
        r.nest_fwd(torch.nn.GELU(approximate="none"))
    return r

def SoftmaxBackwardProp(grad_fn, r):
    if isinstance(r, AddBackwardPromise):
        # Identity for relevance going backwards, no need for bwd
        r.nest_fwd(lambda x: x.softmax(dim=-1))
    return r

def IdentityProp(grad_fn, r):
    """Placeholder for any missed operations, or general use for identity-rule operations."""
    return r
    
def AccumulateGradProp(grad_fn, r):
    if isinstance(r, AddBackwardPromise):
        r.setarg(grad_fn.variable)
    return 0.0

def LRPCheckpointBackwardProp(grad_fn, r):
    if isinstance(r, AddBackwardPromise):
        r.checkpoint(grad_fn)
    else:
        grad_fn.metadata["checkpoint_relevance"] = r
    return r
    

In [231]:
start = 3072
end = 6144
to_pad = [start, 6144 - end]
sliced_dim = 2
pad = tuple(sum([ [0,0] if i != sliced_dim else to_pad for i in range(3, -1, -1) ], []))
print(pad)


(0, 0, 3072, 0, 0, 0, 0, 0)


In [232]:
a = torch.rand((1,17,3072,3))
a.shape
b = F.pad(a, pad)

In [142]:
visited[:10].index(visited[2])

2

In [145]:
# Graph
out_adj = {}
in_adj = {}

root = hidden_states.grad_fn
visited = set()
fcns = [ [root] ]
# idea: dynamically init relevance variables when branching occurs, assign them to the corresponding
# downstream nodes they should belong to using next_functions and visited table. Requires 2 passes.
# Perhaps need a last_saved_relevance for each node, in the case when a node is traversed more than once to accumulate relevance.
# Need all incoming branches to land before we continue, else we need to compute the same downstream paths multiple times for
# each incoming branch.
# Modified DFS? Traverse down a path, creating all necessary relevance branches until a node with multiple in-edges is reached.
# We will need a modified graph as well with in_children for each node to determine the above condition.

# First pass will:
# 1. Create in and out adjacency lists.
# 2. Decompose AddMmBackward0's with AddBackward0 leading back to MmBackward0.
while fcns:
    new_fcns = []
    for fcn_list in fcns:
        for fcn in fcn_list:

            if fcn is None or fcn in visited:
                continue

            if type(fcn).__name__ == "AddMmBackward0":
                # Decompose the function into an Add + Mm, then re-assign its adjacencies.
                decomposed_add = decompose_addmmbackward(fcn)
                # Assign new Add's in-neighbours to the AddMm's in-neighbours.
                in_adj[decomposed_add] = in_adj[fcn]
                for in_neighbour in in_adj[fcn]:
                    # Replace all out-edges going to the AddMm to point to the new Add.
                    old_fcn_idx = out_adj[in_neighbour].index(fcn)
                    out_adj[in_neighbour][old_fcn_idx] = decomposed_add
                del in_adj[fcn]
                fcn = decomposed_add

            # Assign adjacencies
            if fcn not in out_adj:
                out_adj[fcn] = []
            for (child, _) in fcn.next_functions:
                out_adj[fcn].append(child)
                if child not in in_adj:
                    in_adj[child] = []
                in_adj[child].append(fcn)

            visited.add(fcn)

            new_fcns.append([ fcn_tup[0] for fcn_tup in fcn.next_functions ])

    # Iterate
    fcns = new_fcns

In [147]:
for k, v in list(in_adj.items()):
    if k == "AddMmBackward0" or "AddMmBackward0" in v:
        print(k, v)
print(out_adj)

{<ViewBackward0 object at 0x17757e100>: [<torch.autograd.function.IndexPutFirstAxisBackward object at 0x17758c040>], <torch.autograd.function.IndexPutFirstAxisBackward object at 0x17758c040>: [<NativeLayerNormBackward0 object at 0x177461370>, None], <NativeLayerNormBackward0 object at 0x177461370>: [<AddBackward0 object at 0x177582f70>, <AccumulateGrad object at 0x177582be0>, <AccumulateGrad object at 0x1775827f0>], <AddBackward0 object at 0x177582f70>: [<AddmmBackward0 object at 0x177582a00>, <NativeLayerNormBackward0 object at 0x177582610>], <AccumulateGrad object at 0x177582be0>: [], <AccumulateGrad object at 0x1775827f0>: [], <AddmmBackward0 object at 0x177582a00>: [<AccumulateGrad object at 0x177582820>, <MulBackward0 object at 0x177582430>, <TBackward0 object at 0x177582640>], <NativeLayerNormBackward0 object at 0x177582610>: [<AddBackward0 object at 0x1775822b0>, <AccumulateGrad object at 0x177582460>, <AccumulateGrad object at 0x177582130>], <AccumulateGrad object at 0x17758282

In [52]:
visited[-137]._saved_self_sym_sizes

(17, 768)

In [34]:
[ acc.variable.shape for acc in visited[-10:] ]

[torch.Size([768, 768]),
 torch.Size([2304, 768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([2304]),
 torch.Size([2304, 768]),
 torch.Size([2304]),
 torch.Size([2304, 768]),
 torch.Size([2304]),
 torch.Size([2304, 768])]

In [16]:
dir(visited[18])

['__call__',
 '__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '_input_metadata',
 '_raw_saved_mat1',
 '_raw_saved_mat2',
 '_register_hook_dict',
 '_saved_alpha',
 '_saved_beta',
 '_saved_mat1',
 '_saved_mat1_sym_sizes',
 '_saved_mat1_sym_strides',
 '_saved_mat2',
 '_saved_mat2_sym_sizes',
 '_saved_mat2_sym_strides',
 '_sequence_nr',
 '_set_sequence_nr',
 'metadata',
 'name',
 'next_functions',
 'register_hook',
 'register_prehook',
 'requires_grad']

18446744073709551615

In [130]:
for fcn in visited:
    if fcn.name() == "SliceBackward0" and not (fcn._saved_start == 0 and fcn._saved_end == 9223372036854775807):
        print(fcn._saved_dim, fcn._saved_start, fcn._saved_end)
print(len([ fcn for fcn in visited if fcn.name() == "SliceBackward0" ]))

1 0 3072
1 3072 9223372036854775807
1 0 3072
1 3072 9223372036854775807
1 0 3072
1 3072 9223372036854775807
1 0 3072
1 0 3072
1 3072 9223372036854775807
1 3072 9223372036854775807
1 3072 9223372036854775807
1 0 3072
1 0 3072
1 3072 9223372036854775807
1 0 3072
1 3072 9223372036854775807
1 0 3072
1 3072 9223372036854775807
1 3072 9223372036854775807
1 0 3072
1 0 3072
1 3072 9223372036854775807
1 0 3072
1 3072 9223372036854775807
192


In [112]:
dir(visited[3])

['__call__',
 '__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '_input_metadata',
 '_register_hook_dict',
 '_saved_dim',
 '_saved_end',
 '_saved_self_sym_sizes',
 '_saved_start',
 '_saved_step',
 '_sequence_nr',
 '_set_sequence_nr',
 'metadata',
 'name',
 'next_functions',
 'register_hook',
 'register_prehook',
 'requires_grad']

In [116]:
visited[3]._saved_end

9223372036854775807

In [190]:
a = torch.randn(1, requires_grad=True)
b = a*(a + a * 2 + 2)
multa = b.grad_fn._saved_self
multb = b.grad_fn._saved_other
with torch.no_grad():
    print(torch.concat((multa, multb)) / (multa + multb))
print (b.grad_fn.next_functions)
print(dir(b.grad_fn.next_functions[1][0]))
print(b.grad_fn.next_functions[1][0])
print (b.grad_fn.next_functions[1][0].next_functions)
print (b.grad_fn.next_functions[0][0].variable is a)
print(b.grad_fn.next_functions[1][0].next_functions[0][0].next_functions)
print(b.grad_fn.next_functions[1][0].next_functions[0][0].next_functions[1][0]._saved_self)
print(b.grad_fn.next_functions[1][0].next_functions[0][0].next_functions[1][0].next_functions)

tensor([0.1706, 0.8294])
((<AccumulateGrad object at 0x17a75efa0>, 0), (<AddBackward0 object at 0x130051940>, 0))
['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_input_metadata', '_register_hook_dict', '_saved_alpha', '_sequence_nr', '_set_sequence_nr', 'metadata', 'name', 'next_functions', 'register_hook', 'register_prehook', 'requires_grad']
<AddBackward0 object at 0x13005a940>
((<AddBackward0 object at 0x1300111f0>, 0), (None, 0))
True
((<AccumulateGrad object at 0x1300111f0>, 0), (<MulBackward0 object at 0x13005a940>, 0))
None
((<AccumulateGrad object at 0x1300111f0>, 0), (None, 0))
