In [None]:
import os
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.models import efficientnet_b3
from torchvision.models.feature_extraction import (
    create_feature_extractor,
    get_graph_node_names,
)

In [None]:
# Get graph node names
names = get_graph_node_names(efficientnet_b3(weights="IMAGENET1K_V1"))
print(names)

In [None]:
# Feature extractor
class FeatureExtractor(nn.Module):
    def __init__(self, pretrained_model, return_nodes: dict):
        super().__init__()
        nodes = list(return_nodes.values())
        self.out_name = nodes[0]
        self.model = create_feature_extractor(pretrained_model, return_nodes=return_nodes)

    def forward(self, input: torch.Tensor):
        features = self.model(input)
        return features[self.out_name]

In [None]:
# Output layer consisting of a global average pooling layer and a sigmoid layer
class OutputLayer(nn.Module):
    def __init__(self, last_layer_features: int):
        super().__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.sigmoid_fc = nn.Linear(last_layer_features, 1)

    def forward(self, x: torch.Tensor):
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.sigmoid_fc(x)
        x = torch.sigmoid(x)
        return x

In [None]:
# Cam generator
class CAMGenerator(nn.Module):
    def __init__(self, sigmoid_fc: torch.nn.modules.linear.Linear):
        super().__init__()
        self.sigmoid_fc = sigmoid_fc

    def forward(self, input: torch.Tensor, features: torch.Tensor):
        weights = self.sigmoid_fc.weight.data.unsqueeze(-1).unsqueeze(-1)
        cam = torch.einsum("ijkl,ijkl->ikl", features, weights).unsqueeze(1)  # (B, 1, H, W)
        cam = F.interpolate(
            cam, size=input.size()[2:], mode="bilinear", align_corners=False).squeeze(1)  # (B, H, W)
        return cam

In [None]:

# Multi-Head model specifically designed for EfficientNetB3
# Note that return_nodes as well as last_layer_features are hadcoded for EfficientNetB3
class EfficientNetB3_4down(nn.Module):
    def __init__(self, generate_cam: bool = False):
        super().__init__()
        pretrained_model = efficientnet_b3(weights="IMAGENET1K_V1")
        return_nodes = {"features.5.1.block.2": "layerout"}
        self.feature_extractor = FeatureExtractor(pretrained_model, return_nodes=return_nodes)
        self.output_layer = OutputLayer(last_layer_features=816)
        self.cam_generator = CAMGenerator(self.output_layer.sigmoid_fc)
        self.generate_cam = generate_cam

    def forward(self, input: torch.Tensor):
        features = self.feature_extractor(input)
        sigmoid_output = self.output_layer(features)
        if not self.generate_cam:
            return sigmoid_output
        else:
            cam = self.cam_generator(input, features)
            return sigmoid_output, cam

In [None]:
# Check if the model and CAM module outputs reasonable shapes
model = EfficientNetB3_4down(generate_cam=True)
model.eval()
# Assume batch size of 10
dummy_input = torch.randn(10, 3, 640, 640)
with torch.no_grad():
    out, cam = model(dummy_input)

print("Output shape:", out.shape)
print("CAM shape:", cam.shape)