In [None]:
import os

import cv2
import matplotlib.pyplot as plt
from torchsummary import summary
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.models import efficientnet_b3, efficientnet_v2_s, mobilenet_v3_large
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)

In [None]:
# Model summary 
# pretrained_model = efficientnet_b3(weights="IMAGENET1K_V1")
# pretrained_model = efficientnet_v2_s(weights="IMAGENET1K_V1")
pretrained_model = mobilenet_v3_large(weights="IMAGENET1K_V1")

model_summary = summary(pretrained_model, (3, 640, 640), depth=5)
names = get_graph_node_names(pretrained_model)[1]
print("Node names (eval):", names)

In [None]:
# Feature extractor
class FeatureExtractor(nn.Module):
    def __init__(self, pretrained_model, return_nodes: dict):
        super().__init__()
        nodes = list(return_nodes.values())
        self.out_name = nodes[0]
        self.model = create_feature_extractor(pretrained_model, return_nodes=return_nodes)
        self.n_features = self._calculate_n_features()

    def forward(self, input: torch.Tensor):
        features = self.model(input)
        return features[self.out_name]

    def _calculate_n_features(self):
        dummy_input = torch.randn(1, 3, 100, 100)
        with torch.no_grad():
            output = self.forward(dummy_input)
        return output.shape[1]

In [None]:
# Output layer consisting of a global average pooling layer and a sigmoid layer
class OutputLayer(nn.Module):
    def __init__(self, last_layer_features: int):
        super().__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.sigmoid_fc = nn.Linear(last_layer_features, 1)

    def forward(self, x: torch.Tensor):
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.sigmoid_fc(x)
        x = torch.sigmoid(x)
        return x

In [None]:
# Cam generator
class CAMGenerator(nn.Module):
    def __init__(self, sigmoid_fc: torch.nn.modules.linear.Linear):
        super().__init__()
        self.sigmoid_fc = sigmoid_fc

    def forward(self, input: torch.Tensor, features: torch.Tensor):
        weights = self.sigmoid_fc.weight.data.unsqueeze(-1).unsqueeze(-1)
        cam = torch.einsum("ijkl,ijkl->ikl", features, weights).unsqueeze(1)  # (B, 1, H, W)
        cam = F.interpolate(
            cam, size=input.size()[2:], mode="bilinear", align_corners=False
        ).squeeze(
            1
        )  # (B, H, W)
        return cam

In [None]:
# Multi-Head model 
# Note that return_nodes are hardcoded        

class MultiHeadBase(nn.Module):
    def __init__(self, backbone, return_node, weights, generate_cam):
        super().__init__()
        pretrained_model = backbone(weights=weights)
        self.feature_extractor = FeatureExtractor(pretrained_model, return_node)
        self.output_layer = OutputLayer(self.feature_extractor.n_features)
        self.cam_generator = CAMGenerator(self.output_layer.sigmoid_fc)
        self.generate_cam = generate_cam

    def forward(self, input):
        features = self.feature_extractor(input)
        sigmoid_output = self.output_layer(features)
        if self.generate_cam:
            cam = self.cam_generator(input, features)
            return sigmoid_output, cam
        else:
            return sigmoid_output, None

class EfficientNetB3(MultiHeadBase):
    def __init__(self, weights="IMAGENET1K_V1", generate_cam=False):
        return_nodes = {'features.5.1.block.2': 'layerout'}
        super().__init__(efficientnet_b3, return_nodes, weights, generate_cam) 

class EfficientNetV2S(MultiHeadBase):
    def __init__(self, weights="IMAGENET1K_V1", generate_cam=False):
        return_nodes = {'features.6.11.block.0': 'layerout'}
        super().__init__(efficientnet_v2_s, return_nodes, weights, generate_cam)  
        
class MobileNetV3Large(MultiHeadBase):  
     def __init__(self, weights="IMAGENET1K_V1", generate_cam=False):
        return_nodes = {'features.16': 'layerout'}
        super().__init__(mobilenet_v3_large, return_nodes, weights, generate_cam)  
    

In [None]:
models = [
    EfficientNetB3(generate_cam=True), 
    EfficientNetV2S(generate_cam=True), 
    MobileNetV3Large(generate_cam=True)
]

for model in models:
    
    # Check if the model and CAM module outputs reasonable shapes
    model.eval()
    # Assume batch size of 10
    dummy_input = torch.randn(10, 3, 640, 640)
    with torch.no_grad():
        out, cam = model(dummy_input)
    
    print(model.__class__.__name__)
    print("Output shape:", out.shape)
    print("CAM shape:", cam.shape)
    print("\n")
    