In [1]:
import lightning.pytorch as pl
import onnx
import torch
import torch.nn as nn
import torch.onnx

from onnx import numpy_helper

from torchvision.models import resnet18

In [2]:
# model_path = "../outputs/trained_models/covid_xray_federated_2cl_model_100.onnx"
model_path = "../outputs/trained_models/covid_xray_federated_2cl_model_100.pt"

weights = torch.load(model_path)
# weights = onnx.load(model_path)

In [3]:
class ResNet18(nn.Module):
    def __init__(self, in_channels=3, in_features=16, n_classes=10):
        super(ResNet18, self).__init__()

        self.model = resnet18()

        self.model.conv1 = nn.Conv2d(
            in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
        )

        self.model.fc = nn.Linear(512, n_classes)

    def forward(self, x):
        x = self.model(x)
        return x
    

class LitCNN2d(pl.LightningModule):
    def __init__(
        self,
        input_shape: int,
        n_classes: int,
        lr: float,
        _logging: bool = False,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.input_shape = input_shape
        self.n_classes = n_classes
        self.lr = lr
        self._logging = _logging
        
        self.model = ResNet18(in_channels=input_shape, n_classes=self.n_classes)

    def forward(self, images):
        return self.model(images)

In [4]:
cnn = LitCNN2d(input_shape=1, n_classes=4, lr=1e-3)

# Load the weights
# state_dict = cnn.state_dict()
# for name, param in weights.items():
#     if name in state_dict:
#         state_dict[name].copy_(param)
#     else:
#         print(f"Unrecognized parameter: {name}")
cnn.load_state_dict(weights["model"])

<All keys matched successfully>

In [5]:
cnn

LitCNN2d(
  (model): ResNet18(
    (model): ResNet(
      (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d

In [6]:
cnn.model.model.fc

Linear(in_features=512, out_features=4, bias=True)

In [7]:
# Change the last layer to have 15 classes
cnn.model.model.fc = nn.Linear(512, 15)

In [8]:
cnn

LitCNN2d(
  (model): ResNet18(
    (model): ResNet(
      (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d

In [9]:
# Save the model
torch.save(cnn.state_dict(), "../outputs/covid_xray_federated_2cl_model_100_15classes.pt")

In [10]:
# Load the model
_cnn = LitCNN2d(input_shape=1, n_classes=15, lr=1e-3)
_weights = torch.load("../outputs/covid_xray_federated_2cl_model_100_15classes.pt")
_cnn.load_state_dict(_weights)

<All keys matched successfully>