In [93]:
import torch
import torch.nn as nn
import torchvision.models as models

# Define a function to create a ResNet model with custom stages
def resnet_custom_stages(model, stages, blocks=3, **kwargs):
    if stages < 1 or stages > 4 or not isinstance(stages, int):
        raise ValueError("Number of stages must be between 1 and 4.")
    if blocks < 1 or blocks > 6 or not isinstance(blocks, int):
        raise ValueError("Number of blocks must be greater than 0.")

    # Determine the number of layers to include in each stage
    num_blocks = [3, 4, 6, 3]
    layers = {
        1: [3, 0, 0, 0],
        2: [3, 4, 0, 0],
        3: [3, 4, 6, 0],
        4: [3, 4, 6, 3]
    }
    if not blocks:
        layers[stages][stages-1] = num_blocks[stages-1]
    else:
        layers[stages][stages-1] = blocks

    # Create the model with the specified stages
    model = models.resnet.ResNet(models.resnet.Bottleneck, layers[stages], **kwargs)
    model.fc = nn.Identity() 
    
    # Remove extra stages if needed
    if stages < 4:
        model.layer4 = nn.Identity()
        if stages < 3:
            model.layer3 = nn.Identity()
            if stages < 2:
                model.layer2 = nn.Identity()

    # Load pre-trained weights from torchvision
    pretrained_model = model
    pretrained_state_dict = pretrained_model.state_dict()
    model.load_state_dict(pretrained_state_dict, strict=False)

    return model

# Create a ResNet model with only the first two stages
resnet_first_two_stages = resnet_custom_stages(stages=1, blocks=2)

# Optionally, load pre-trained weights
# resnet_first_two_stages = resnet_custom_stages(stages=1, pretrained=True)

print(resnet_first_two_stages.forward(torch.randn(1, 3, 224, 224)).shape)
torchinfo.summary(resnet_first_two_stages)

torch.Size([1, 256])


Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            9,408
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─Bottleneck: 2-1                   --
│    │    └─Conv2d: 3-1                  4,096
│    │    └─BatchNorm2d: 3-2             128
│    │    └─Conv2d: 3-3                  36,864
│    │    └─BatchNorm2d: 3-4             128
│    │    └─Conv2d: 3-5                  16,384
│    │    └─BatchNorm2d: 3-6             512
│    │    └─ReLU: 3-7                    --
│    │    └─Sequential: 3-8              16,896
│    └─Bottleneck: 2-2                   --
│    │    └─Conv2d: 3-9                  16,384
│    │    └─BatchNorm2d: 3-10            128
│    │    └─Conv2d: 3-11                 36,864
│    │    └─BatchNorm2d: 3-12            128
│    │    └─Conv2d: 3-13               

In [95]:
from torchvision.models import ResNet50_Weights
from torchinfo import torchinfo
from feature_extraction.feature_extractors.resnet.stage_feature_extractor import StageFeatureExtractor

torchinfo.summary(StageFeatureExtractor(models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2), num_stages=1))

Layer (type:depth-idx)                        Param #
StageFeatureExtractor                         --
├─ResNet: 1-1                                 --
│    └─Conv2d: 2-1                            (9,408)
│    └─BatchNorm2d: 2-2                       (128)
│    └─ReLU: 2-3                              --
│    └─MaxPool2d: 2-4                         --
│    └─Sequential: 2-5                        --
│    │    └─Bottleneck: 3-1                   (75,008)
│    │    └─Bottleneck: 3-2                   (70,400)
│    │    └─Bottleneck: 3-3                   (70,400)
│    └─Sequential: 2-6                        --
│    │    └─Bottleneck: 3-4                   (379,392)
│    │    └─Bottleneck: 3-5                   (280,064)
│    │    └─Bottleneck: 3-6                   (280,064)
│    │    └─Bottleneck: 3-7                   (280,064)
│    └─Sequential: 2-7                        --
│    │    └─Bottleneck: 3-8                   (1,512,448)
│    │    └─Bottleneck: 3-9                   (1,1