Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Key error when a module has a list of submodules #99

Open
TCLaurentiu opened this issue Aug 19, 2023 · 0 comments
Open

Key error when a module has a list of submodules #99

TCLaurentiu opened this issue Aug 19, 2023 · 0 comments

Comments

@TCLaurentiu
Copy link

TCLaurentiu commented Aug 19, 2023

Edit
The issue was I was using a python list when I should have been using a ModuleList from pytorch. The python list was causing some other issues with pytorch itself when trying to use the gpu, so as far as I'm concerned this isn't much of a concern for torchview.

Describe the bug
I was attempting to implement an U-net and visualize it with torchview. The network itself seemed to work, as in I could pass a tensor to it and I would get back some output, but when trying to use torchview, I received a KeyError. The following is the shortest code I could come up with that shows the error:

import torch
import torch.nn as nn
from torchview import draw_graph

class BuggyModule(nn.Module):

    def __init__(self):
        super(BuggyModule, self).__init__()
        self.modules = [nn.Conv2d(3, 4, 3)]

    def forward(self, x):
        return self.modules[0](x)

net = BuggyModule()

draw_graph(net, input_size=(1, 3, 100, 100), device = "meta")

The error that I received is:

KeyError                                  Traceback (most recent call last)

[<ipython-input-16-c26d126403d3>](https://localhost:8080/#) in <cell line: 3>()
      1 net = BuggyModule()
      2 
----> 3 draw_graph(net, input_size=(1, 3, 100, 100), device = "meta")

2 frames

[/usr/local/lib/python3.10/dist-packages/torchview/computation_graph.py](https://localhost:8080/#) in render_edges(self)
    137         edge_counter: dict[tuple[int, int], int] = {}
    138         for tail, head in self.edge_list:
--> 139             edge_id = self.id_dict[tail.node_id], self.id_dict[head.node_id]
    140             edge_counter[edge_id] = edge_counter.get(edge_id, 0) + 1
    141             self.add_edge(edge_id, edge_counter[edge_id])

KeyError: '139727746724816'

It seems to have something to do with the list of submodules, however I was able to get my code working while using such a list. I have uploaded the faulty code to google colab.

I have also uploaded my original code that was displaying the bug here and the corrected code here. Note the only difference is in the forward method of ContractionModule. Also note these contain a bunch of probably unrelated code, unlike the first one which is only what is needed to reproduce the issue.

To Reproduce
Steps to reproduce the behavior:

  1. Run the above code
  2. There will be an error on the draw_graph() line

Expected behavior
A graph of the network should be displayed on my screen.

Screenshots / Text
Provided above

Additional context

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant