In [1]:
from transformers import AutoModel, AutoTokenizer
import torch

In [2]:
from lrp_graph import make_graph
from lrp_prop_fcns import LRPPropFunctions
from add_backward_promise import AddBackwardPromise, compound_promises
from util import create_checkpoint

In [3]:
def checkpoint_hook(module, input, output):
    return create_checkpoint(output)

In [4]:
model_name = "zhihan1996/DNABERT-2-117M"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

Some weights of BertModel were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
for layer_module in model.encoder.layer:
    layer_module.attention.self.register_forward_hook(checkpoint_hook)

In [11]:
dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]
hidden_states : torch.Tensor = model(inputs, requires_grad=True)[0] # [1, sequence_length, 768]

In [12]:
# from util import create_checkpoint
# a = torch.rand((5,4,2), requires_grad=True)
# b = torch.rand((5,4,2), requires_grad=True)
# hidden_states = a + create_checkpoint(b)

In [13]:
from torch.autograd.graph import Node

in_adj_list, out_adj_list, names = make_graph(hidden_states)
input_tracker : dict[Node, list] = { k : [] for k in list(in_adj_list.keys()) }
checkpoints = list(filter(lambda k: type(k).__name__ == "LRPCheckpointBackward", list(in_adj_list.keys())))
num_checkpoints_reached = 0

fcn_map = LRPPropFunctions.generate_prop_fcn_map(names)

In [14]:
debug = None

In [15]:
visited1 = set()
with torch.no_grad():
    # Create the first relevance layer via max logit.
    m = hidden_states.max(-1)
    relevance = torch.zeros_like(hidden_states)
    b, s, d = hidden_states.shape
    for i, inds in enumerate(m.indices):
        relevance[i,list(range(s)),inds] = torch.ones_like(m.values[0])

    # Setup the first iteration
    input_tracker[hidden_states.grad_fn] = [ relevance ]
    stack = [hidden_states.grad_fn]
    in_adj_list[hidden_states.grad_fn] = []
    nodes_pending = { k : len(v) for k, v in list(in_adj_list.items()) }

    promise_queue : list[Node] = []

    promise_traversal_stack = []
    promise_traversal_mode = False

    promise_fulfillment_mode = False

    while (stack or promise_traversal_stack or promise_queue) and num_checkpoints_reached < len(checkpoints):
        # Pop first element of either promise_queue or main stack
        curnode = None

        if None in visited1:
            raise Exception

        print(stack, promise_traversal_stack, promise_queue)
        
        if promise_queue and any(fcn.metadata["promise"]["complete"] for fcn in promise_queue):
            # Search for the first complete promise in the queue.
            curnode = next(( fcn for fcn in promise_queue if fcn.metadata["promise"]["complete"] ))
            idx = promise_queue.index(curnode)
            promise_queue = promise_queue[:idx] + promise_queue[idx + 1:]
            promise_traversal_mode = False
            promise_fulfillment_mode = True
        elif promise_queue and any(nodes_pending[fcn] == 0 and "pre_promise" in fcn.metadata and
                                   all(parent.complete for parent in fcn.metadata["pre_promise"].parents)
                                   for fcn in promise_queue):
            curnode = next(( fcn for fcn in promise_queue if nodes_pending[fcn] == 0 and 
                            all(parent.complete for parent in fcn.metadata["pre_promise"].parents) ))
            idx = promise_queue.index(curnode)
            promise_queue = promise_queue[:idx] + promise_queue[idx + 1:]
            promise_traversal_mode = False
            promise_fulfillment_mode = False
        elif promise_traversal_stack:
            # Second priority is promise traversal, which overrides the requirement for all inputs to land
            # before traversing a node. However, the promise will not have its rins computed until 
            curnode = promise_traversal_stack[0]
            promise_traversal_stack = promise_traversal_stack[1:]
            promise_traversal_mode = True
            promise_fulfillment_mode = False
        elif stack:
            curnode = stack[0]
            stack = stack[1:]
            promise_traversal_mode = False
            promise_fulfillment_mode = False

        curnode_inputs = input_tracker[curnode]

        visited1.add(curnode)
        # if any([ x is None for x in curnode_inputs ]):
        #     # Node hasn't received all its relevance yet, push it to the back.
        #     stack = stack + [curnode]
        #     continue

        children = out_adj_list[curnode]

        if not promise_fulfillment_mode:

            # Categorize all inputs into either pending promises, complete promises, or tensors
            pending_promise_inputs = []
            complete_promise_inputs = []
            tensor_inputs = []
            for input_ in curnode_inputs:
                if isinstance(input_, torch.Tensor):
                    tensor_inputs.append(input_)
                elif isinstance(input_, AddBackwardPromise) and input_.complete:
                    complete_promise_inputs.append(input_)
                elif isinstance(input_, AddBackwardPromise):
                    pending_promise_inputs.append(input_)
                elif input_ == 0.0:
                    continue
                else:
                    print(input_)
                    raise ValueError(f"Expected relevance input to Node {curnode} to be type AddBackwardPromise or Tensor, but got {type(input_)} instead.")
    
            if not complete_promise_inputs and not pending_promise_inputs and not tensor_inputs:
                continue
    
            # Aggregate all inputs into one Tensor or AddBackwardPromise
            curnode_in_rel = sum(tensor_inputs) + sum([ p.rin for p in complete_promise_inputs ])
            if pending_promise_inputs:
                # In promise traversal mode this will be True
                agg_promises = compound_promises(pending_promise_inputs, promise_traversal_mode, promise_traversal_mode)
                if curnode_in_rel != 0:
                    curnode_in_rel = agg_promises + curnode_in_rel
                else:
                    curnode_in_rel = agg_promises
        else:
            curnode_in_rel = curnode.metadata["promise"]["rins"][curnode.metadata["promise_idx"]]


        if not promise_traversal_mode and "pre_promise" in curnode.metadata:
            # We have already traversed a promise tree, but have not calculated its bwd,
            # since it was done in promise traversal mode.
            pre_promise : AddBackwardPromise = curnode.metadata["pre_promise"]

            assert pre_promise.ready, f"Pre-promise at {curnode} was assumed to be ready but was not."

            if not pre_promise.complete:
                if isinstance(curnode_in_rel, AddBackwardPromise):
                    # If there is still pending promises at this node, try to complete them via the aggregate promise.
                    # In the case this completes and propagates relevance down, we will have to pick up from the tail nodes
                    # of the aggregate promise.
                    curnode_in_rel.children.append(pre_promise)
                    curnode_in_rel.setarg(pre_promise.arg1 + pre_promise.arg2)
                else:
                    pre_promise.accumulate_rout(curnode_in_rel)
                    pre_promise.trigger_promise_completion()

            tail_nodes = pre_promise.promise["tail_nodes"]
            if curnode in tail_nodes:
                tail_nodes.remove(curnode)
            if tail_nodes:
                # Don't know if the promise is complete yet, so put on the promise queue.
                # If any are done, they will be traversed with priority.
                promise_queue += tail_nodes
                continue
            else:
                # If the pre-promise is a singleton, i.e. the node is the tail of its own pre-promise,
                # just collect the computed rin and re-traverse this node with a tensor rin input like normal.
                curnode_in_rel = pre_promise.rin

        if promise_traversal_mode:
            # We want to save this so later we'll know we've already traversed this node.
            curnode.metadata["pre_promise"] = curnode_in_rel

        # Call the propagation function for the node
        curnode_outputs = fcn_map[type(curnode).__name__](curnode, curnode_in_rel)

        if isinstance(curnode_outputs, AddBackwardPromise) and curnode_outputs.arg is not None and not curnode_outputs.complete:
            # Node is waiting on Promise to be completed, add to promise queue and come back later.
            curnode.metadata["promise"] = curnode_outputs.promise
            curnode.metadata["promise_idx"] = curnode_outputs.idx
            promise_queue.append(curnode)
            continue

        # Children may contain None, like grad_fn.next_functions, to keep integrity of input tracking
        if len(children) == 0 or all(child is None for child in children):
            continue
        elif len(children) == 1:
            # if isinstance(curnode_outputs, tuple):
            #     curnode_outputs = [ curnode_outputs[0] ]
            # else:
            curnode_outputs = [ curnode_outputs ]
            
        elif len(children) != len(curnode_outputs):
            raise ValueError(f"Mismatch: {len(children)} children but {len(curnode_outputs)} outputs from {curnode}.")


        # Update child inputs
        for i, child in enumerate(children):
            if child is None:
                # Discard the input (it shouldn't have value anyway), if it's a promise make it a zero-promise
                if isinstance(curnode_outputs[i], AddBackwardPromise):
                    # Manually set the arg to not trigger any additional side effects
                    curnode_outputs[i].promise["args"][curnode_outputs[i].idx] = 0.0
                continue
            input_tracker[child].append(curnode_outputs[i])
            nodes_pending[child] -= 1
            assert nodes_pending[child] >= 0, f"Negative pending count for node {child}"
            assert len(input_tracker[child]) <= len(in_adj_list[child]), \
                f"Too many inputs landed for {child}"

        # Collect children who now have all their inputs or that have promise(s) depending on them.
        ready_children = []
        promise_depends_on = []
        for i, child in enumerate(children):
            if child is None:
                continue
            if nodes_pending[child] == 0 and child not in promise_queue:
                ready_children.append(child)
            elif isinstance(curnode_outputs[i], AddBackwardPromise) and not curnode_outputs[i].complete and "pre_promise" not in child.metadata:
                promise_depends_on.append(child)

        promise_traversal_stack = promise_depends_on + promise_traversal_stack
        stack = ready_children + stack
        num_checkpoints_reached = sum([ "checkpoint_relevance" in checkpoint.metadata for checkpoint in checkpoints])


checkpoint_vals = [ checkpoint.metadata["checkpoint_relevance" ] for checkpoint in checkpoints ]


[<ViewBackward0 object at 0x297b43460>] [] []
[<torch.autograd.function.IndexPutFirstAxisBackward object at 0x2b86e2140>] [] []
[<NativeLayerNormBackward0 object at 0x2b8706670>] [] []
[<AddBackward0 object at 0x2b8706760>, <AccumulateGrad object at 0x297b4f940>, <AccumulateGrad object at 0x297b4f3a0>] [] []
[<util.AddBackward0 object at 0x2b8709970>, <AccumulateGrad object at 0x297b4f940>, <AccumulateGrad object at 0x297b4f3a0>] [<NativeLayerNormBackward0 object at 0x2b8706850>] []
triggering promise <add_backward_promise.AddBackwardPromise object at 0x29b7bdb80>
[<util.AddBackward0 object at 0x2b8709970>, <AccumulateGrad object at 0x297b4f940>, <AccumulateGrad object at 0x297b4f3a0>] [] [<NativeLayerNormBackward0 object at 0x2b8706850>]
[<AccumulateGrad object at 0x297b4eb80>, <util.MmBackward0 object at 0x2b8709d30>, <AccumulateGrad object at 0x297b4f940>, <AccumulateGrad object at 0x297b4f3a0>] [] [<NativeLayerNormBackward0 object at 0x2b8706850>]
[<util.MmBackward0 object at 0x2b8

In [11]:
slice2 = in_adj_list[curnode][0]
slice1 = in_adj_list[slice2][0]
permute = in_adj_list[slice1][0]
expand = in_adj_list[permute][0]

print(input_tracker[slice1][0].shape)
print(slice1._saved_dim)
print(slice1._saved_self_sym_sizes)
print(input_tracker[slice2][0].shape)

# curnode._saved_self_sym_sizes
# promise = input_tracker[curnode][0]
# print(promise.bwd[0][1](promise.arg).shape)
# print(in_adj_list[curnode][0]._saved_self_sym_sizes)

torch.Size([1, 64, 17, 12])
3
(1, 17, 12, 64)
torch.Size([1, 64, 17, 64])


In [12]:
[ (elem, nodes_pending[elem[0]]) for elem in list(input_tracker.items()) if len(elem[1]) ]

[((<torch.autograd.function.IndexPutFirstAxisBackward at 0x29c0bfd40>,
   [tensor([[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]])]),
  0),
 ((<NativeLayerNormBackward0 at 0x29c0b2f40>,
   [tensor([[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]])]),
  0),
 ((<AddBackward0 at 0x29c0b2fd0>,
   [tensor([[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ...,

In [13]:
len(AddBackwardPromise.all_promises)

12

In [14]:
len([ fcn for fcn in visited1 if fcn in in_adj_list and (len(in_adj_list[fcn]) > 1 or type(fcn).__name__ == "AddBackward0") ])

8

In [46]:
visited1 = sorted(list(visited1), key=lambda fcn: fcn._sequence_nr(), reverse=True)
# print(visited1)
visited = list(filter(lambda fcn: type(fcn).__name__ != "AccumulateGrad", visited1))

# print(out_adj_list[visited[1]])

# inputs = input_tracker[visited[-1]]

# print(inputs[0].promise)
# print(visited[-1].metadata)
# print(input_tracker[visited1[4]])
# print(visited1[4].metadata)
# print(visited[-5])
# print(input_tracker[out_adj_list[visited[-5]][0]][0].promise)
list(enumerate(visited))

[(0, <ViewBackward0 at 0x28ed35340>),
 (1, <torch.autograd.function.IndexPutFirstAxisBackward at 0x291c79d40>),
 (2, <NativeLayerNormBackward0 at 0x28ed35f40>),
 (3, <AddBackward0 at 0x28ed359d0>),
 (4, <util.AddBackward0 at 0x291c73a90>),
 (5, <util.MmBackward0 at 0x291c73c10>),
 (6, <MulBackward0 at 0x291c73b80>),
 (7, <GeluBackward0 at 0x291c73d00>),
 (8, <SliceBackward0 at 0x291c73af0>),
 (9, <SliceBackward0 at 0x291c73550>),
 (10, <SliceBackward0 at 0x291c73c70>),
 (11, <SliceBackward0 at 0x291c73b50>),
 (12, <MmBackward0 at 0x291c73b20>),
 (13, <NativeLayerNormBackward0 at 0x291c73be0>),
 (14, <AddBackward0 at 0x291c73dc0>),
 (15, <util.AddBackward0 at 0x291c73a30>),
 (16, <util.MmBackward0 at 0x291c733a0>),
 (17, <torch.autograd.function.LRPCheckpointBackward at 0x291c79c40>),
 (18, <ViewBackward0 at 0x291c73bb0>),
 (19, <torch.autograd.function.IndexFirstAxisBackward at 0x291c79b40>),
 (20, <ReshapeAliasBackward0 at 0x14c65c6d0>),
 (21, <PermuteBackward0 at 0x28ed35760>),
 (22,

In [29]:
promise = promise_queue[0].metadata["promise"]
parent = promise["parents"][0]
parent.promise


{'rout': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 'args': [tensor([[-0.0475, -0.2723,  0.4098,  ..., -0.3523,  0.1281, -0.2018],
          [ 0.1428,  0.0020,  0.2586,  ..., -0.2399,  0.1047, -0.0055],
          [ 0.1221, -0.0062,  0.3019,  ..., -0.2288,  0.0363, -0.0044],
          ...,
          [ 0.0528, -0.1793,  0.1704,  ...,  0.1261, -0.0343, -0.2712],
          [ 0.0495, -0.0877,  0.1440,  ...,  0.1425, -0.0284, -0.2875],
          [ 0.1622, -0.0624,  0.2502,  ...,  0.0583, -0.0040, -0.2234]]),
  tensor([[-0.0622, -0.0385,  0.1013,  ..., -0.0403, -0.0234, -0.0343],
          [ 0.1521,  0.0156,  0.0175,  ..., -0.1897,  0.0935,  0.0241],
          [ 0.0246,  0.0086,  0.0833,  ..., -0.0325,  0.0045,  0.0782],
          ...,
          [-0.0537,  0.1426,  0.0506,  ...

tensor(-4406.1240)

In [12]:
# [ fcn for fcn in visited1 if "pre_promise" in fcn.metadata and fcn.metadata["pre_promise"] == promise_queue[1].metadata["promise"]["parents"][0].parents[0].parents[0].parents[0].promise]
len(visited1)

694

In [18]:
curnode_in_rel.shape

torch.Size([1, 17, 12, 64])

In [21]:
out = torch.zeros(curnode._saved_self_sym_sizes)
x_expanded = torch.unsqueeze(curnode_in_rel, 2)

In [22]:
x_expanded.shape

torch.Size([1, 17, 1, 12, 64])

In [11]:
len(visited1)

216

In [13]:
a = torch.zeros((4068, 768))
b = torch.tensor([1,5,8,10,11,12,14,19,20,21])
c = torch.rand((10,768))
a.index_add_(0,b,c)
a

2

In [43]:
list(enumerate(visited))

NameError: name 'visited' is not defined

In [25]:
# with torch.no_grad():
#     # Create the first relevance layer via max logit.
#     m = hidden_states.max(-1)
#     relevance = torch.zeros_like(hidden_states)
#     b, s, d = hidden_states.shape

#     for i, inds in enumerate(m.indices):
#         relevance[i,list(range(s)),inds] = torch.ones_like(m.values[0])

#     # Setup the first iteration
#     input_tracker[hidden_states.grad_fn] = [ relevance ]
#     stack = [hidden_states.grad_fn]
#     nodes_pending = { k : len(v) for k, v in list(in_adj_list.items()) }

#     promise_queue : list[AddBackwardPromise] = []

#     while (stack or promise_queue) and num_checkpoints_reached < len(checkpoints):
#         print(stack)
#         # Pop first element of either promise_queue or main stack
#         curnode = None
#         if promise_queue and promise_queue[0].metadata["promise"]["complete"]:
#             # Since our non-special-case traversal is DFS, the first promise will be the first to be complete.
#             curnode = promise_queue[0]
#             promise_queue = promise_queue[1:]
#         else:
#             curnode = stack[0]
#             stack = stack[1:]

#         curnode_inputs = input_tracker[curnode]
#         # if any([ x is None for x in curnode_inputs ]):
#         #     # Node hasn't received all its relevance yet, push it to the back.
#         #     stack = stack + [curnode]
#         #     continue

#         pending_promise_inputs = []
#         complete_promise_inputs = []
#         tensor_inputs = []
#         for input_ in curnode_inputs:
#             if isinstance(input_, torch.Tensor):
#                 tensor_inputs.append(input_)
#             elif isinstance(input_, AddBackwardPromise) and input_.complete:
#                 complete_promise_inputs.append(input_)
#             elif isinstance(input_, AddBackwardPromise):
#                 pending_promise_inputs.append(input_)
#             elif input_ == 0.0:
#                 continue
#             else:
#                 raise ValueError(f"Expected relevance input to Node {curnode} to be type AddBackwardPromise or Tensor, but got {type(input_)} instead.")

#         if not complete_promise_inputs and not pending_promise_inputs and not tensor_inputs:
#             continue

#         agg_tensor_inputs = sum(tensor_inputs) + sum([ p.rin for p in complete_promise_inputs ])
#         if pending_promise_inputs:
#             agg_promises = compound_promises(pending_promise_inputs)
#             if agg_tensor_inputs != 0:
#                 curnode_in_rel = agg_promises + curnode_in_rel
#             else:
#                 curnode_in_rel = agg_promises
            
        
#         # Call the propagation function for the node
#         curnode_outputs = fcn_map[type(curnode).__name__](curnode, curnode_in_rel)

#         if isinstance(curnode_outputs, AddBackwardPromise) and not curnode_outputs.complete:
#             # Node is waiting on Promise to be completed, add to promise queue and come back later.
#             debug = curnode_outputs.promise
#             curnode.metadata["promise"] = curnode_outputs.promise
#             promise_queue.append(curnode)
#             continue

#         children = out_adj_list[curnode]
#         ready_children = []

#         if len(children) == 0:
#             continue
#         elif len(children) == 1:
#             curnode_outputs = [ curnode_outputs ]

#         if len(children) != len(curnode_outputs):
#             raise ValueError(f"Mismatch: {len(children)} children but {len(curnode_outputs)} outputs from {curnode}.")

#         # Collect children who now have all their inputs
#         for i, child in enumerate(children):
#             input_tracker[child].append(curnode_outputs[i])
#             nodes_pending[child] -= 1
#             assert nodes_pending[child] >= 0, f"Negative pending count for node {child}"
#             if len(input_tracker[child]) > len(in_adj_list[child]):
#                 print(child, input_tracker[child], in_adj_list[child])
#             assert len(input_tracker[child]) <= len(in_adj_list[child]), \
#                 f"Too many inputs landed for {child}, expected {len(in_adj_list[child])} but got {len(input_tracker[child])}."
#             if nodes_pending[child] == 0:
#                 ready_children.append(child)

#         stack = ready_children + stack
#         num_checkpoints_reached = sum([ hasattr(checkpoint.metadata, "checkpoint_relevance") for checkpoint in checkpoints])

# print(checkpoints)

# checkpoint_vals = [ checkpoint.metadata for checkpoint in checkpoints ]

[<ViewBackward0 object at 0x176f69880>]


RuntimeError: shape '[17, 768]' is invalid for input of size 2359296

In [35]:
[ k for k in out_adj_list if promise_queue[1] in out_adj_list[k] ][0].metadata["promise"]["parents"][0].other_branch


<add_backward_promise.AddBackwardPromise at 0x16cb30df0>

In [21]:
visited[714].next_functions

((<NativeLayerNormBackward0 at 0x28cab9dc0>, 0), (None, 0))

In [12]:
fcns = [ [hidden_states.grad_fn] ]
visited = set()
names = set()

count = 0
while fcns:
    count += 1
    # print(len(fcns))
    new_fcns = []
    for fcn_list in fcns:
        for fcn in fcn_list:
            if fcn is None or fcn in visited:
                continue
            if type(fcn).__name__ not in names:
                names.add(type(fcn).__name__)
            visited.add(fcn)
            new_fcns.append([ fcn_tup[0] for fcn_tup in fcn.next_functions ])
        # new_fcns += [ [ fcn_tup[0] for fcn_tup in curr.next_functions ] for curr in fcn_list if (curr is not None) ]
    fcns = new_fcns

print(count)

visited = sorted(list(visited), key=lambda fcn: fcn._sequence_nr())

74
