# Model independent of input size

## Classification

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.models as models

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
class ResNetVariableInput(nn.Module):
    """ ResNet with adaptive pooling to make it input independent """
    
    def __init__(self, num_classes=16):
        super().__init__()
        # Load pretrained ResNet50
        self.backbone = models.resnet50(weights="IMAGENET1K_V1")
        
        # Replace fixed-size avgpool with adaptive global pooling
        # (N, 2048, H/32, W/32) --> (N, 2048, 1, 1)
        # This makes the network input-size agnostic, because its output is always [B, C, 1, 1]
        # C is 2048, the number of channels of the last CNN
        self.backbone.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Replace final FC layer for 16 classes
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.backbone(x)

In [4]:
# Test the model on variable input sizes ----
model = ResNetVariableInput(num_classes=16).to(device)

# Create dummy inputs with different spatial sizes
x1 = torch.randn(2, 3, 128, 128).to(device)   # small square
x2 = torch.randn(2, 3, 256, 512).to(device)   # rectangular
x3 = torch.randn(2, 3, 480, 640).to(device)   # bigger

# Forward pass
y1 = model(x1)
y2 = model(x2)
y3 = model(x3)

print("Output shapes:")
print(list(x1.shape), "->", list(y1.shape))  # (2, 16)
print(list(x2.shape), "->", list(y2.shape))  # (2, 16)
print(list(x3.shape), "->", list(y3.shape))  # (2, 16)

Output shapes:
[2, 3, 128, 128] -> [2, 16]
[2, 3, 256, 512] -> [2, 16]
[2, 3, 480, 640] -> [2, 16]
