In [1]:
import io
import torch
import torch.nn as nn
import math
from collections import OrderedDict

In [2]:
# This code will convert the weights form the pretrained models to be able to use them for Python 3

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
################## AlexNet ##################

def bn_relu(inplanes):
    return nn.Sequential(nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True))

def bn_relu_pool(inplanes, kernel_size=3, stride=2):
    return nn.Sequential(nn.BatchNorm2d(inplanes), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=kernel_size, stride=stride))

class AlexNet(nn.Module):
    def __init__(self, num_classes=1):
        super(AlexNet, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 96, kernel_size=11, stride=4, bias=False)
        self.relu_pool1 = bn_relu_pool(inplanes=96)
        self.conv2 = nn.Conv2d(96, 192, kernel_size=5, padding=2, groups=2, bias=False)
        self.relu_pool2 = bn_relu_pool(inplanes=192)
        self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1, groups=2, bias=False)
        self.relu3 = bn_relu(inplanes=384)
        self.conv4 = nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2, bias=False)
        self.relu4 = bn_relu(inplanes=384)
        self.conv5 = nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2, bias=False)
        self.relu_pool5 = bn_relu_pool(inplanes=256)
        # classifier
        self.conv6 = nn.Conv2d(256, 256, kernel_size=5, groups=2, bias=False)
        self.relu6 = bn_relu(inplanes=256)
        self.conv7 = nn.Conv2d(256, num_classes, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu_pool1(x)
        x = self.conv2(x)
        x = self.relu_pool2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.relu4(x)
        x = self.conv5(x)
        x = self.relu_pool5(x)
        x = self.conv6(x)
        x = self.relu6(x)
        x = self.conv7(x)
        x = x.view(x.size(0), -1)
        return x

In [5]:
def byte_convert(model: dict) -> dict:
    new_model = dict()
    for key in model:
        if type(key) == bytes:
            new_key = key.decode("utf-8")
        else:
            new_key = key
        new_model[new_key] = dict()
        if isinstance(model[key], dict):
            new_model[new_key] = byte_convert(model[key])
        else:
            new_model[new_key] = model[key]
    return new_model

In [6]:
def load_model(model_path: str) -> dict:
    return torch.load(model_path, encoding='bytes', map_location=device)

In [7]:
def load_model2(model_path: str) -> nn.Module:
    with open(model_path, 'rb') as f:
        buffer = io.BytesIO(f.read())
        return torch.load(buffer, map_location=device)

In [8]:
def save_model(model: nn.Module, state_dict: dict) -> None:
    model.load_state_dict(state_dict)
    model.eval()
    torch.save(model, type(model).__name__ + '.pth')

In [9]:
model = AlexNet()
model = load_model('pytorch-models/alexnet.pth')

In [10]:
model.keys()

dict_keys([b'optimizer', b'epoch', b'state_dict', b'best_prec1'])

In [11]:
state_dict = byte_convert(model)['state_dict']

In [12]:
save_model(AlexNet(), state_dict)