In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm # We import timm directly
import numpy as np

# --- Helper Module 1: High-Frequency Extractor ---
class HighFrequencyExtractor(nn.Module):
    def __init__(self, in_channels=3):
        super(HighFrequencyExtractor, self).__init__()
        # (Using 3 input channels for RGB)
        sobel_x = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=np.float32)
        sobel_y = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=np.float32)
        sobel_weights_xy = np.stack([sobel_x, sobel_y], axis=0) # (2, 3, 3)
        
        final_weights = np.zeros((2 * in_channels, in_channels, 3, 3), dtype=np.float32)
        for i in range(in_channels):
            final_weights[i*2 : (i+1)*2, i, :, :] = sobel_weights_xy
            
        self.conv = nn.Conv2d(in_channels, 2 * in_channels, kernel_size=3, padding=1, groups=in_channels, bias=False)
        self.conv.weight = nn.Parameter(torch.from_numpy(final_weights), requires_grad=False)

    def forward(self, x):
        return self.conv(x)

# --- Helper Module 2: Standard Convolution Block ---
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

# --- Helper Module 3: Standard Decoder Block ---
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels // 2 + skip_channels, out_channels)

    def forward(self, x_up, x_skip):
        x_up = self.up(x_up)
        # Pad to handle potential size mismatch
        diffY = x_skip.size()[2] - x_up.size()[2]
        diffX = x_skip.size()[3] - x_up.size()[3]
        x_up = F.pad(x_up, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x_skip, x_up], dim=1)
        return self.conv(x)

In [None]:
class HFANet_timm(nn.Module):
    def __init__(self, encoder_name='resnet34', classes=1, pretrained=True, in_channels=3):
        super(HFANet_timm, self).__init__()
        
        # --- 1. Spatial Stream (Backbone) ---
        # Use timm to get just the backbone
        self.backbone = timm.create_model(
            encoder_name,
            pretrained=pretrained,
            features_only=True,
            in_chans=in_channels
        )
        
        # Get the channel counts from the backbone
        # For resnet34, this is [64, 64, 128, 256, 512]
        sp_channels = self.backbone.feature_info.channels()

        # --- 2. High-Frequency Stream ---
        self.hf_extractor = HighFrequencyExtractor(in_channels=in_channels)
        hf_channels = in_channels * 2 # 3*2 = 6

        # --- 3. HFA Module (to process the skip connections) ---
        # This module will process the HF diff and the *first* spatial diff map
        # sp_channels[0] is the stem (64 channels, 1/2 res)
        # sp_channels[1] is layer1 (64 channels, 1/4 res)
        
        # Let's apply HFA to the features from stage 1 (1/4 res)
        self.hf_attention_head = nn.Sequential(
            nn.Conv2d(hf_channels, 16, kernel_size=3, stride=2, padding=1, bias=False), # 1/2 res
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 1, kernel_size=3, stride=2, padding=1, bias=False), # 1/4 res
            nn.Sigmoid()
        )
        
        # Fusion conv for the HFA-applied features
        # It takes (raw spatial diff) + (attended spatial diff)
        self.fusion_s1 = nn.Sequential(
            nn.Conv2d(sp_channels[1] * 2, sp_channels[1], kernel_size=1, bias=False),
            nn.BatchNorm2d(sp_channels[1]),
            nn.ReLU(inplace=True)
        )

        # --- 4. Manual Decoder ---
        # We now have to build the decoder by hand
        self.dec_layer4 = DecoderBlock(sp_channels[4], sp_channels[3], 256)
        self.dec_layer3 = DecoderBlock(256, sp_channels[2], 128)
        self.dec_layer2 = DecoderBlock(128, sp_channels[1], 64) # HFA will be applied to skip[1]
        self.dec_layer1 = DecoderBlock(64, sp_channels[0], 64)
        
        self.final_up = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.final_conv = nn.Conv2d(32, classes, kernel_size=1)

    def forward(self, x1, x2):
        
        # --- 1. High-Frequency Stream ---
        hf_t1 = self.hf_extractor(x1) # (B, 6, 256, 256)
        hf_t2 = self.hf_extractor(x2) # (B, 6, 256, 256)
        hf_diff = torch.abs(hf_t1 - hf_t2)

        # --- 2. Spatial Stream (Siamese) ---
        # Returns a list of 5 feature maps (stem, layer1, layer2, layer3, layer4)
        sp_features_t1 = self.backbone(x1)
        sp_features_t2 = self.backbone(x2)
        
        # Get the differences for all skip connections
        d0 = torch.abs(sp_features_t1[0] - sp_features_t2[0])
        d1 = torch.abs(sp_features_t1[1] - sp_features_t2[1])
        d2 = torch.abs(sp_features_t1[2] - sp_features_t2[2])
        d3 = torch.abs(sp_features_t1[3] - sp_features_t2[3])
        d4 = torch.abs(sp_features_t1[4] - sp_features_t2[4]) # Bottleneck

        # --- 3. Apply HFA Module ---
        # Create mask from HF features (at 1/4 res)
        hf_attention_mask = self.hf_attention_head(hf_diff)
        
        # Apply attention to the corresponding spatial difference map (d1)
        attended_d1 = d1 * hf_attention_mask
        
        # Fuse the attended and original features
        fused_d1 = self.fusion_s1(torch.cat([d1, attended_d1], dim=1))
        
        # --- 4. Manual Decoder Path ---
        # We pass the fused_d1 as the skip connection
        dec4 = self.dec_layer4(d4, d3)
        dec3 = self.dec_layer3(dec4, d2)
        dec2 = self.dec_layer2(dec3, fused_d1) # <- HFA is applied here
        dec1 = self.dec_layer1(dec2, d0)
        
        out = self.final_up(dec1)
        out = self.final_conv(out)
        
        return out