In [19]:
import os
import time
import numpy as np
import torch
import onnx
import onnxruntime as ort
from typing import Optional
from torch import nn
from torchvision import models
from torchinfo import summary


In [14]:
class PreTrainedClassifier(nn.Module):
    """ResNet-18 backbone → custom FC head for N classes."""

    def __init__(self, 
                 num_classes: int =3,
                 dropout: float = 0.5, 
                 pretrained: bool = True,
                 model_backbone: Optional[str] = "resnet18",
                 ) -> None:
        super().__init__()
        self.model_backbone_map = {
            'resnet18': models.ResNet18_Weights.IMAGENET1K_V1,
            'resnet50': models.ResNet50_Weights.IMAGENET1K_V1,
            'efficientnetb1': models.EfficientNet_B1_Weights.IMAGENET1K_V2, 
            'efficientnetb1': models.EfficientNet_B4_Weights.IMAGENET1K_V1, 
        }
        self.dropout = dropout
        if model_backbone in self.model_backbone_map and pretrained:
            weights = self.model_backbone_map[model_backbone] 
        elif model_backbone in self.model_backbone_map and not pretrained:
            weights = None
        else:
            raise ValueError(f"Unsupported model backbone: {model_backbone}")
            
        self.backbone = models.resnet18(weights=weights)
        in_feat = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Dropout(p=self.dropout),
            nn.Linear(in_feat, 256),
            nn.ReLU(),
            nn.Dropout(p=self.dropout),
            nn.Linear(256, num_classes)
        )
        self.classifier = self.backbone.fc

    def forward(self, x):
        return self.backbone(x)

In [15]:
model = PreTrainedClassifier()


In [17]:
state_dict_path = "../models/torch/chest/best_resnet.pt"
model.load_state_dict(torch.load(state_dict_path, map_location="cuda"))

<All keys matched successfully>

In [18]:
model.load_state_dict(torch.load(state_dict_path, map_location="cuda"))

<All keys matched successfully>

In [21]:
summary(model)

Layer (type:depth-idx)                        Param #
PreTrainedClassifier                          --
├─ResNet: 1-1                                 --
│    └─Conv2d: 2-1                            9,408
│    └─BatchNorm2d: 2-2                       128
│    └─ReLU: 2-3                              --
│    └─MaxPool2d: 2-4                         --
│    └─Sequential: 2-5                        --
│    │    └─BasicBlock: 3-1                   73,984
│    │    └─BasicBlock: 3-2                   73,984
│    └─Sequential: 2-6                        --
│    │    └─BasicBlock: 3-3                   230,144
│    │    └─BasicBlock: 3-4                   295,424
│    └─Sequential: 2-7                        --
│    │    └─BasicBlock: 3-5                   919,040
│    │    └─BasicBlock: 3-6                   1,180,672
│    └─Sequential: 2-8                        --
│    │    └─BasicBlock: 3-7                   3,673,088
│    │    └─BasicBlock: 3-8                   4,720,640
│    └─AdaptiveA

In [22]:
onnx_model_path = "../models/onnx/chest/chest_resnet.onnx"
# dummy input - used to clarify the input shape
dummy_input = torch.randn(1, 3, 224, 224)  
torch.onnx.export(model, dummy_input, onnx_model_path,
                  export_params=True, opset_version=20,
                  do_constant_folding=True, input_names=['input'],
                  output_names=['output'], dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}})

print(f"ONNX model saved to {onnx_model_path}")

ONNX model saved to ../models/onnx/chest/chest_resnet.onnx


In [23]:
onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)

AttributeError: parameters