In [10]:
import os
import time
import numpy as np
import torch
import onnx
import onnxruntime as ort
import torch.nn as nn
from torchvision import models

In [5]:
model_state_dict_path = "../models/torch/best_resnet.pt"
model_state_dict = torch.load(model_state_dict_path, map_location="cpu")

In [8]:
type(model_state_dict)

collections.OrderedDict

In [12]:
from typing import Optional


class PreTrainedClassifier(nn.Module):
    """ResNet-18 backbone → custom FC head for N classes."""

    def __init__(self, 
                 num_classes: int,
                 dropout: float = 0.5, 
                 pretrained: bool = True,
                 model_backbone: Optional[str] = "resnet18",
                 ) -> None:
        super().__init__()
        self.model_backbone_map = {
            'resnet18': models.ResNet18_Weights.IMAGENET1K_V1,
            'resnet50': models.ResNet50_Weights.IMAGENET1K_V1,
            'efficientnetb1': models.EfficientNet_B1_Weights.IMAGENET1K_V2, 
            'efficientnetb1': models.EfficientNet_B4_Weights.IMAGENET1K_V1, 
        }
        self.dropout = dropout
        if model_backbone in self.model_backbone_map and pretrained:
            weights = self.model_backbone_map[model_backbone] 
        elif model_backbone in self.model_backbone_map and not pretrained:
            weights = None
        else:
            raise ValueError(f"Unsupported model backbone: {model_backbone}")
            
        self.backbone = models.resnet18(weights=weights)
        in_feat = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Dropout(p=self.dropout),
            nn.Linear(in_feat, 256),
            nn.ReLU(),
            nn.Dropout(p=self.dropout),
            nn.Linear(256, num_classes)
        )
        self.classifier = self.backbone.fc

    def forward(self, x):
        return self.backbone(x)
    

In [13]:
model = PreTrainedClassifier(num_classes=3)
model.load_state_dict(model_state_dict)
model.eval()

PreTrainedClassifier(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=T

In [14]:
from src.chest_xray_trainer import PreTrainedClassifier

ModuleNotFoundError: No module named 'src'