In [1]:
from onnx2pytorch import ConvertModel
import onnx
import torch

In [4]:
path_to_onnx_model = '/Users/macbook/Desktop/supercombo.onnx'

In [38]:
onnx_model = onnx.load(path_to_onnx_model)
pytorch_model = ConvertModel(onnx_model)

### Checking internals

In [39]:
pytorch_model

ConvertModel(
  (Conv_634): Conv2d(12, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (BatchNormalization_636): BatchNormWrapper(
    (bnu): BatchNormUnsafe(32, eps=0.0010000000474974513, momentum=0.9900000095367432, affine=True, track_running_stats=True)
  )
  (Elu_637): ELU(alpha=1.0, inplace=True)
  (Conv_638): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
  (BatchNormalization_640): BatchNormWrapper(
    (bnu): BatchNormUnsafe(32, eps=0.0010000000474974513, momentum=0.9900000095367432, affine=True, track_running_stats=True)
  )
  (Elu_641): ELU(alpha=1.0, inplace=True)
  (Conv_642): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (BatchNormalization_644): BatchNormWrapper(
    (bnu): BatchNormUnsafe(16, eps=0.0010000000474974513, momentum=0.9900000095367432, affine=True, track_running_stats=True)
  )
  (Conv_645): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)


### Running the model

In [40]:
inputs = {
    'input_imgs': torch.zeros((1, 12, 128, 256), dtype=torch.float32), 
    'desire': torch.zeros((1, 8), dtype=torch.float32), 
    'traffic_convention': torch.zeros((1,2), dtype=torch.float32), 
    'initial_state': torch.zeros((1,512), dtype=torch.float32)
}

In [41]:
outs = pytorch_model(**inputs)
outs.shape

torch.Size([1, 6472])

In [42]:
pytorch_model.zero_grad()
outs.backward(torch.randn(outs.shape))

### Re-initializing parameters

In [43]:
print('\nbefore:', pytorch_model.Gemm_1046.bias)

with torch.no_grad():
    pytorch_model.Gemm_1046.bias = torch.nn.Parameter(torch.zeros_like(pytorch_model.Gemm_1046.bias, dtype=torch.float32))
    
print('\nafter:', pytorch_model.Gemm_1046.bias)



before: Parameter containing:
tensor([ 0.2084, -0.1751, -0.0720, -0.0886, -0.1428, -0.3590, -0.0951, -0.2933],
       requires_grad=True)

after: Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)


### Stopping gradients

In [48]:
for name, param in pytorch_model.named_parameters():
    if 'Conv_' in name:
        param.requires_grad = False

In [49]:
for name, param in pytorch_model.named_parameters():
    print(name, 'requires grad:', param.requires_grad)

Conv_634.weight requires grad: False
BatchNormalization_636.bnu.weight requires grad: True
BatchNormalization_636.bnu.bias requires grad: True
Conv_638.weight requires grad: False
BatchNormalization_640.bnu.weight requires grad: True
BatchNormalization_640.bnu.bias requires grad: True
Conv_642.weight requires grad: False
BatchNormalization_644.bnu.weight requires grad: True
BatchNormalization_644.bnu.bias requires grad: True
Conv_645.weight requires grad: False
BatchNormalization_647.bnu.weight requires grad: True
BatchNormalization_647.bnu.bias requires grad: True
Conv_649.weight requires grad: False
BatchNormalization_651.bnu.weight requires grad: True
BatchNormalization_651.bnu.bias requires grad: True
Conv_653.weight requires grad: False
BatchNormalization_655.bnu.weight requires grad: True
BatchNormalization_655.bnu.bias requires grad: True
Conv_657.weight requires grad: False
BatchNormalization_659.bnu.weight requires grad: True
BatchNormalization_659.bnu.bias requires grad: True