In [6]:
import numpy as np
import torch
import os
from utils.ResNet import *

In [7]:
ckpt_dir = '../resnet_params_data/models'
onnx_dir = 'models/onnx'
ckpt_names = os.listdir(ckpt_dir)

model_dict = {
    "ResNet10_1111_4": ResNet10_1111_4(),
    "ResNet10_1111_6": ResNet10_1111_6(),
    "ResNet10_1111_8": ResNet10_1111_8(),
    "ResNet10_22_4": ResNet10_22_4(),
    "ResNet10_22_6": ResNet10_22_6(),
    "ResNet10_22_8": ResNet10_22_8(),
    "ResNet12_2111_4": ResNet12_2111_4(),
    "ResNet12_2111_6": ResNet12_2111_6(),
    "ResNet12_2111_8": ResNet12_2111_8(),
    "ResNet14_2211_4": ResNet14_2211_4(),
    "ResNet14_2211_6": ResNet14_2211_6(),
    "ResNet14_2211_8": ResNet14_2211_8(),
    "ResNet14_222_4": ResNet14_222_4(),
    "ResNet14_222_6": ResNet14_222_6(),
    "ResNet14_222_8": ResNet14_222_8(),
    "ResNet16_2221_4": ResNet16_2221_4(),
    "ResNet16_2221_6": ResNet16_2221_6(),
    "ResNet16_2221_8": ResNet16_2221_8(),
    "ResNet18_2222_4": ResNet18_2222_4(),
    "ResNet18_2222_6": ResNet18_2222_6(),
    "ResNet18_2222_8": ResNet18_2222_8(),
    "ResNet34_3463_4": ResNet34_3463_4(),
    "ResNet34_3463_6": ResNet34_3463_6(),
    "ResNet34_3463_8": ResNet34_3463_8()
}

path_dict = {}
for ckpt in ckpt_names:
    ckpt_file = os.listdir(os.path.join(ckpt_dir, ckpt))
    path_dict[ckpt] = os.path.join(ckpt_dir, ckpt, ckpt_file[0])

In [8]:
model = ResNet14_2211_4()
model.load_state_dict(torch.load('../resnet_params_data/models\\ResNet14_2211_4\\ResNet14_2211_4_seed_0.ckpt', map_location='cpu'))
model.linear = nn.Linear(3136, 10)
print(model)

ResNet(
  (conv1): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (blks): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(4, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T

In [9]:
dummy_inputs = torch.randn(1, 3, 224, 224)
input_names = ['input']
output_names = ['output']
dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
tmp_model_path = os.path.join(onnx_dir, ckpt_names[0] + '.onnx')
torch.onnx.export(
        model,
        dummy_inputs,
        tmp_model_path,
        export_params=True,
        opset_version=13,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
    )

In [10]:
for ckpt in ckpt_names:
    print(ckpt)
    num_blocks = int(ckpt.split('_')[1])
    in_planes = int(ckpt.split('_')[2])
    planes_factor = 1
    len_blocks = len(str(num_blocks))
    for i in range(len_blocks):
        planes_factor = planes_factor * 2

    model = model_dict[ckpt]
    model.load_state_dict(torch.load(path_dict[ckpt], map_location='cpu'))
    model.linear = nn.Linear(in_planes*planes_factor*(4**(4-len_blocks)), 10)

    dummy_inputs = torch.randn(1, 3, 32, 32)
    input_names = ['input']
    output_names = ['output']
    dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    tmp_model_path = os.path.join(onnx_dir, ckpt + '.onnx')
    torch.onnx.export(
        model,
        dummy_inputs,
        tmp_model_path,
        export_params=True,
        opset_version=13,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
    )

ResNet10_1111_4
ResNet10_1111_6
ResNet10_1111_8
ResNet10_22_4
ResNet10_22_6
ResNet10_22_8
ResNet12_2111_4
ResNet12_2111_6
ResNet12_2111_8
ResNet14_2211_4
ResNet14_2211_6
ResNet14_2211_8
ResNet14_222_4
ResNet14_222_6
ResNet14_222_8
ResNet16_2221_4
ResNet16_2221_6
ResNet16_2221_8
ResNet18_2222_4
ResNet18_2222_6
ResNet18_2222_8
ResNet34_3463_4
ResNet34_3463_6
ResNet34_3463_8
