In [37]:
import torch
from torchvision import models
import os

### Inspecting the checkpoint

In [38]:
checkpoint = torch.load('./checkpoint/coffee-diseases-stage2-res50-4-ckpt-81.76acc.t7', map_location='cpu')

keys = checkpoint.keys()
print(keys)

print(checkpoint['acc'])
print(checkpoint['epoch'])
# print(checkpoint['scaler'])

dict_keys(['model', 'optimizer', 'scaler', 'acc', 'epoch'])
81.76
9


### Loading the net

In [39]:
# net = models.resnet50(weights='DEFAULT')
# num_features = net.fc.in_features
# net.fc = torch.nn.Sequential(
#     torch.nn.Linear(num_features, 256),
#     torch.nn.ReLU(),
#     torch.nn.Linear(256, 4),
#     torch.nn.Softmax(dim=1)
# )

In [49]:
# ResNet50()
# -*- coding: utf-8 -*-

'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=4):
        super(ResNet, self).__init__()
        self.in_planes = 64

        # self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(num_classes=4):
    return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes)

def ResNet34(num_classes=4):
    return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)

def ResNet50(num_classes=4):
    return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes)

def ResNet101(num_classes=4):
    return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes)

def ResNet152(num_classes=4):
    return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes)


def test():
    net = ResNet18()
    y = net(torch.randn(1,3,32,32))
    print(y.size())

# test()

In [50]:
net = ResNet50(num_classes=3)

In [51]:
# Load the checkpoint
# checkpoint = torch.load('./checkpoint/coffee-diseases-res50-4-ckpt.t7', map_location='cpu')

# print(checkpoint['model'].items())

In [52]:
# Load the checkpoint
#checkpoint = torch.load('./checkpoint/XXXXX.t7', map_location='cpu')

# Remove the "module" prefix from the keys
new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()}

# Load the modified state_dict into the model
net.load_state_dict(new_state_dict)

<All keys matched successfully>

In [53]:
# print('==> loading from checkpoint..')
# assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
# checkpoint = torch.load('./checkpoint/res50-4-ckpt.t7')

# keys = checkpoint.keys()
# print(keys)

# net.load_state_dict(checkpoint['model'])
# best_acc = checkpoint['acc']
# start_epoch = checkpoint['epoch']

# net = ResNet50()
# checkpoint = torch.load('./checkpoint/coffee-diseases-res50-4-ckpt.t7')
# keys = checkpoint.keys()
# print(keys)
# net.load_state_dict(checkpoint['model'])


# net = checkpoint['net']
# print(net)
# net.load_state_dict(checkpoint['net'])

### Saving the .pth model

In [54]:
torch.save(net.state_dict(), './checkpoint/resnet50.pth')

### Loading the model

In [55]:
pthPath = f'./checkpoint/resnet50.pth'
onnxPath = f'./checkpoint/resnet50-81.76acc.onnx'

# model =  torch.load(pthPath)
# model_dict = model.state_dict()
# net.load_state_dict(model_dict)

### Saving the onnx

In [56]:
!pip install onnx



In [57]:
import onnx
from torch.autograd import Variable

# dummy_input = torch.randn(1, 3, 32, 32)
dummy_input = torch.randn(1, 3, 60, 60)

# Export the model
torch.onnx.export(net,                 # model being run
                  dummy_input,                         # model input (or a tuple for multiple inputs)
                  onnxPath,   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

verbose: False, log level: Level.ERROR

