In [3]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models.detection.backbone_utils import BackboneWithFPN
from torchvision.ops import misc as misc_nn_ops
from typing import Dict, List
import logging

In [4]:
# Cell 2: Set up logging
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


#### Medication Image Feature Extractor Class

In [None]:
class MedicationFeatureExtractor(nn.Module):
    """
    Feature extractor for the Parenteral Medication Recognition System.
    Uses ResNet50 with Feature Pyramid Network (FPN) to extract multi-scale features
    from medication images.
    """
    
    def __init__(
        self,
        pretrained: bool = True,
        returned_layers: List[int] = [1, 2, 3, 4],
        extra_blocks=None,
        norm_layer=misc_nn_ops.FrozenBatchNorm2d,
        trainable_layers: int = 3
    ):
        """
        Initialize the feature extractor.
        
        Args:
            pretrained: Whether to use pretrained weights
            returned_layers: Which ResNet layers to return features from
            extra_blocks: Extra blocks to add to FPN
            norm_layer: Normalization layer to use
            trainable_layers: Number of trainable layers (from the end)
        """
        super(MedicationFeatureExtractor, self).__init__()
        
        logger.info("Initializing MedicationFeatureExtractor with ResNet50FPN backbone")
        
        # Load pretrained ResNet50 model
        resnet = models.resnet50(pretrained=pretrained)
        
        # Freeze specific layers if needed
        if trainable_layers < 5:
            # Freeze layers based on trainable_layers parameter
            layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
            for name, parameter in resnet.named_parameters():
                if all([not name.startswith(layer) for layer in layers_to_train]):
                    parameter.requires_grad_(False)
                    
            logger.info(f"Freezing layers except: {layers_to_train}")
        
        # Get list of stage modules from ResNet
        return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
        
        # Create the backbone with FPN
        self.backbone = BackboneWithFPN(
            resnet, 
            return_layers=return_layers,
            in_channels_list=[256, 512, 1024, 2048][:len(returned_layers)],
            out_channels=256,
            extra_blocks=extra_blocks
        )
        
        # Feature normalization
        self.normalization = nn.ModuleDict({
            str(i): nn.BatchNorm2d(256) for i in range(len(returned_layers))
        })
        
        # Additional layers for feature enhancement
        self.enhancement = nn.ModuleDict({
            str(i): nn.Sequential(
                nn.Conv2d(256, 256, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            ) for i in range(len(returned_layers))
        })
        
        # Initialize weights
        self._initialize_weights()
        
        logger.info(f"MedicationFeatureExtractor initialized with {len(returned_layers)} feature levels")


The Feature Pyramid Network (FPN) is used to extract multi-scale features through several key components in the code:

1. **Creation of the FPN structure** happens in these lines:
```python
# Create the backbone with FPN
self.backbone = BackboneWithFPN(
    resnet, 
    return_layers=return_layers,
    in_channels_list=[256, 512, 1024, 2048][:len(returned_layers)],
    out_channels=256,
    extra_blocks=extra_blocks
)
```

This is where the magic happens. The `BackboneWithFPN` class from torchvision combines the ResNet backbone with an FPN architecture. Here's how it works:

2. **Feature extraction at multiple scales** is achieved by:
   - `return_layers` mapping specifies which ResNet layers to extract features from (typically layers 1-4)
   - `in_channels_list=[256, 512, 1024, 2048]` corresponds to the channel dimensions from these different ResNet layers
   - Each layer represents a different scale/resolution of features (earlier layers have higher resolution but less semantic information)

3. **FPN structure** adds top-down pathways with lateral connections that:
   - Takes high-level features from deeper layers
   - Upsamples them and combines them with features from shallower layers
   - This creates a feature hierarchy where each level contains both high-resolution and strong semantic information

4. **Standardized output channels** is ensured with:
   - `out_channels=256` makes all feature maps have the same number of channels (256)
   - This standardization allows for consistent processing downstream

Then during the forward pass, these multi-scale features are accessed as a dictionary:
```python
# Extract features using the backbone
features = self.backbone(x)
```

Where each key in the `features` dictionary corresponds to a different scale level, giving you multi-scale feature representation of the input image.

In [None]:
# Cell 4: Weight initialization method
def _initialize_weights(self):
    """Initialize weights for enhancement layers."""
    for m in self.enhancement.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)