In [14]:
import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict


In [8]:
class GraphLikeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.branch1 = torch.nn.Linear(128, 64)
        self.branch2 = torch.nn.Linear(128, 64)
        self.head = torch.nn.Linear(64, 10)
    def forward(self, x):
        return self.head(self.branch1(x)+self.branch2(x))
x = torch.rand(256, 128)    
model = GraphLikeModel()
model(x).shape

torch.Size([256, 10])

In [15]:
traced_graph = torch.fx.symbolic_trace(model)

In [16]:
print(traced_graph.graph)

graph():
    %x : [#users=2] = placeholder[target=x]
    %branch1 : [#users=1] = call_module[target=branch1](args = (%x,), kwargs = {})
    %branch2 : [#users=1] = call_module[target=branch2](args = (%x,), kwargs = {})
    %add : [#users=1] = call_function[target=operator.add](args = (%branch1, %branch2), kwargs = {})
    %head : [#users=1] = call_module[target=head](args = (%add,), kwargs = {})
    return head


In [63]:
class GraphInterperter:
    """
    see torch.fx shape propagation
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def __call__(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
            if node.op == 'output':
                return result
            
            
            env[node.name] = result
        

In [65]:
prop = GraphInterperter(traced_graph)

In [66]:
prop(x)

tensor([[-0.1441,  0.1658, -0.0016,  ..., -0.0394, -0.1156,  0.1525],
        [ 0.1496,  0.1661, -0.1636,  ..., -0.3794, -0.1345,  0.0840],
        [ 0.1438,  0.1154, -0.1784,  ..., -0.2873,  0.0075, -0.2276],
        ...,
        [ 0.0880,  0.1487, -0.1796,  ..., -0.1099,  0.0612, -0.0129],
        [ 0.0730,  0.1332, -0.0952,  ..., -0.0871, -0.2447,  0.0098],
        [ 0.0216,  0.1172, -0.1538,  ..., -0.2286,  0.1072, -0.1076]],
       grad_fn=<AddmmBackward0>)

In [67]:
model(x)

tensor([[-0.1441,  0.1658, -0.0016,  ..., -0.0394, -0.1156,  0.1525],
        [ 0.1496,  0.1661, -0.1636,  ..., -0.3794, -0.1345,  0.0840],
        [ 0.1438,  0.1154, -0.1784,  ..., -0.2873,  0.0075, -0.2276],
        ...,
        [ 0.0880,  0.1487, -0.1796,  ..., -0.1099,  0.0612, -0.0129],
        [ 0.0730,  0.1332, -0.0952,  ..., -0.0871, -0.2447,  0.0098],
        [ 0.0216,  0.1172, -0.1538,  ..., -0.2286,  0.1072, -0.1076]],
       grad_fn=<AddmmBackward0>)

In [76]:
class GraphInterperterWithGamma:
    """
    тот же класс, но с гаммами для нод
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())
        self.gammas = {}
        for node in self.graph.nodes:
            if node.op == 'call_module':
                self.gammas[str(node)] = 1.0 # перевод в str тут для удобства. в реалньых методах это не нужно
                # да и вообще, тут по идее должен быть тензор/параметр

    def __call__(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) * self.gammas[str(node)]
            if node.op == 'output':
                return result
                        
            env[node.name] = result
        

In [77]:
gamma_graph = GraphInterperterWithGamma(traced_graph)

In [78]:
gamma_graph(x)

tensor([[-0.1441,  0.1658, -0.0016,  ..., -0.0394, -0.1156,  0.1525],
        [ 0.1496,  0.1661, -0.1636,  ..., -0.3794, -0.1345,  0.0840],
        [ 0.1438,  0.1154, -0.1784,  ..., -0.2873,  0.0075, -0.2276],
        ...,
        [ 0.0880,  0.1487, -0.1796,  ..., -0.1099,  0.0612, -0.0129],
        [ 0.0730,  0.1332, -0.0952,  ..., -0.0871, -0.2447,  0.0098],
        [ 0.0216,  0.1172, -0.1538,  ..., -0.2286,  0.1072, -0.1076]],
       grad_fn=<MulBackward0>)

In [79]:
gamma_graph.gammas['branch1'] = 0.0

In [80]:
gamma_graph(x)

tensor([[ 0.1519,  0.0631, -0.0611,  ...,  0.4401, -0.3220,  0.1663],
        [ 0.2076, -0.0517, -0.0655,  ...,  0.1796, -0.4026, -0.0234],
        [ 0.1672,  0.0546, -0.1082,  ...,  0.1515, -0.2221, -0.1993],
        ...,
        [ 0.1904, -0.0106, -0.2424,  ...,  0.3191, -0.2320, -0.0055],
        [ 0.0705,  0.0832, -0.1698,  ...,  0.1545, -0.2700, -0.1007],
        [ 0.1491, -0.0325, -0.0811,  ...,  0.3080, -0.1795, -0.1576]],
       grad_fn=<MulBackward0>)