<a href="https://colab.research.google.com/github/kumardesappan/colab-notebooks/blob/main/onnx_model_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!pip install onnx

import torch
import torch.onnx
import onnx
import onnx.utils

class LoopModel(torch.nn.Module):
    def __init__(self, num_ch, scale, size):
        super().__init__()
        self.batch_norm = torch.nn.BatchNorm2d(num_ch)
        self.conv = torch.nn.Conv2d(3,num_ch,kernel_size=1, stride=2)
        self.scale = scale
        self.size = size

    def forward(self, x):
        x = self.batch_norm(x)
        h, w = x.shape[2:]       
        h1, w1 = h*self.scale, w*self.scale
        a  = torch.nn.functional.interpolate(x,size=(h1,w1), mode='bilinear', align_corners=True)
        b  = torch.nn.functional.interpolate(x,size=(h,w), mode='bilinear', align_corners=True)
        c  = torch.nn.functional.interpolate(x,size=(16,16), mode='bilinear', align_corners=True)
        d  = torch.nn.functional.interpolate(x,scale_factor=(h1//h,w1//w), mode='bilinear', align_corners=True)

        a = self.batch_norm(a)
        b = self.batch_norm(b)        
        c = self.batch_norm(c) 
        d = self.batch_norm(d)
        return a,b,c,d

num_ch = 3
scale  = 2
opset_version = 11
model = LoopModel(num_ch, scale, 16.0)

dummy_input = torch.ones(1, 3, 128, 256, dtype=torch.float)
name = 'resize_'+str(scale)+'x_v'+str(opset_version)+'.onnx'
name_shape = 'resize_'+str(scale)+'x_v'+str(opset_version)+'_shape.onnx'
torch.onnx.export(model, dummy_input, name, verbose=False,opset_version=opset_version, do_constant_folding =True)

onnx_model = onnx.load(name)  # load onnx model
infer_onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
onnx.save(infer_onnx_model,name_shape)

print('Before shape inference, the shape info of Y is:\n{}'.format(onnx_model.graph.value_info))
print('After shape inference, the shape info of Y is:\n{}'.format(infer_onnx_model.graph.value_info))


Before shape inference, the shape info of Y is:
[]
After shape inference, the shape info of Y is:
[name: "8"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 3
      }
      dim {
        dim_value: 128
      }
      dim {
        dim_value: 256
      }
    }
  }
}
, name: "9"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_value: 4
      }
    }
  }
}
, name: "10"
type {
  tensor_type {
    elem_type: 7
    shape {
    }
  }
}
, name: "11"
type {
  tensor_type {
    elem_type: 7
    shape {
    }
  }
}
, name: "12"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_value: 4
      }
    }
  }
}
, name: "13"
type {
  tensor_type {
    elem_type: 7
    shape {
    }
  }
}
, name: "14"
type {
  tensor_type {
    elem_type: 7
    shape {
    }
  }
}
, name: "15"
type {
  tensor_type {
    elem_type: 7
    shape {
    }
  }
}
, name: "16"
type {
  tensor_type {
   

  dtype=torch.float32)).float())) for i in range(dim)]
