In [80]:
from torchvision.models import ResNet50_Weights
import torchinfo
import torch
import torch.nn as nn
import torchvision.models as models

# Define a function to create a ResNet model with custom stages
def resnet_custom_stages(stages, blocks=3, **kwargs):
    if stages < 1 or stages > 4 or not isinstance(stages, int):
        raise ValueError("Number of stages must be between 1 and 4.")
    if blocks < 1 or blocks > 16 or not isinstance(blocks, int):
        raise ValueError("Number of blocks must be greater than 0.")

    # Determine the number of layers to include in each stage
    layers = {
        1: [3, 0, 0, 0],
        2: [3, 4, 0, 0],
        3: [3, 4, 6, 0],
        4: [3, 4, 6, 3]
    }
    if blocks:
        layers[stages][stages-1] = blocks

    # Create the model with the specified stages
    model = models.resnet.ResNet(models.resnet.Bottleneck, layers[stages], **kwargs)
    model.fc = nn.Identity() 
    
    # Remove extra stages if needed
    if stages < 4:
        model.layer4 = nn.Identity()
        if stages < 3:
            model.layer3 = nn.Identity()
            if stages < 2:
                model.layer2 = nn.Identity()

    # Load pre-trained weights from torchvision
    pretrained_model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

    # Copy weights for non-identity layers
    own_state = model.state_dict()
    pretrained_state = pretrained_model.state_dict()

    for name, param in pretrained_state.items():
        if name.startswith('layer'):
            if name not in own_state:
                continue
            own_param = own_state[name]
            if isinstance(own_param, nn.Identity):
                continue
            own_param.copy_(param)

    return model

# Create a ResNet model with only the first two stages
resnet_stage = resnet_custom_stages(stages=1)


print(resnet_stage.forward(torch.randn(1, 3, 224, 224)).shape)
torchinfo.summary(resnet_stage)

torch.Size([1, 256])


Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            9,408
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─Bottleneck: 2-1                   --
│    │    └─Conv2d: 3-1                  4,096
│    │    └─BatchNorm2d: 3-2             128
│    │    └─Conv2d: 3-3                  36,864
│    │    └─BatchNorm2d: 3-4             128
│    │    └─Conv2d: 3-5                  16,384
│    │    └─BatchNorm2d: 3-6             512
│    │    └─ReLU: 3-7                    --
│    │    └─Sequential: 3-8              16,896
│    └─Bottleneck: 2-2                   --
│    │    └─Conv2d: 3-9                  16,384
│    │    └─BatchNorm2d: 3-10            128
│    │    └─Conv2d: 3-11                 36,864
│    │    └─BatchNorm2d: 3-12            128
│    │    └─Conv2d: 3-13               

In [66]:
import torch.nn as nn
import torchvision.models as models

torchinfo.summary(models.resnet50(weights='DEFAULT'))

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            9,408
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─Bottleneck: 2-1                   --
│    │    └─Conv2d: 3-1                  4,096
│    │    └─BatchNorm2d: 3-2             128
│    │    └─Conv2d: 3-3                  36,864
│    │    └─BatchNorm2d: 3-4             128
│    │    └─Conv2d: 3-5                  16,384
│    │    └─BatchNorm2d: 3-6             512
│    │    └─ReLU: 3-7                    --
│    │    └─Sequential: 3-8              16,896
│    └─Bottleneck: 2-2                   --
│    │    └─Conv2d: 3-9                  16,384
│    │    └─BatchNorm2d: 3-10            128
│    │    └─Conv2d: 3-11                 36,864
│    │    └─BatchNorm2d: 3-12            128
│    │    └─Conv2d: 3-13               