# Exporting backbones 

This notebook investigates how backbones can be exported using `torch.export`.

In [None]:
import torch.export
import torch.nn
import torch.fx

import backbones as bb
import unipercept as up

rn50 = up.config.lazy.instantiate(bb.resnet.configs.RESNET_50)
bb.load_weights(
    "../weights/resnet/50/imagenet-classification.safetensors", rn50, device="cpu"
)

print(rn50)




ResNet(
  (stem): Sequential(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): InplaceReLU(inplace=True)
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (ext1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (norm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): InplaceReLU(inplace=True)
      (residual): Sequential(
 

  bb.load_weights(


In [46]:
from torch.export.dynamic_shapes import Dim

inputs = (torch.randn(2, 3, 512, 1024),)

sb = Dim("batch", min=1)
sh = Dim("height", min=128, max=512)
sw = Dim("width", min=128, max=512)
# sw = 2 * sh


rn50_train = torch.export.export_for_training(
    rn50,
    inputs,
    dynamic_shapes=[(sb, Dim.STATIC, 2 * sh, 2 * sw)],  # type: ignore[attr-defined]
)
print(rn50_train)


ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_stem_conv_weight: "f32[64, 3, 7, 7]", p_stem_norm_weight: "f32[64]", p_stem_norm_bias: "f32[64]", p_getattr_l__self___ext1___0___residual_conv_weight: "f32[256, 64, 1, 1]", p_getattr_l__self___ext1___0___residual_norm_weight: "f32[256]", p_getattr_l__self___ext1___0___residual_norm_bias: "f32[256]", p_getattr_l__self___ext1___0___conv1_weight: "f32[64, 64, 1, 1]", p_getattr_l__self___ext1___0___norm1_weight: "f32[64]", p_getattr_l__self___ext1___0___norm1_bias: "f32[64]", p_getattr_l__self___ext1___0___conv2_weight: "f32[64, 64, 3, 3]", p_getattr_l__self___ext1___0___norm2_weight: "f32[64]", p_getattr_l__self___ext1___0___norm2_bias: "f32[64]", p_getattr_l__self___ext1___0___conv3_weight: "f32[256, 64, 1, 1]", p_getattr_l__self___ext1___0___norm3_weight: "f32[256]", p_getattr_l__self___ext1___0___norm3_bias: "f32[256]", p_getattr_l__self___ext1___1___conv1_weight: "f32[64, 256, 1, 1]", p_getattr_l__sel

In [47]:
torch.export.save(rn50_train, "resnet50.pt2")


In [48]:
rn50_load = torch.export.load("resnet50.pt2")

print(rn50_load)


ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_stem_conv_weight: "f32[64, 3, 7, 7]", p_stem_norm_weight: "f32[64]", p_stem_norm_bias: "f32[64]", p_getattr_l__self___ext1___0___residual_conv_weight: "f32[256, 64, 1, 1]", p_getattr_l__self___ext1___0___residual_norm_weight: "f32[256]", p_getattr_l__self___ext1___0___residual_norm_bias: "f32[256]", p_getattr_l__self___ext1___0___conv1_weight: "f32[64, 64, 1, 1]", p_getattr_l__self___ext1___0___norm1_weight: "f32[64]", p_getattr_l__self___ext1___0___norm1_bias: "f32[64]", p_getattr_l__self___ext1___0___conv2_weight: "f32[64, 64, 3, 3]", p_getattr_l__self___ext1___0___norm2_weight: "f32[64]", p_getattr_l__self___ext1___0___norm2_bias: "f32[64]", p_getattr_l__self___ext1___0___conv3_weight: "f32[256, 64, 1, 1]", p_getattr_l__self___ext1___0___norm3_weight: "f32[256]", p_getattr_l__self___ext1___0___norm3_bias: "f32[256]", p_getattr_l__self___ext1___1___conv1_weight: "f32[64, 256, 1, 1]", p_getattr_l__sel

In [None]:
rn50_ft = bb.extract_features(rn50, features=["ext1", "ext2", "ext3", "ext4"])
rn50_ft.print_readable()


class ResNet(torch.nn.Module):
    def forward(self, x : torch.Tensor):
        # No stacktrace found for following nodes
        stem_conv = self.stem.conv(x);  x = None
        stem_norm = self.stem.norm(stem_conv);  stem_conv = None
        relu = torch.nn.functional.relu(stem_norm, inplace = True);  stem_norm = None
        stem_pool = self.stem.pool(relu);  relu = None
        ext1_0_residual_conv = getattr(self.ext1, "0").residual.conv(stem_pool)
        ext1_0_residual_norm = getattr(self.ext1, "0").residual.norm(ext1_0_residual_conv);  ext1_0_residual_conv = None
        ext1_0_conv1 = getattr(self.ext1, "0").conv1(stem_pool);  stem_pool = None
        ext1_0_norm1 = getattr(self.ext1, "0").norm1(ext1_0_conv1);  ext1_0_conv1 = None
        relu_1 = torch.nn.functional.relu(ext1_0_norm1, inplace = True);  ext1_0_norm1 = None
        ext1_0_conv2 = getattr(self.ext1, "0").conv2(relu_1);  relu_1 = None
        ext1_0_norm2 = getattr(self.ext1, "0").norm2(ext1_0_conv2);  ext1_0_con

'class ResNet(torch.nn.Module):\n    def forward(self, x : torch.Tensor):\n        # No stacktrace found for following nodes\n        stem_conv = self.stem.conv(x);  x = None\n        stem_norm = self.stem.norm(stem_conv);  stem_conv = None\n        relu = torch.nn.functional.relu(stem_norm, inplace = True);  stem_norm = None\n        stem_pool = self.stem.pool(relu);  relu = None\n        ext1_0_residual_conv = getattr(self.ext1, "0").residual.conv(stem_pool)\n        ext1_0_residual_norm = getattr(self.ext1, "0").residual.norm(ext1_0_residual_conv);  ext1_0_residual_conv = None\n        ext1_0_conv1 = getattr(self.ext1, "0").conv1(stem_pool);  stem_pool = None\n        ext1_0_norm1 = getattr(self.ext1, "0").norm1(ext1_0_conv1);  ext1_0_conv1 = None\n        relu_1 = torch.nn.functional.relu(ext1_0_norm1, inplace = True);  ext1_0_norm1 = None\n        ext1_0_conv2 = getattr(self.ext1, "0").conv2(relu_1);  relu_1 = None\n        ext1_0_norm2 = getattr(self.ext1, "0").norm2(ext1_0_conv2

In [53]:
rn50_ft_train = torch.export.export_for_training(
    rn50_ft, inputs, dynamic_shapes=[(sb, Dim.STATIC, 2 * sh, 2 * sw)]
)

rn50_ft_train.graph_module.print_readable()


class GraphModule(torch.nn.Module):
    def forward(self, p_stem_conv_weight: "f32[64, 3, 7, 7]", p_stem_norm_weight: "f32[64]", p_stem_norm_bias: "f32[64]", p_getattr_l__self___ext1___0___residual_conv_weight: "f32[256, 64, 1, 1]", p_getattr_l__self___ext1___0___residual_norm_weight: "f32[256]", p_getattr_l__self___ext1___0___residual_norm_bias: "f32[256]", p_getattr_l__self___ext1___0___conv1_weight: "f32[64, 64, 1, 1]", p_getattr_l__self___ext1___0___norm1_weight: "f32[64]", p_getattr_l__self___ext1___0___norm1_bias: "f32[64]", p_getattr_l__self___ext1___0___conv2_weight: "f32[64, 64, 3, 3]", p_getattr_l__self___ext1___0___norm2_weight: "f32[64]", p_getattr_l__self___ext1___0___norm2_bias: "f32[64]", p_getattr_l__self___ext1___0___conv3_weight: "f32[256, 64, 1, 1]", p_getattr_l__self___ext1___0___norm3_weight: "f32[256]", p_getattr_l__self___ext1___0___norm3_bias: "f32[256]", p_getattr_l__self___ext1___1___conv1_weight: "f32[64, 256, 1, 1]", p_getattr_l__self___ext1___1___norm1_weig

'class GraphModule(torch.nn.Module):\n    def forward(self, p_stem_conv_weight: "f32[64, 3, 7, 7]", p_stem_norm_weight: "f32[64]", p_stem_norm_bias: "f32[64]", p_getattr_l__self___ext1___0___residual_conv_weight: "f32[256, 64, 1, 1]", p_getattr_l__self___ext1___0___residual_norm_weight: "f32[256]", p_getattr_l__self___ext1___0___residual_norm_bias: "f32[256]", p_getattr_l__self___ext1___0___conv1_weight: "f32[64, 64, 1, 1]", p_getattr_l__self___ext1___0___norm1_weight: "f32[64]", p_getattr_l__self___ext1___0___norm1_bias: "f32[64]", p_getattr_l__self___ext1___0___conv2_weight: "f32[64, 64, 3, 3]", p_getattr_l__self___ext1___0___norm2_weight: "f32[64]", p_getattr_l__self___ext1___0___norm2_bias: "f32[64]", p_getattr_l__self___ext1___0___conv3_weight: "f32[256, 64, 1, 1]", p_getattr_l__self___ext1___0___norm3_weight: "f32[256]", p_getattr_l__self___ext1___0___norm3_bias: "f32[256]", p_getattr_l__self___ext1___1___conv1_weight: "f32[64, 256, 1, 1]", p_getattr_l__self___ext1___1___norm1_we

In [55]:
rn50_ft_train.module()(*inputs)


{'ext1': tensor([[[[2.1229e+00, 2.1989e-01, 3.5561e+00,  ..., 3.0651e+00,
            6.5761e-01, 3.5836e+00],
           [1.7811e+00, 1.9223e+00, 4.0303e+00,  ..., 0.0000e+00,
            0.0000e+00, 0.0000e+00],
           [4.1568e+00, 1.5858e+00, 4.2046e-02,  ..., 1.5455e-01,
            8.4549e-02, 0.0000e+00],
           ...,
           [0.0000e+00, 0.0000e+00, 7.3195e-01,  ..., 4.0034e-01,
            0.0000e+00, 0.0000e+00],
           [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 8.2993e-01,
            1.4144e+00, 3.2125e-01],
           [0.0000e+00, 0.0000e+00, 5.5326e+00,  ..., 5.9888e+00,
            4.1769e+00, 3.9481e+00]],
 
          [[6.6075e-01, 8.0942e-01, 3.6434e-01,  ..., 4.5403e+00,
            1.3335e+00, 0.0000e+00],
           [1.9473e+00, 1.7458e+00, 9.9051e-01,  ..., 4.4006e+00,
            1.0807e+00, 2.4453e+00],
           [1.2348e+00, 5.2678e-01, 1.0074e+00,  ..., 1.6955e+00,
            2.2640e+00, 2.9317e+00],
           ...,
           [2.0802e+00, 7.9480