In [1]:
import torch_serialize
import torch

In [2]:
class SubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(3, 2)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.relu(self.fc(x))

class HierarchicalModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = SubModule()
        self.block2 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, kernel_size=3),
            torch.nn.ReLU()
        )

    def forward(self, x):
        out1 = self.block1(x+3)
        out2 = self.block2(out1.unsqueeze(1).unsqueeze(2))  # Adjust for Conv2d
        return out2

# Create the model
model = HierarchicalModel()
data = torch_serialize.serialize_graph(model, "hierarchical_model.json")

Hierarchical model serialized to hierarchical_model.json


In [3]:
data

{'name': 'root',
 'is_standard_nn': False,
 'graph': [{'name': 'x',
   'op': 'placeholder',
   'target': 'x',
   'args': [],
   'kwargs': {}},
  {'name': 'add',
   'op': 'call_function',
   'target': '<built-in function add>',
   'args': ['x', '3'],
   'kwargs': {}},
  {'name': 'block1',
   'op': 'call_module',
   'target': 'block1',
   'args': ['add'],
   'kwargs': {}},
  {'name': 'unsqueeze',
   'op': 'call_method',
   'target': 'unsqueeze',
   'args': ['block1', '1'],
   'kwargs': {}},
  {'name': 'unsqueeze_1',
   'op': 'call_method',
   'target': 'unsqueeze',
   'args': ['unsqueeze', '2'],
   'kwargs': {}},
  {'name': 'block2',
   'op': 'call_module',
   'target': 'block2',
   'args': ['unsqueeze_1'],
   'kwargs': {}},
  {'name': 'output',
   'op': 'output',
   'target': 'output',
   'args': ['block2'],
   'kwargs': {}}],
 'children': {'block1': {'name': 'block1',
   'is_standard_nn': False,
   'graph': [{'name': 'x',
     'op': 'placeholder',
     'target': 'x',
     'args': [],
 

In [29]:
traced_graph = torch_serialize.BasicTracer().trace(model)
node = list(traced_graph.nodes)[1]
node.target

<function _operator.add(a, b, /)>

In [None]:
model.state_d