In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

class SimpleTimmModel(nn.Module):
    """
    A simple model using a timm backbone for feature extraction,
    followed by custom layers to produce a spatial output map.
    Handles input channel adaptation and spatial resizing via interpolation.
    """
    def __init__(self, timm_model_name, in_channels, target_output_hw):
        super().__init__()
        self.timm_model_name = timm_model_name
        self.in_channels = in_channels
        self.target_output_hw = target_output_hw

        # --- 1. Load the backbone ---
        # num_classes=0 ensures we get the backbone *without* the classification head.
        # pretrained=True loads weights if available.
        print(f"Loading timm model: {timm_model_name} with num_classes=0 and pretrained=True")
        try:
            self.backbone = timm.create_model(timm_model_name, pretrained=True, num_classes=0)
            print("Backbone loaded successfully.")
        except Exception as e:
            print(f"Error loading timm model {timm_model_name}: {e}")
            raise

        # --- 2. Adapt input channels if necessary ---
        # Standard pretrained models expect 3 input channels.
        # If our input has a different number (like 5), we need to modify the first layer.
        # This involves finding the first Conv2d layer and replacing it.
        original_first_conv = None
        first_conv_name = None

        # Find the first Conv2d layer in the backbone
        for name, module in self.backbone.named_modules():
            if isinstance(module, nn.Conv2d):
                original_first_conv = module
                first_conv_name = name
                break # Found the first one, stop

        if original_first_conv is None:
             raise AttributeError(f"Could not find *any* Conv2d layer in {timm_model_name}. Simple adaptation for input channels is not possible for this model type (e.g., ViT).")

        print(f"Found first convolutional layer: '{first_conv_name}' with {original_first_conv.in_channels} input channels.")

        if in_channels != original_first_conv.in_channels:
            print(f"Adapting first convolutional layer from {original_first_conv.in_channels} to {in_channels} input channels.")
            # Create a new conv layer with the desired in_channels but same properties
            new_first_conv = nn.Conv2d(
                in_channels=in_channels,
                out_channels=original_first_conv.out_channels,
                kernel_size=original_first_conv.kernel_size,
                stride=original_first_conv.stride,
                padding=original_first_conv.padding,
                bias=original_first_conv.bias is not None
            )

            # Optional: Copy weights for the first 3 channels from the pretrained model
            # This allows the model to leverage the learned features for the initial 3 channels.
            # The weights for the new channels (5-3=2 in your case) are randomly initialized.
            if original_first_conv.in_channels == 3:
                 print("Copying weights for initial 3 input channels from pretrained model.")
                 with torch.no_grad():
                     # Copy weights for the first 3 channels
                     new_first_conv.weight.data[:, :3, :, :].copy_(original_first_conv.weight.data)
                     # Initialize weights for the new channels (e.g., with zeros or small random values)
                     if in_channels > 3:
                         # Simple zero initialization for extra channels
                          new_first_conv.weight.data[:, 3:, :, :].zero_()
                     # Copy bias if it exists
                     if original_first_conv.bias is not None:
                         new_first_conv.bias.copy_(original_first_conv.bias.data)
            else:
                 print("Original model did not have 3 input channels. Not copying weights.")


            # Replace the original first convolutional layer in the backbone's module hierarchy
            # This requires navigating the module structure.
            # For instance, if name is 'conv1', parent is self.backbone and child name is 'conv1'.
            # If name is 'features.0', parent is self.backbone.features and child name is '0'.
            parts = first_conv_name.rsplit('.', 1)
            if len(parts) == 1: # Top level module
                 parent_module = self.backbone
                 child_name = parts[0]
            else: # Nested module
                 parent_name = parts[0]
                 child_name = parts[1]
                 parent_module = self.backbone.get_submodule(parent_name)

            setattr(parent_module, child_name, new_first_conv)
            print(f"Replaced '{first_conv_name}' layer.")

        else:
            print("Input channels match backbone. No adaptation needed for the first conv layer.")


        # --- 3. Determine backbone output *feature map* shape ---
        # We need to know the number of channels output by the backbone's
        # feature extractor before any potential global pooling.
        # timm models often have a `forward_features` method for this.
        print("Determining backbone output feature map shape using a dummy tensor...")
        # Use a dummy input tensor with the expected shape (1 batch, your channels, example H, example W)
        # The spatial dimensions (1000, 70) are large, the backbone will downsample significantly.
        dummy_input = torch.randn(1, self.in_channels, 1000, 70)

        # Ensure the dummy input is on the same device as the model parameters (important!)
        # Get the device from one of the backbone's parameters
        device = next(self.backbone.parameters()).device
        dummy_input = dummy_input.to(device)
        print(f"Dummy input tensor shape: {dummy_input.shape} on device: {device}")


        try:
            # Pass the dummy input through the feature extraction part only
            with torch.no_grad():
                 backbone_features = self.backbone.forward_features(dummy_input)

            # Get the shape of the output feature map
            # Should be (1, channels, height, width)
            self.backbone_out_channels = backbone_features.shape[1]
            # Note: The spatial dimensions (H, W) here depend heavily on the backbone
            # and the input size (1000x70). They will likely be much smaller than 1000x70
            # due to the backbone's downsampling layers.
            self.backbone_out_h = backbone_features.shape[2]
            self.backbone_out_w = backbone_features.shape[3]
            print(f"Backbone `forward_features` output shape: {backbone_features.shape}")
            print(f"Determined backbone output channels: {self.backbone_out_channels}")

        except Exception as e:
             print(f"Error determining backbone output shape using `forward_features`. Does '{timm_model_name}' have this method or does it work with input shape {dummy_input.shape}?")
             print(f"Error details: {e}")
             # This shape is crucial, so re-raise if we can't get it
             raise RuntimeError("Failed to determine backbone output feature map shape.") from e


        # --- 4. Add custom layers to transform backbone output to target shape ---
        # Target output shape: (N, 1, 70, 70)
        # Backbone output shape: (N, backbone_out_channels, backbone_out_h, backbone_out_w)

        # Layer 1: Reduce channels from `backbone_out_channels` to 1.
        # A 1x1 convolution is suitable for this.
        self.channel_reducer = nn.Conv2d(self.backbone_out_channels, 1, kernel_size=1)
        print(f"Added channel reducer layer: {self.backbone_out_channels} -> 1 channels.")

        # Layer 2: Spatially transform the feature map from (backbone_out_h, backbone_out_w)
        # to the target spatial size (70, 70).
        # Since the backbone heavily downsamples a 1000x70 input, the output
        # spatial size will likely be much smaller than 70x70. We need upsampling.
        # F.interpolate is a simple way to resize. Alternatively, ConvTranspose2d could be used
        # if learnable upsampling is desired. We'll use interpolation for simplicity.
        # The interpolation happens within the forward pass, not as a module here.

        print(f"Model head will reduce channels to 1 and interpolate to target spatial size {self.target_output_hw}.")


    def forward(self, x):
        """
        Forward pass of the simple model.
        Args:
            x (torch.Tensor): Input tensor of shape (N, in_channels, H, W).
        Returns:
            torch.Tensor: Output tensor of shape (N, 1, target_output_h, target_output_w).
        """
        # Ensure input has 4 dimensions (N, C, H, W)
        if x.ndim != 4:
             raise ValueError(f"Expected input tensor to have 4 dimensions (N, C, H, W), but got {x.ndim}")

        # 1. Pass input through the backbone's feature extraction path
        # Use forward_features to get the output *before* any global pooling that
        # might be present even with num_classes=0.
        features = self.backbone.forward_features(x)
        # Shape of features: (N, backbone_out_channels, backbone_out_h, backbone_out_w)
        # print(f"Backbone features shape: {features.shape}") # Optional: uncomment for debugging

        # 2. Apply the channel reduction layer
        output = self.channel_reducer(features)
        # Shape of output: (N, 1, backbone_out_h, backbone_out_w)
        # print(f"After channel reduction shape: {output.shape}") # Optional

        # 3. Interpolate to the target spatial size
        # Use bilinear interpolation for continuous data.
        # align_corners=False is generally recommended for bilinear interpolation
        # unless you have a specific reason to align corners (e.g., matching pixel centers vs corners).
        output = F.interpolate(
            output,
            size=self.target_output_hw,
            mode='bilinear',
            align_corners=False
        )
        # Shape of output: (N, 1, target_output_h, target_output_w)
        # print(f"After interpolation shape: {output.shape}") # Optional

        return output


In [6]:

# --- Example Usage (using your cfg structure) ---
# Assume cfg is defined elsewhere, e.g.:
# class Config:
#      def __init__(self):
#          self.backbone = 'resnet18' # Or 'efficientnet_b0', etc.
#          self.batch_size = 4
#          self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
#
# cfg = Config()
# print(f"Using device: {cfg.device}")

print("Instantiating the simple model...")
# Pass the configured backbone name, the expected input channels (5),
# and the target output spatial size (70, 70).
# Use the SimpleTimmModel class we defined.
try:
    model = SimpleTimmModel(
        timm_model_name=cfg.backbone,
        in_channels=5,
        target_output_hw=(70, 70)
    ).to(cfg.device)
    print("Model instantiated successfully.")
    print(f"Model summary:\n{model}")

except Exception as e:
    print(f"Failed to instantiate the model: {e}")
    model = None # Set model to None if instantiation failed

# Check initial model output shape with a dummy input on the correct device
if model: # Only proceed if model was instantiated
    # if train_dl: # Assume train_dl exists and implies cfg.batch_size is valid
        try:
            # Create a dummy input tensor matching your described training data sample shape
            # (batch_size, channels, height, width)
            dummy_input = torch.randn(cfg.batch_size, 5, 1000, 70).to(cfg.device)
            print(f"\nTesting forward pass with dummy input shape: {dummy_input.shape}")

            with torch.no_grad(): # No need for gradients during this shape check
                 dummy_output = model(dummy_input)

            print(f"Dummy model output shape: {dummy_output.shape}")

            # Define the expected output shape
            expected_output_shape = (cfg.batch_size, 1, 70, 70)

            # Compare the actual output shape with the expected shape
            if dummy_output.shape != expected_output_shape:
                 print(f"Warning: Model output shape {dummy_output.shape} does not match expected {expected_output_shape}.")
            else:
                 print("Model output shape matches expected shape.")

        except Exception as e:
            print(f"Error during dummy model forward pass: {e}")
    # else:
    #     print("\nSkipping dummy model test as train_dl is not available.")

Instantiating the simple model...
Loading timm model: resnet18 with num_classes=0 and pretrained=True
Backbone loaded successfully.
Found first convolutional layer: 'conv1' with 3 input channels.
Adapting first convolutional layer from 3 to 5 input channels.
Copying weights for initial 3 input channels from pretrained model.
Replaced 'conv1' layer.
Determining backbone output feature map shape using a dummy tensor...
Dummy input tensor shape: torch.Size([1, 5, 1000, 70]) on device: cpu
Backbone `forward_features` output shape: torch.Size([1, 512, 32, 3])
Determined backbone output channels: 512
Added channel reducer layer: 512 -> 1 channels.
Model head will reduce channels to 1 and interpolate to target spatial size (70, 70).
Model instantiated successfully.
Model summary:
SimpleTimmModel(
  (backbone): ResNet(
    (conv1): Conv2d(5, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tr

In [8]:
model

SimpleTimmModel(
  (backbone): ResNet(
    (conv1): Conv2d(5, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (drop_block): Identity()
        (act1): ReLU(inplace=True)
        (aa): Identity()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padd