In [1]:
# Find the best checkpoint file
import glob
import os
import torch
import torch.nn as nn
from torchvision import models

In [2]:
list_of_files = glob.glob('animals_classification_*.pth')
latest_file = max(list_of_files, key=os.path.getctime)
print(f"Loading the best model from: {latest_file}")

Loading the best model from: animals_classification_60_0.888.pth


In [3]:
class AnimalClassification(nn.Module):
    def __init__(self, size_inner=256, droprate=0.2, num_classes=90):
        super(AnimalClassification, self).__init__()

        # Load pretrained ResNet18
        self.base_model = models.resnet18(weights="IMAGENET1K_V1")

        # Freeze backbone
        for param in self.base_model.parameters():
            param.requires_grad = False

        # Remove original FC layer
        in_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Identity()

        # Custom head
        self.inner = nn.Linear(in_features, size_inner)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(droprate)
        self.output_layer = nn.Linear(size_inner, num_classes)

    def forward(self, x):
        x = self.base_model(x)          # (B, 512)
        x = self.inner(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.output_layer(x)
        return x

In [4]:
# Set up for using gpu to training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
model = AnimalClassification(size_inner=256, droprate=0.2, num_classes=90)
model.load_state_dict(torch.load(latest_file))
model.to(device)

AnimalClassification(
  (base_model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine

In [6]:
!pip install onnx

Collecting onnx
  Downloading onnx-1.20.0-cp312-abi3-win_amd64.whl.metadata (8.6 kB)
Collecting protobuf>=4.25.1 (from onnx)
  Downloading protobuf-6.33.2-cp310-abi3-win_amd64.whl.metadata (593 bytes)
Collecting ml_dtypes>=0.5.0 (from onnx)
  Downloading ml_dtypes-0.5.4-cp312-cp312-win_amd64.whl.metadata (9.2 kB)
Downloading onnx-1.20.0-cp312-abi3-win_amd64.whl (16.5 MB)
   ---------------------------------------- 0.0/16.5 MB ? eta -:--:--
    --------------------------------------- 0.3/16.5 MB ? eta -:--:--
   ------------------- -------------------- 7.9/16.5 MB 32.3 MB/s eta 0:00:01
   ----------------------------- ---------- 12.3/16.5 MB 27.5 MB/s eta 0:00:01
   -------------------------------------- - 16.0/16.5 MB 24.5 MB/s eta 0:00:01
   ---------------------------------------- 16.5/16.5 MB 22.6 MB/s  0:00:00
Downloading ml_dtypes-0.5.4-cp312-cp312-win_amd64.whl (212 kB)
Downloading protobuf-6.33.2-cp310-abi3-win_amd64.whl (436 kB)
Installing collected packages: protobuf, ml_dtype

In [7]:
# Define dummy input for ONNX export
# The input shape should match the input shape of your model (batch_size, channels, height, width)
# Use a batch size of 1 for simplicity when exporting
dummy_input = torch.randn(1, 3, 224, 224).to(device)

# Export the model to ONNX format
onnx_path = "animals_classification_latest.onnx"

torch.onnx.export(
    model,                     # PyTorch Model
    dummy_input,               # Dummy input tensor
    onnx_path,                 # Path to save the ONNX model
    verbose=True,              # Print export details
    input_names=['input'],     # Input layer name
    output_names=['output'],   # Output layer name
    dynamic_axes={             # Dynamic batch size
        'input' : {0 : 'batch_size'},
        'output' : {0 : 'batch_size'}
    }
)

print(f"Model exported to {onnx_path}")

Model exported to animals_classification_latest.onnx
