Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Names of Modules? #103

Open
Erotemic opened this issue Sep 2, 2023 · 1 comment
Open

Names of Modules? #103

Erotemic opened this issue Sep 2, 2023 · 1 comment

Comments

@Erotemic
Copy link

Erotemic commented Sep 2, 2023

Is your feature request related to a problem? Please describe.

I'm looking for a way to visualize the information flow of a network in terms of nested module names. I.e. if we extract torch modules of specific types like this:

        def model_layers(model):
            """ Extract named "leaf" layers from a module """
            stack = [('', '', model)]
            while stack:
                prefix, basename, item = stack.pop()
                name = '.'.join([p for p in [prefix, basename] if p])
                if isinstance(item, torch.nn.modules.conv._ConvNd):
                    yield name, item
                elif isinstance(item, torch.nn.modules.batchnorm._BatchNorm):
                    yield name, item
                elif hasattr(item, 'reset_parameters'):
                    yield name, item

                child_prefix = name
                for child_basename, child_item in list(item.named_children())[::-1]:
                    stack.append((child_prefix, child_basename, child_item))

which for torchvision.resenet18 looks like:

['conv1', 'bn1', 'layer1.0.conv1', 'layer1.0.bn1', 'layer1.0.conv2', 'layer1.0.bn2', 'layer1.1.conv1', 'layer1.1.bn1', 'layer1.1.conv2', 'layer1.1.bn2', 'layer2.0.conv1', 'layer2.0.bn1', 'layer2.0.conv2', 'layer2.0.bn2', 'layer2.0.downsample.0', 'layer2.0.downsample.1', 'layer2.1.conv1', 'layer2.1.bn1', 'layer2.1.conv2', 'layer2.1.bn2', 'layer3.0.conv1', 'layer3.0.bn1', 'layer3.0.conv2', 'layer3.0.bn2', 'layer3.0.downsample.0', 'layer3.0.downsample.1', 'layer3.1.conv1', 'layer3.1.bn1', 'layer3.1.conv2', 'layer3.1.bn2', 'layer4.0.conv1', 'layer4.0.bn1', 'layer4.0.conv2', 'layer4.0.bn2', 'layer4.0.downsample.0', 'layer4.0.downsample.1', 'layer4.1.conv1', 'layer4.1.bn1', 'layer4.1.conv2', 'layer4.1.bn2', 'fc']

I would like to be able to determine what layers conceptually connected (i.e. the outputs of one layer eventually make it to the inputs of another layer).

Describe the solution you'd like

Currently when I do something like:

        import torchvision
        from torchview import draw_graph
        net = torchvision.models.resnet18()
        model_graph = draw_graph(net, input_size=(2, 3, 224, 224), device='meta')

I don't see any recorded layer names. I'm wondering if it's possible in collect graph (https://github.com/mert-kurttutan/torchview/blob/main/torchview/computation_graph.py#L188) to also associate layer names with nodes as they are extracted.

Describe alternatives you've considered

Tried doing this myself. There are a lot of corner cases. This library seems like it has the best approach for torch graph extaction I've seen so far.

Additional context

I raised a similar issue on the torch discussion page: https://discuss.pytorch.org/t/tracing-a-graph-of-torch-layers/187615

@Erotemic
Copy link
Author

Erotemic commented Sep 2, 2023

A MWE of something close to what I want is:

    import torchvision
    from torchview import draw_graph
    import torch
    import networkx as nx

    def model_layers(model):
        """ Extract named "leaf" layers from a module """
        stack = [('', '', model)]
        while stack:
            prefix, basename, item = stack.pop()
            name = '.'.join([p for p in [prefix, basename] if p])
            if isinstance(item, torch.nn.modules.conv._ConvNd):
                yield name, item
            elif isinstance(item, torch.nn.modules.batchnorm._BatchNorm):
                yield name, item
            elif hasattr(item, 'reset_parameters'):
                yield name, item

            child_prefix = name
            for child_basename, child_item in list(item.named_children())[::-1]:
                stack.append((child_prefix, child_basename, child_item))

    # Create example network
    net = torchvision.models.resnet18()
    model_graph = draw_graph(net, input_size=(2, 3, 224, 224), device='meta')

    # Remember the dotted layer name associated with each torch.Module
    # instance.  Usually a module will just have one name associated to an
    # instance, but it could have more than one.
    from collections import defaultdict
    named_layers = list(model_layers(net))
    id_to_names = defaultdict(list)
    for name, layer in named_layers:
        layer_id = id(layer)
        id_to_names[layer_id].append(name)

    def make_label(n, data):
        """ Create a nice printable label """
        n_id = id(n)
        n_id_str = str(n_id)
        parts = []
        if 'layer_name' in data:
            parts.append(data['layer_name'] + ':')
        parts.append(n.name)
        if n_id_str in model_graph.id_dict:
            idx = model_graph.id_dict[n_id_str]
            parts.append(f':{idx}')

        if n_id in id_to_names:
            parts.append(' ' + id_to_names[n_id])

        label = ''.join(parts)
        return label

    # Build a networkx version of the torchview model graph
    graph = nx.DiGraph()
    for node in model_graph.node_set:
        graph.add_node(node)

    for u, v in model_graph.edge_list:
        u_id = id(u)
        v_id = id(v)
        graph.add_edge(u_id, v_id)
        graph.nodes[u_id]['compute_node'] = u
        graph.nodes[v_id]['compute_node'] = v

    # Enrich each node with more info
    for n_id, data in graph.nodes(data=True):
        if 'compute_node' in data:
            n = data['compute_node']
            if hasattr(n, 'compute_unit_id'):
                if n.compute_unit_id in id_to_names:
                    layer_names = id_to_names[n.compute_unit_id]
                    if len(layer_names) == 1:
                        data['layer_name'] = layer_names[0]
                    else:
                        data['layer_names'] = layer_names[0]
            data['label'] = make_label(n, data)

    nx.write_network_text(graph, vertical_chains=1)
    # model_graph.visual_graph.view()

Produces:

╟── 139679377001936
╙── auxiliary-tensor
    ╽
    conv1:Conv2d:1
    ╽
    bn1:BatchNorm2d:2
    ╽
    ReLU:3
    ╽
    MaxPool2d:4
    ├─╼ layer1.0.conv1:Conv2d:5
    │   ╽
    │   layer1.0.bn1:BatchNorm2d:6
    │   ╽
    │   ReLU:7
    │   ╽
    │   layer1.0.conv2:Conv2d:8
    │   ╽
    │   layer1.0.bn2:BatchNorm2d:9
    │   ╽
    │   add_:10 ╾ MaxPool2d:4
    │   ╽
    │   ReLU:11
    │   ├─╼ layer1.1.conv1:Conv2d:12
    │   │   ╽
    │   │   layer1.1.bn1:BatchNorm2d:13
    │   │   ╽
    │   │   ReLU:14
    │   │   ╽
    │   │   layer1.1.conv2:Conv2d:15
    │   │   ╽
    │   │   layer1.1.bn2:BatchNorm2d:16
    │   │   ╽
    │   │   add_:17 ╾ ReLU:11
    │   │   ╽
    │   │   ReLU:18
    │   │   ├─╼ layer2.0.conv1:Conv2d:19
    │   │   │   ╽
    │   │   │   layer2.0.bn1:BatchNorm2d:20
    │   │   │   ╽
    │   │   │   ReLU:21
    │   │   │   ╽
    │   │   │   layer2.0.conv2:Conv2d:22
    │   │   │   ╽
    │   │   │   layer2.0.bn2:BatchNorm2d:23
    │   │   │   ╽
    │   │   │   add_:25 ╾ Sequential:24
    │   │   │   ╽
    │   │   │   ReLU:26
    │   │   │   ├─╼ layer2.1.conv1:Conv2d:27
    │   │   │   │   ╽
    │   │   │   │   layer2.1.bn1:BatchNorm2d:28
    │   │   │   │   ╽
    │   │   │   │   ReLU:29
    │   │   │   │   ╽
    │   │   │   │   layer2.1.conv2:Conv2d:30
    │   │   │   │   ╽
    │   │   │   │   layer2.1.bn2:BatchNorm2d:31
    │   │   │   │   ╽
    │   │   │   │   add_:32 ╾ ReLU:26
    │   │   │   │   ╽
    │   │   │   │   ReLU:33
    │   │   │   │   ├─╼ layer3.0.conv1:Conv2d:34
    │   │   │   │   │   ╽
    │   │   │   │   │   layer3.0.bn1:BatchNorm2d:35
    │   │   │   │   │   ╽
    │   │   │   │   │   ReLU:36
    │   │   │   │   │   ╽
    │   │   │   │   │   layer3.0.conv2:Conv2d:37
    │   │   │   │   │   ╽
    │   │   │   │   │   layer3.0.bn2:BatchNorm2d:38
    │   │   │   │   │   ╽
    │   │   │   │   │   add_:40 ╾ Sequential:39
    │   │   │   │   │   ╽
    │   │   │   │   │   ReLU:41
    │   │   │   │   │   ├─╼ layer3.1.conv1:Conv2d:42
    │   │   │   │   │   │   ╽
    │   │   │   │   │   │   layer3.1.bn1:BatchNorm2d:43
    │   │   │   │   │   │   ╽
    │   │   │   │   │   │   ReLU:44
    │   │   │   │   │   │   ╽
    │   │   │   │   │   │   layer3.1.conv2:Conv2d:45
    │   │   │   │   │   │   ╽
    │   │   │   │   │   │   layer3.1.bn2:BatchNorm2d:46
    │   │   │   │   │   │   ╽
    │   │   │   │   │   │   add_:47 ╾ ReLU:41
    │   │   │   │   │   │   ╽
    │   │   │   │   │   │   ReLU:48
    │   │   │   │   │   │   ├─╼ layer4.0.conv1:Conv2d:49
    │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   layer4.0.bn1:BatchNorm2d:50
    │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   ReLU:51
    │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   layer4.0.conv2:Conv2d:52
    │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   layer4.0.bn2:BatchNorm2d:53
    │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   add_:55 ╾ Sequential:54
    │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   ReLU:56
    │   │   │   │   │   │   │   ├─╼ layer4.1.conv1:Conv2d:57
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   layer4.1.bn1:BatchNorm2d:58
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   ReLU:59
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   layer4.1.conv2:Conv2d:60
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   layer4.1.bn2:BatchNorm2d:61
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   add_:62 ╾ ReLU:56
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   ReLU:63
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   AdaptiveAvgPool2d:64
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   flatten:65
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   fc:Linear:66
    │   │   │   │   │   │   │   │   ╽
    │   │   │   │   │   │   │   │   output-tensor:67
    │   │   │   │   │   │   │   └─╼  ...
    │   │   │   │   │   │   └─╼ Sequential:54
    │   │   │   │   │   │       └─╼  ...
    │   │   │   │   │   └─╼  ...
    │   │   │   │   └─╼ Sequential:39
    │   │   │   │       └─╼  ...
    │   │   │   └─╼  ...
    │   │   └─╼ Sequential:24
    │   │       └─╼  ...
    │   └─╼  ...
    └─╼  ...

You can see here that I've been able to associate many of the nodes with their original layer names. However, my solution to this is to just assume each instance is only used once. I think a correct solution would attempt to know which module attribute was the caller - which gets tricky if you assign multiple instances of a module to different variables.

To further process this into what I'm actually intersted in I do something like this:

    # Now that we have a graph where a subset of nodes correspond to known
    # layers, we can postprocess it to only show effective connections between
    # the layers.

    # Determine which nodes have associated layer names
    remove_ids = []
    keep_ids = []
    for n_id, data in graph.nodes(data=True):
        if 'layer_name' in data:
            keep_ids.append(n_id)
        else:
            remove_ids.append(n_id)

    import ubelt as ub
    topo_order = ub.OrderedSet(nx.topological_sort(graph))
    keep_topo_order = (topo_order & keep_ids)

    # Find the nearest ancestor that we want to view and collapse the node we
    # dont care about into it. Do a final relabeling to keep the original node
    # ids where possible.
    collapseables = defaultdict(list)
    for n in remove_ids:
        valid_prev_nodes = keep_topo_order & set(nx.ancestors(graph, n))
        if valid_prev_nodes:
            p = valid_prev_nodes[-1]
            collapseables[p].append(n)
    from networkx.algorithms.connectivity.edge_augmentation import collapse
    grouped_nodes = []
    for p, vs in collapseables.items():
        grouped_nodes.append([p, *vs])
    g2 = collapse(graph, grouped_nodes)
    relabel = {n: n for n in g2.nodes}
    new_to_olds = ub.udict(g2.graph['mapping']).invert(unique_vals=0)
    for new, olds in new_to_olds.items():
        if len(olds) == 1:
            old = ub.peek(olds)
            relabel[new] = old
        else:
            keep_olds = keep_topo_order & olds
            old = ub.peek(keep_olds)
            relabel[new] = old
    g3 = nx.relabel_nodes(g2, relabel)

    def transfer_data(g_dst, g_src):
        for n in set(g_dst.nodes) & set(g_src.nodes):
            g_dst.nodes[n].update(g_src.nodes[n])

    # Show the collapsed graph
    transfer_data(g3, graph)
    nx.write_network_text(g3, vertical_chains=1)

    # Further reduce the graph to remove skip connection information
    g4 = nx.transitive_reduction(g3)
    transfer_data(g4, graph)
    nx.write_network_text(g4, vertical_chains=1)

    g2 = nx.transitive_closure(graph)
    g2 = nx.transitive_reduction(g2)
    transfer_data(g2, graph)

Which shows the graph where the intermediate functional nodes have been collapsed into one of their parent layers:

╟── auxiliary-tensor
╎   ╽
╎   conv1:Conv2d:1
╎   ╽
╎   bn1:BatchNorm2d:2
╎   ├─╼ layer1.0.conv1:Conv2d:5
╎   │   ╽
╎   │   layer1.0.bn1:BatchNorm2d:6
╎   │   ╽
╎   │   layer1.0.conv2:Conv2d:8
╎   │   ╽
╎   │   layer1.0.bn2:BatchNorm2d:9bn1:BatchNorm2d:2
╎   │   ├─╼ layer1.1.conv1:Conv2d:12
╎   │   │   ╽
╎   │   │   layer1.1.bn1:BatchNorm2d:13
╎   │   │   ╽
╎   │   │   layer1.1.conv2:Conv2d:15
╎   │   │   ╽
╎   │   │   layer1.1.bn2:BatchNorm2d:16layer1.0.bn2:BatchNorm2d:9
╎   │   │   ├─╼ layer2.0.bn2:BatchNorm2d:23layer2.0.conv2:Conv2d:22
╎   │   │   │   ├─╼ layer2.1.conv1:Conv2d:27
╎   │   │   │   │   ╽
╎   │   │   │   │   layer2.1.bn1:BatchNorm2d:28
╎   │   │   │   │   ╽
╎   │   │   │   │   layer2.1.conv2:Conv2d:30
╎   │   │   │   │   ╽
╎   │   │   │   │   layer2.1.bn2:BatchNorm2d:31layer2.0.bn2:BatchNorm2d:23
╎   │   │   │   │   ├─╼ layer3.0.bn2:BatchNorm2d:38layer3.0.conv2:Conv2d:37
╎   │   │   │   │   │   ├─╼ layer3.1.conv1:Conv2d:42
╎   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   layer3.1.bn1:BatchNorm2d:43
╎   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   layer3.1.conv2:Conv2d:45
╎   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   layer3.1.bn2:BatchNorm2d:46layer3.0.bn2:BatchNorm2d:38
╎   │   │   │   │   │   │   ├─╼ layer4.0.bn2:BatchNorm2d:53layer4.0.conv2:Conv2d:52
╎   │   │   │   │   │   │   │   ├─╼ layer4.1.conv1:Conv2d:57
╎   │   │   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   │   │   layer4.1.bn1:BatchNorm2d:58
╎   │   │   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   │   │   layer4.1.conv2:Conv2d:60
╎   │   │   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   │   │   layer4.1.bn2:BatchNorm2d:61layer4.0.bn2:BatchNorm2d:53
╎   │   │   │   │   │   │   │   │   ╽
╎   │   │   │   │   │   │   │   │   fc:Linear:66
╎   │   │   │   │   │   │   │   └─╼  ...
╎   │   │   │   │   │   │   └─╼ layer4.0.conv1:Conv2d:49
╎   │   │   │   │   │   │       ╽
╎   │   │   │   │   │   │       layer4.0.bn1:BatchNorm2d:50
╎   │   │   │   │   │   │       ╽
╎   │   │   │   │   │   │       layer4.0.conv2:Conv2d:52
╎   │   │   │   │   │   │       └─╼  ...
╎   │   │   │   │   │   └─╼  ...
╎   │   │   │   │   └─╼ layer3.0.conv1:Conv2d:34
╎   │   │   │   │       ╽
╎   │   │   │   │       layer3.0.bn1:BatchNorm2d:35
╎   │   │   │   │       ╽
╎   │   │   │   │       layer3.0.conv2:Conv2d:37
╎   │   │   │   │       └─╼  ...
╎   │   │   │   └─╼  ...
╎   │   │   └─╼ layer2.0.conv1:Conv2d:19
╎   │   │       ╽
╎   │   │       layer2.0.bn1:BatchNorm2d:20
╎   │   │       ╽
╎   │   │       layer2.0.conv2:Conv2d:22
╎   │   │       └─╼  ...
╎   │   └─╼  ...
╎   └─╼  ...
╙── 139679377001936

and finally what I ultimately want to see: the transitive reduction of this graph:

╟── auxiliary-tensor
╎   ╽
╎   conv1:Conv2d:1
╎   ╽
╎   bn1:BatchNorm2d:2
╎   ╽
╎   layer1.0.conv1:Conv2d:5
╎   ╽
╎   layer1.0.bn1:BatchNorm2d:6
╎   ╽
╎   layer1.0.conv2:Conv2d:8
╎   ╽
╎   layer1.0.bn2:BatchNorm2d:9
╎   ╽
╎   layer1.1.conv1:Conv2d:12
╎   ╽
╎   layer1.1.bn1:BatchNorm2d:13
╎   ╽
╎   layer1.1.conv2:Conv2d:15
╎   ╽
╎   layer1.1.bn2:BatchNorm2d:16
╎   ╽
╎   layer2.0.conv1:Conv2d:19
╎   ╽
╎   layer2.0.bn1:BatchNorm2d:20
╎   ╽
╎   layer2.0.conv2:Conv2d:22
╎   ╽
╎   layer2.0.bn2:BatchNorm2d:23
╎   ╽
╎   layer2.1.conv1:Conv2d:27
╎   ╽
╎   layer2.1.bn1:BatchNorm2d:28
╎   ╽
╎   layer2.1.conv2:Conv2d:30
╎   ╽
╎   layer2.1.bn2:BatchNorm2d:31
╎   ╽
╎   layer3.0.conv1:Conv2d:34
╎   ╽
╎   layer3.0.bn1:BatchNorm2d:35
╎   ╽
╎   layer3.0.conv2:Conv2d:37
╎   ╽
╎   layer3.0.bn2:BatchNorm2d:38
╎   ╽
╎   layer3.1.conv1:Conv2d:42
╎   ╽
╎   layer3.1.bn1:BatchNorm2d:43
╎   ╽
╎   layer3.1.conv2:Conv2d:45
╎   ╽
╎   layer3.1.bn2:BatchNorm2d:46
╎   ╽
╎   layer4.0.conv1:Conv2d:49
╎   ╽
╎   layer4.0.bn1:BatchNorm2d:50
╎   ╽
╎   layer4.0.conv2:Conv2d:52
╎   ╽
╎   layer4.0.bn2:BatchNorm2d:53
╎   ╽
╎   layer4.1.conv1:Conv2d:57
╎   ╽
╎   layer4.1.bn1:BatchNorm2d:58
╎   ╽
╎   layer4.1.conv2:Conv2d:60
╎   ╽
╎   layer4.1.bn2:BatchNorm2d:61
╎   ╽
╎   fc:Linear:66
╙── 139679377001936

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant