In [2]:
"""
Author@ Mrinal Kanti Dhar
October 24, 2024
"""

import sys
sys.path.append("/research/m324371/Project/adnexal/utils/")

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

from copy import deepcopy

from pscse_cab import PscSEWithCAB
from classification_head import ClassificationHead
from feature_ensemble_2models import FeatureEnsemble2models

from res50pscse_512x28x28 import ResNet50Pscse_512x28x28
from enetb2lpscse_384x28x28 import EfficientNetB2LPscse_384x28x28

from base_models_collection import base_models

In [7]:
class EnsembleResNet18Ft512_EfficientNetB2SFt1408(nn.Module):
    """ Ensembles ResNet18 with 512 features and EfficientNetB2 with 1408 features """
    def __init__(self, 
                 num_classes: int,
                 out_channels: list = None,  # for instance [1024, 512, 256]. Used in classification head
                 pretrain: bool = True,
                 dropout: float = 0.3,
                 in_chs: int = None,
                 separate_inputs: int = None):  # separate_inputs defines the number of inputs

        self.separate_inputs = separate_inputs

        super(EnsembleResNet18Ft512_EfficientNetB2SFt1408, self).__init__()

        model1 = base_models('resnet18', pretrain=pretrain, num_classes=num_classes, in_chs=in_chs)
        model2 = base_models('efficientnet_v2_s', pretrain=pretrain, num_classes=num_classes, in_chs=in_chs)

        self.ens_model1 = FeatureEnsemble2models(model1, model2, trim1=2, trim2=2)  # clip classification head

        # Create a list of models for separate inputs if separate_inputs is specified
        if self.separate_inputs is not None:
            self.ensemble_models = nn.ModuleList([deepcopy(self.ens_model1) for _ in range(self.separate_inputs)])

        self.classification = ClassificationHead(num_classes=num_classes,
                                                 out_channels=out_channels,
                                                 dropout=dropout)

    def forward(self, x):

        if self.separate_inputs is not None:

            # Ensure input data has self.separate_inputs no. of channels
            if x.shape[1] < self.separate_inputs:
                raise ValueError(f"Can't split. Input data has {x.shape[1]} channels whereas separate_inputs parameter is {self.separate_inputs}. \
Check the separate_inputs parameter in the config file.")
            
            features_list = []

            # Loop over each input channel, process it, and store the features
            for i in range(self.separate_inputs):
                # Separate the i-th input (single channel)
                xi = x[:, i:i + 1, :, :]  # extract ith channel

                # Convert to 3 channels by repeating or concatenating along the channel dimension
                xi_3ch = torch.cat([xi, xi, xi], dim=1)

                # Get features from the i-th ensemble model
                features_i = self.ensemble_models[i](xi_3ch)

                # Collect features
                features_list.append(features_i)

            # Concatenate features along the channel dimension
            features = torch.cat(features_list, dim=1)

        else:
            features = self.ens_model1(x)

        # # Pass the features through the classification head
        # out = self.classification(features)

        return out


In [10]:

inp=torch.rand(1, 3, 224, 224)
num_classes=2
out_channels=[5376, 512, 256]
pretrain = True
dropout=0.3
separate_inputs = 3
in_channels = 3

model = EnsembleResNet18Ft512_EfficientNetB2SFt1408(num_classes, out_channels, pretrain, dropout, in_channels, separate_inputs)

out = model(inp)

print(out.shape)

torch.Size([1, 5376, 7, 7])


In [12]:
inp=torch.rand(1, 3, 224, 224)
num_classes=2
out_channels=[5376, 512, 256]
pretrain = True
dropout=0.3
separate_inputs = 3
in_channels = 3

model = EnsembleResNet18Ft512_EfficientNetB2SFt1408(num_classes, out_channels, pretrain, dropout, in_channels, separate_inputs)

modules = []
for layer in model.children():
    if isinstance(layer, nn.ModuleList):
        modules.extend(layer)  # Flatten out ModuleList layers
    else:
        modules.append(layer)
dl_feature_extractor = nn.Sequential(*modules[:-1])  # Exclude the classification head

In [13]:
dl_feature_extractor

Sequential(
  (0): FeatureEnsemble2models(
    (trimped_model1): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchN

In [14]:
out2 = dl_feature_extractor(inp)

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[1, 1792, 7, 7] to have 3 channels, but got 1792 channels instead