In [1]:
import torch
from torch import nn, optim
import torchvision
import torchvision.transforms as transforms

from PIL import Image
import matplotlib.pyplot as plt

In [2]:
BEST_CAT_DOG_MODEL_A = \
    '../checkpoint/ckpt-a-20250314-092618-383242.pth'


class CatDogModelA(nn.Module):

    def __init__(self):
        super(CatDogModelA, self).__init__()
        self.net = torchvision.models.efficientnet_b0(weights=None)
        for i in range(6):
            for param in self.net.features[i].parameters():
                param.requires_grad = False
        self.net.classifier[1] = nn.Linear(
            self.net.classifier[1].in_features, 2
        )
        self.softmax = nn.Softmax(dim=1)


    def forward(self, x):
        x = self.net(x)
        x = self.softmax(x)
        return x


def load_model_from_checkpoint(checkpoint_path=BEST_CAT_DOG_MODEL_A, *, device):
    model = CatDogModelA().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    load_model(model, optimizer, checkpoint_path)
    return model, optimizer


def load_model(model, optimizer, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, weights_only=True)
    model.load_state_dict(checkpoint['model_state_dict'])
    del checkpoint['model_state_dict']
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    del checkpoint['optimizer_state_dict']
    # print('Model loaded:', checkpoint_path)
    return checkpoint


def get_transform():
    return transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ])

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

Device: cuda


In [4]:
model, _ = load_model_from_checkpoint(device=device)
model.eval()
print('OK')

OK


In [5]:
transform = get_transform()

image_path = '../images/may.jpg'
with Image.open(image_path).convert('RGB') as image:
    example_inputs = (torch.unsqueeze(transform(image), 0).to(device),)

print(example_inputs[0].shape)

torch.Size([1, 3, 224, 224])


In [6]:
(
    torch.onnx.export(
        model, example_inputs, '../checkpoint/catdog.onnx',
        opset_version=17,
        export_params=True,
        do_constant_folding=True,
        input_names=['images'],
        output_names=['prob'],
        dynamic_axes={'images': {0: 'batch'}, 'prob': {0: 'batch'}},
        operator_export_type=torch.onnx.OperatorExportTypes.ONNX
    )
)

In [7]:
image_path = '../images/may.jpg'
with Image.open(image_path).convert('RGB') as image:
    example_inputs = (torch.unsqueeze(transform(image), 0).to('cpu'),)
(
    torch.onnx.export(
        model.to('cpu'), example_inputs, '../checkpoint/catdog-cpu.onnx',
        opset_version=17,
        export_params=True,
        do_constant_folding=True,
        input_names=['images'],
        output_names=['prob'],
        dynamic_axes={'images': {0: 'batch'}, 'prob': {0: 'batch'}},
        operator_export_type=torch.onnx.OperatorExportTypes.ONNX
    )
)