Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upsample exports to a very big subgraph in dynamo #1533

Open
yuanyao-nv opened this issue May 14, 2024 · 19 comments · Fixed by #1592
Open

Upsample exports to a very big subgraph in dynamo #1533

yuanyao-nv opened this issue May 14, 2024 · 19 comments · Fixed by #1592
Assignees
Labels
topic: torch_lib Related to the torch/aten function lib in development

Comments

@yuanyao-nv
Copy link

yuanyao-nv commented May 14, 2024

I'm looking at some models in MONAI which involves torch.nn.Upsample. I notice that torchscript exports the Upsample module to a Resize node but dynamo exports it to a very big graph and has a perf impact.

An example is the SegResNet model.

torchscript:
image

dynamo:
image
expanding the dynamo subgraph:
image

Here's the export script:

import torch
from monai.networks.nets import SegResNet

model = lambda : SegResNet(
    blocks_down=(1, 2, 2, 4),
    blocks_up=(1, 1, 1),
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2
)
model = model().eval().to('cuda')
data = torch.randn(1,4,224,224,128).to("cuda")

dynamo_export = True
if dynamo_export:
    export_output = torch.onnx.dynamo_export(
        model,
        data,
    )
    export_output.save('Clara_SegResNet_dynamo.onnx')
else:
    torch.onnx.export(model, (data,), 'Clara_SegResNet.onnx')

Relevant versions:
onnx==1.16.0
onnxscript==0.1.0.dev20240513
torch==2.4.0.dev20240513+cu121
monai 1.3.0

@yuanyao-nv
Copy link
Author

I also tried to get more info on which part of dynamo/onnxscript might be responsible for this.
If I run

scripted_model = torch.jit.script(model)
print(scripted_model.graph)

I get this:

graph(%self : __torch__.monai.networks.nets.segresnet.SegResNet,
      %x.1 : Tensor):
  %3 : (Tensor, Tensor[]) = prim::CallMethod[name="encode"](%self, %x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:180:20
  %x.5 : Tensor, %down_x.1 : Tensor[] = prim::TupleUnpack(%3)
   = aten::reverse(%down_x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:181:8
  %x.9 : Tensor = prim::CallMethod[name="decode"](%self, %x.5, %down_x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:183:12
  return (%x.9)

If I run

gm, _ = torch._dynamo.export(model)(data)
gm = torch.fx.experimental.proxy_tensor.make_fx(torch.func.functionalize(gm))(data)
gm.print_readable()

I get an error:

Traceback (most recent call last):
  File "/ws/dynamo/0501/export_SegResNet.py", line 31, in <module>
    gm, _ = torch._dynamo.export(model)(data)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1282, in inner
    dim_constraints.solve()
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 1772, in solve
    tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}"
KeyError: "L['x'].size()[4]"

@justinchuby
Copy link
Contributor

Thanks for catching this. Very intriguing. Will take a look!

@justinchuby
Copy link
Contributor

cc @xiaowuhu @fatcat-z

@justinchuby
Copy link
Contributor

@yuanyao-nv could you obtain the graph module from torch.export.export and post it here?

@yuanyao-nv
Copy link
Author

@justinchuby Is this what you mean?

    exported_program = torch.export.export(model, (data,))
    exported_program._graph_module.print_readable()

which gives

class GraphModule(torch.nn.Module):
    def forward(self, p_convinit_conv_weight: "f32[16, 4, 3, 3, 3]", p_getattr_l__self___down_layers_0___1___norm1_weight: "f32[16]", p_getattr_l__self___down_layers_0___1___norm1_bias: "f32[16]", p_getattr_l__self___down_layers_0___1___conv1_conv_weight: "f32[16, 16, 3, 3, 3]", p_getattr_l__self___down_layers_0___1___norm2_weight: "f32[16]", p_getattr_l__self___down_layers_0___1___norm2_bias: "f32[16]", p_getattr_l__self___down_layers_0___1___conv2_conv_weight: "f32[16, 16, 3, 3, 3]", p_down_layers_1_0_conv_weight: "f32[32, 16, 3, 3, 3]", p_getattr_l__self___down_layers_1___1___norm1_weight: "f32[32]", p_getattr_l__self___down_layers_1___1___norm1_bias: "f32[32]", p_getattr_l__self___down_layers_1___1___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___1___norm2_weight: "f32[32]", p_getattr_l__self___down_layers_1___1___norm2_bias: "f32[32]", p_getattr_l__self___down_layers_1___1___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___2___norm1_weight: "f32[32]", p_getattr_l__self___down_layers_1___2___norm1_bias: "f32[32]", p_getattr_l__self___down_layers_1___2___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___2___norm2_weight: "f32[32]", p_getattr_l__self___down_layers_1___2___norm2_bias: "f32[32]", p_getattr_l__self___down_layers_1___2___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_down_layers_2_0_conv_weight: "f32[64, 32, 3, 3, 3]", p_getattr_l__self___down_layers_2___1___norm1_weight: "f32[64]", p_getattr_l__self___down_layers_2___1___norm1_bias: "f32[64]", p_getattr_l__self___down_layers_2___1___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___1___norm2_weight: "f32[64]", p_getattr_l__self___down_layers_2___1___norm2_bias: "f32[64]", p_getattr_l__self___down_layers_2___1___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___2___norm1_weight: "f32[64]", p_getattr_l__self___down_layers_2___2___norm1_bias: "f32[64]", p_getattr_l__self___down_layers_2___2___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___2___norm2_weight: "f32[64]", p_getattr_l__self___down_layers_2___2___norm2_bias: "f32[64]", p_getattr_l__self___down_layers_2___2___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_down_layers_3_0_conv_weight: "f32[128, 64, 3, 3, 3]", p_getattr_l__self___down_layers_3___1___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___1___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___1___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___1___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___1___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___1___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___2___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___2___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___2___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___2___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___2___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___2___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___3___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___3___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___3___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___3___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___3___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___3___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___4___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___4___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___4___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___4___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___4___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___4___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_up_samples_0_0_conv_weight: "f32[64, 128, 1, 1, 1]", p_getattr_l__self___up_layers_0___0___norm1_weight: "f32[64]", p_getattr_l__self___up_layers_0___0___norm1_bias: "f32[64]", p_getattr_l__self___up_layers_0___0___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___up_layers_0___0___norm2_weight: "f32[64]", p_getattr_l__self___up_layers_0___0___norm2_bias: "f32[64]", p_getattr_l__self___up_layers_0___0___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_up_samples_1_0_conv_weight: "f32[32, 64, 1, 1, 1]", p_getattr_l__self___up_layers_1___0___norm1_weight: "f32[32]", p_getattr_l__self___up_layers_1___0___norm1_bias: "f32[32]", p_getattr_l__self___up_layers_1___0___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___up_layers_1___0___norm2_weight: "f32[32]", p_getattr_l__self___up_layers_1___0___norm2_bias: "f32[32]", p_getattr_l__self___up_layers_1___0___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_up_samples_2_0_conv_weight: "f32[16, 32, 1, 1, 1]", p_getattr_l__self___up_layers_2___0___norm1_weight: "f32[16]", p_getattr_l__self___up_layers_2___0___norm1_bias: "f32[16]", p_getattr_l__self___up_layers_2___0___conv1_conv_weight: "f32[16, 16, 3, 3, 3]", p_getattr_l__self___up_layers_2___0___norm2_weight: "f32[16]", p_getattr_l__self___up_layers_2___0___norm2_bias: "f32[16]", p_getattr_l__self___up_layers_2___0___conv2_conv_weight: "f32[16, 16, 3, 3, 3]", p_conv_final_0_weight: "f32[16]", p_conv_final_0_bias: "f32[16]", p_conv_final_2_conv_weight: "f32[3, 16, 1, 1, 1]", p_conv_final_2_conv_bias: "f32[3]", x: "f32[1, 4, 224, 224, 128]"):
        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:157 in encode, code: x = self.convInit(x)
        conv3d: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(x, p_convinit_conv_weight, None, [1, 1, 1], [1, 1, 1]);  x = p_convinit_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:159 in encode, code: x = self.dropout(x)
        feature_dropout: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.feature_dropout.default(conv3d, 0.2, False);  conv3d = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(feature_dropout, 8, p_getattr_l__self___down_layers_0___1___norm1_weight, p_getattr_l__self___down_layers_0___1___norm1_bias);  p_getattr_l__self___down_layers_0___1___norm1_weight = p_getattr_l__self___down_layers_0___1___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm);  group_norm = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu, p_getattr_l__self___down_layers_0___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu = p_getattr_l__self___down_layers_0___1___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(conv3d_1, 8, p_getattr_l__self___down_layers_0___1___norm2_weight, p_getattr_l__self___down_layers_0___1___norm2_bias);  conv3d_1 = p_getattr_l__self___down_layers_0___1___norm2_weight = p_getattr_l__self___down_layers_0___1___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_1);  group_norm_1 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_2: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_1, p_getattr_l__self___down_layers_0___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_1 = p_getattr_l__self___down_layers_0___1___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(conv3d_2, feature_dropout);  conv3d_2 = feature_dropout = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
        conv3d_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(add, p_down_layers_1_0_conv_weight, None, [2, 2, 2], [1, 1, 1]);  p_down_layers_1_0_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_3, 8, p_getattr_l__self___down_layers_1___1___norm1_weight, p_getattr_l__self___down_layers_1___1___norm1_bias);  p_getattr_l__self___down_layers_1___1___norm1_weight = p_getattr_l__self___down_layers_1___1___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_2);  group_norm_2 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_2, p_getattr_l__self___down_layers_1___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_2 = p_getattr_l__self___down_layers_1___1___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_4, 8, p_getattr_l__self___down_layers_1___1___norm2_weight, p_getattr_l__self___down_layers_1___1___norm2_bias);  conv3d_4 = p_getattr_l__self___down_layers_1___1___norm2_weight = p_getattr_l__self___down_layers_1___1___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_3);  group_norm_3 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_3, p_getattr_l__self___down_layers_1___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_3 = p_getattr_l__self___down_layers_1___1___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_1: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_5, conv3d_3);  conv3d_5 = conv3d_3 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(add_1, 8, p_getattr_l__self___down_layers_1___2___norm1_weight, p_getattr_l__self___down_layers_1___2___norm1_bias);  p_getattr_l__self___down_layers_1___2___norm1_weight = p_getattr_l__self___down_layers_1___2___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_4);  group_norm_4 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_6: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_4, p_getattr_l__self___down_layers_1___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_4 = p_getattr_l__self___down_layers_1___2___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_6, 8, p_getattr_l__self___down_layers_1___2___norm2_weight, p_getattr_l__self___down_layers_1___2___norm2_bias);  conv3d_6 = p_getattr_l__self___down_layers_1___2___norm2_weight = p_getattr_l__self___down_layers_1___2___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_5);  group_norm_5 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_7: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_5, p_getattr_l__self___down_layers_1___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_5 = p_getattr_l__self___down_layers_1___2___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_7, add_1);  conv3d_7 = add_1 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
        conv3d_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(add_2, p_down_layers_2_0_conv_weight, None, [2, 2, 2], [1, 1, 1]);  p_down_layers_2_0_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_6: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_8, 8, p_getattr_l__self___down_layers_2___1___norm1_weight, p_getattr_l__self___down_layers_2___1___norm1_bias);  p_getattr_l__self___down_layers_2___1___norm1_weight = p_getattr_l__self___down_layers_2___1___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_6: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_6);  group_norm_6 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_6, p_getattr_l__self___down_layers_2___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_6 = p_getattr_l__self___down_layers_2___1___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_7: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_9, 8, p_getattr_l__self___down_layers_2___1___norm2_weight, p_getattr_l__self___down_layers_2___1___norm2_bias);  conv3d_9 = p_getattr_l__self___down_layers_2___1___norm2_weight = p_getattr_l__self___down_layers_2___1___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_7: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_7);  group_norm_7 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_10: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_7, p_getattr_l__self___down_layers_2___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_7 = p_getattr_l__self___down_layers_2___1___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_3: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_10, conv3d_8);  conv3d_10 = conv3d_8 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(add_3, 8, p_getattr_l__self___down_layers_2___2___norm1_weight, p_getattr_l__self___down_layers_2___2___norm1_bias);  p_getattr_l__self___down_layers_2___2___norm1_weight = p_getattr_l__self___down_layers_2___2___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_8);  group_norm_8 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_11: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_8, p_getattr_l__self___down_layers_2___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_8 = p_getattr_l__self___down_layers_2___2___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_11, 8, p_getattr_l__self___down_layers_2___2___norm2_weight, p_getattr_l__self___down_layers_2___2___norm2_bias);  conv3d_11 = p_getattr_l__self___down_layers_2___2___norm2_weight = p_getattr_l__self___down_layers_2___2___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_9);  group_norm_9 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_12: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_9, p_getattr_l__self___down_layers_2___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_9 = p_getattr_l__self___down_layers_2___2___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_4: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_12, add_3);  conv3d_12 = add_3 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
        conv3d_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(add_4, p_down_layers_3_0_conv_weight, None, [2, 2, 2], [1, 1, 1]);  p_down_layers_3_0_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_10: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_13, 8, p_getattr_l__self___down_layers_3___1___norm1_weight, p_getattr_l__self___down_layers_3___1___norm1_bias);  p_getattr_l__self___down_layers_3___1___norm1_weight = p_getattr_l__self___down_layers_3___1___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_10: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_10);  group_norm_10 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_10, p_getattr_l__self___down_layers_3___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_10 = p_getattr_l__self___down_layers_3___1___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_11: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_14, 8, p_getattr_l__self___down_layers_3___1___norm2_weight, p_getattr_l__self___down_layers_3___1___norm2_bias);  conv3d_14 = p_getattr_l__self___down_layers_3___1___norm2_weight = p_getattr_l__self___down_layers_3___1___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_11: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_11);  group_norm_11 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_11, p_getattr_l__self___down_layers_3___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_11 = p_getattr_l__self___down_layers_3___1___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_5: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_15, conv3d_13);  conv3d_15 = conv3d_13 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_12: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_5, 8, p_getattr_l__self___down_layers_3___2___norm1_weight, p_getattr_l__self___down_layers_3___2___norm1_bias);  p_getattr_l__self___down_layers_3___2___norm1_weight = p_getattr_l__self___down_layers_3___2___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_12: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_12);  group_norm_12 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_12, p_getattr_l__self___down_layers_3___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_12 = p_getattr_l__self___down_layers_3___2___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_16, 8, p_getattr_l__self___down_layers_3___2___norm2_weight, p_getattr_l__self___down_layers_3___2___norm2_bias);  conv3d_16 = p_getattr_l__self___down_layers_3___2___norm2_weight = p_getattr_l__self___down_layers_3___2___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_13);  group_norm_13 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_13, p_getattr_l__self___down_layers_3___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_13 = p_getattr_l__self___down_layers_3___2___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_6: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_17, add_5);  conv3d_17 = add_5 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_6, 8, p_getattr_l__self___down_layers_3___3___norm1_weight, p_getattr_l__self___down_layers_3___3___norm1_bias);  p_getattr_l__self___down_layers_3___3___norm1_weight = p_getattr_l__self___down_layers_3___3___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_14);  group_norm_14 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_18: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_14, p_getattr_l__self___down_layers_3___3___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_14 = p_getattr_l__self___down_layers_3___3___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_18, 8, p_getattr_l__self___down_layers_3___3___norm2_weight, p_getattr_l__self___down_layers_3___3___norm2_bias);  conv3d_18 = p_getattr_l__self___down_layers_3___3___norm2_weight = p_getattr_l__self___down_layers_3___3___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_15);  group_norm_15 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_19: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_15, p_getattr_l__self___down_layers_3___3___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_15 = p_getattr_l__self___down_layers_3___3___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_7: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_19, add_6);  conv3d_19 = add_6 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_7, 8, p_getattr_l__self___down_layers_3___4___norm1_weight, p_getattr_l__self___down_layers_3___4___norm1_bias);  p_getattr_l__self___down_layers_3___4___norm1_weight = p_getattr_l__self___down_layers_3___4___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_16);  group_norm_16 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_20: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_16, p_getattr_l__self___down_layers_3___4___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_16 = p_getattr_l__self___down_layers_3___4___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_20, 8, p_getattr_l__self___down_layers_3___4___norm2_weight, p_getattr_l__self___down_layers_3___4___norm2_bias);  conv3d_20 = p_getattr_l__self___down_layers_3___4___norm2_weight = p_getattr_l__self___down_layers_3___4___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_17);  group_norm_17 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_21: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_17, p_getattr_l__self___down_layers_3___4___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_17 = p_getattr_l__self___down_layers_3___4___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_8: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_21, add_7);  conv3d_21 = add_7 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
        conv3d_22: "f32[1, 64, 28, 28, 16]" = torch.ops.aten.conv3d.default(add_8, p_up_samples_0_0_conv_weight);  add_8 = p_up_samples_0_0_conv_weight = None
        upsample_trilinear3d: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_22, None, False, [2.0, 2.0, 2.0]);  conv3d_22 = None
        add_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(upsample_trilinear3d, add_4);  upsample_trilinear3d = add_4 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_18: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(add_9, 8, p_getattr_l__self___up_layers_0___0___norm1_weight, p_getattr_l__self___up_layers_0___0___norm1_bias);  p_getattr_l__self___up_layers_0___0___norm1_weight = p_getattr_l__self___up_layers_0___0___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_18: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_18);  group_norm_18 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_23: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_18, p_getattr_l__self___up_layers_0___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_18 = p_getattr_l__self___up_layers_0___0___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_19: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_23, 8, p_getattr_l__self___up_layers_0___0___norm2_weight, p_getattr_l__self___up_layers_0___0___norm2_bias);  conv3d_23 = p_getattr_l__self___up_layers_0___0___norm2_weight = p_getattr_l__self___up_layers_0___0___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_19: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_19);  group_norm_19 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_24: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_19, p_getattr_l__self___up_layers_0___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_19 = p_getattr_l__self___up_layers_0___0___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_10: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_24, add_9);  conv3d_24 = add_9 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
        conv3d_25: "f32[1, 32, 56, 56, 32]" = torch.ops.aten.conv3d.default(add_10, p_up_samples_1_0_conv_weight);  add_10 = p_up_samples_1_0_conv_weight = None
        upsample_trilinear3d_1: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_25, None, False, [2.0, 2.0, 2.0]);  conv3d_25 = None
        add_11: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(upsample_trilinear3d_1, add_2);  upsample_trilinear3d_1 = add_2 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_20: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(add_11, 8, p_getattr_l__self___up_layers_1___0___norm1_weight, p_getattr_l__self___up_layers_1___0___norm1_bias);  p_getattr_l__self___up_layers_1___0___norm1_weight = p_getattr_l__self___up_layers_1___0___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_20: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_20);  group_norm_20 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_26: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_20, p_getattr_l__self___up_layers_1___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_20 = p_getattr_l__self___up_layers_1___0___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_21: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_26, 8, p_getattr_l__self___up_layers_1___0___norm2_weight, p_getattr_l__self___up_layers_1___0___norm2_bias);  conv3d_26 = p_getattr_l__self___up_layers_1___0___norm2_weight = p_getattr_l__self___up_layers_1___0___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_21: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_21);  group_norm_21 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_27: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_21, p_getattr_l__self___up_layers_1___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_21 = p_getattr_l__self___up_layers_1___0___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_12: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_27, add_11);  conv3d_27 = add_11 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
        conv3d_28: "f32[1, 16, 112, 112, 64]" = torch.ops.aten.conv3d.default(add_12, p_up_samples_2_0_conv_weight);  add_12 = p_up_samples_2_0_conv_weight = None
        upsample_trilinear3d_2: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_28, None, False, [2.0, 2.0, 2.0]);  conv3d_28 = None
        add_13: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(upsample_trilinear3d_2, add);  upsample_trilinear3d_2 = add = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_22: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(add_13, 8, p_getattr_l__self___up_layers_2___0___norm1_weight, p_getattr_l__self___up_layers_2___0___norm1_bias);  p_getattr_l__self___up_layers_2___0___norm1_weight = p_getattr_l__self___up_layers_2___0___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_22: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_22);  group_norm_22 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_29: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_22, p_getattr_l__self___up_layers_2___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_22 = p_getattr_l__self___up_layers_2___0___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_23: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(conv3d_29, 8, p_getattr_l__self___up_layers_2___0___norm2_weight, p_getattr_l__self___up_layers_2___0___norm2_bias);  conv3d_29 = p_getattr_l__self___up_layers_2___0___norm2_weight = p_getattr_l__self___up_layers_2___0___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_23: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_23);  group_norm_23 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_30: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_23, p_getattr_l__self___up_layers_2___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_23 = p_getattr_l__self___up_layers_2___0___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_14: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(conv3d_30, add_13);  conv3d_30 = add_13 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:175 in decode, code: x = self.conv_final(x)
        group_norm_24: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(add_14, 8, p_conv_final_0_weight, p_conv_final_0_bias);  add_14 = p_conv_final_0_weight = p_conv_final_0_bias = None
        relu_24: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_24);  group_norm_24 = None
        conv3d_31: "f32[1, 3, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_24, p_conv_final_2_conv_weight, p_conv_final_2_conv_bias);  relu_24 = p_conv_final_2_conv_weight = p_conv_final_2_conv_bias = None
        return (conv3d_31,)

@justinchuby
Copy link
Contributor

Yes, thank you

@justinchuby
Copy link
Contributor

That's very strange. If you run torch.onnx.dynamo_export(exported_program, ...), do you get the same graph?

@justinchuby justinchuby added the topic: torch_lib Related to the torch/aten function lib in development label May 15, 2024
@yuanyao-nv
Copy link
Author

@justinchuby Is this the procedure you're suggesting?

    exported_program = torch.export.export(model, (data,))
    exported_program._graph_module.print_readable()

    export_output = torch.onnx.dynamo_export(exported_program, data)
    export_output.save('Clara_SegResNet_dynamo1.onnx')

The exported UpSample module looks about the same as before, still a very big graph.
In addition, the weights in the model appear as extra inputs, giving rise to tens of extra model inputs. Similar to this bug pytorch/pytorch#126071

@justinchuby
Copy link
Contributor

I don't see any resize ops, which is puzzling. Could you share the onnx model? You may remove the weights if it is too big

@justinchuby
Copy link
Contributor

I was expecting to see this function:

def _aten_upsample_output_size(

@yuanyao-nv
Copy link
Author

@justinchuby I uploaded the two versions of the model here: https://drive.google.com/drive/folders/1s1lhKRuG6fOZmD4IjZvN_zlWfIxPB_8w?usp=sharing

@justinchuby
Copy link
Contributor

It's possible that the upsample op was somehow decomposed by PyTorch. I will look deeper.

@borisfom
Copy link

I have found that in general case, one has to run exported_program.run_decompositions() before applying dynamo_export().
That may in fact fold some operations. @yuanyao-nv can you try that ?

@justinchuby
Copy link
Contributor

Thanks. We will be creating a series of changes to the exporter to support ExportedPrograms properly, including handling of the weights.

@yuanyao-nv
Copy link
Author

yuanyao-nv commented May 17, 2024

@borisfom I tried running run_decompositions() but it didn't do anything for this particular subgraph.

    exported_program = torch.export.export(model, args=(data,))
    exported_program.run_decompositions()
    export_output = torch.onnx.dynamo_export(exported_program, data)
    export_output.save('Clara_SegResNet_dynamo.onnx')

@titaiwangms
Copy link
Contributor

titaiwangms commented Jun 6, 2024

@titaiwangms
Copy link
Contributor

Hi @yuanyao-nv,

This one should be fixed when you call torch.onnx.dynamo_export with nn.Module. However, if you call torch.export.export first, it's going to be decomposed to the big subgraph you had. This decomposition is forced by dynamo for some reasons. Feel free to open an issue like pytorch/pytorch#115883.

cc @gramalingam @justinchuby @xadupre This forcing decomposition would need us to maybe rewriting them as patterns. It will come back to us once we rely on torch.export.export.

@yuanyao-nv
Copy link
Author

@titaiwangms Thanks for the update.

What's a good way to test it out?
If I rerun the export script in the description using the latest torch nightly build (2.5.0.dev20240617+cu121) I actually hit another error

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 1509, in dynamo_export
    ).export()
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 1236, in export
    graph_module = self.options.fx_tracer.generate_fx(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 214, in generate_fx
    graph_module, graph_guard = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1379, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 169, in wrapped
    return output_adapter.apply(model_func(*args, **kwargs), model=model)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/usr/local/lib/python3.10/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2462, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 510, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1470, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 754, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 356, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 298, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 95, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2677, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2793, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 234, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 962, in call_function
    return handler(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 941, in _handle_insert_op_in_graph
    return wrap_fx_proxy(tx, proxy)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1759, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1846, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1853, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1786, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1921, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1903, in run_node
    return node.target(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1061, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1450, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1145, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1757, in _dispatch_impl
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 670, in __call__
    return self_._op(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 266, in _fn
    result = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 138, in _fn
    result = fn(**bound.arguments)
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 1080, in add
    a, b = _maybe_broadcast(a, b)
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 419, in _maybe_broadcast
    common_shape = _broadcast_shapes(
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 408, in _broadcast_shapes
    raise RuntimeError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function add>(*(FakeTensor(..., device='cuda:0', size=(1, 64, 28, 56, 56, 32),
           grad_fn=<WarnNotImplemented>), FakeTensor(..., device='cuda:0', size=(1, 64, 56, 56, 32),
           grad_fn=<AddBackward0>)), **{}):
Attempting to broadcast a dimension of length 64 at -4! Mismatching argument at index 1 had torch.Size([1, 64, 56, 56, 32]); but expected shape should be broadcastable to [1, 64, 28, 56, 56, 32]

from user code:
   File "/usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py", line 183, in forward
    x = self.decode(x, down_x)
  File "/usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py", line 171, in decode
    x = up(x) + down_x[i + 1]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Do you know why fake tensor is being used in the latest torch version?

I also tried exporting just a nn.Upsample function

def f(x):
    m = torch.nn.Upsample(size=(10), mode='linear')
    return m(x)

x = torch.randn(2, 5, 5)
export_output = torch.onnx.dynamo_export(f, x)
export_output.save('Upsample.onnx')

The exported graph looks reasonable. Is this what you'd expect?
image

@yuanyao-nv
Copy link
Author

Filed a separate issue to track the above fake tensor broadcast error pytorch/pytorch#129534

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: torch_lib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants