In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp
import numpy as np

class HighFrequencyExtractor(nn.Module):
    """
    Extracts high-frequency features (edges) using a fixed Sobel filter.
    This module is NOT trained.
    """
    def __init__(self, in_channels=1):
        super(HighFrequencyExtractor, self).__init__()
        self.in_channels = in_channels
        
        # Define Sobel filters
        # Sobel X
        sobel_x = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=np.float32)
        # Sobel Y
        sobel_y = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=np.float32)
        
        # Stack them to create a (out_channels=2, in_channels=1, 3, 3) weight
        # We will process each input channel independently
        sobel_weights = np.stack([sobel_x, sobel_y])[np.newaxis, ...]
        # This (1, 2, 3, 3) is for 1 in_channel. We need to repeat it.
        
        # We create a filter for each input channel (R, G, B)
        # Final weights shape: (out_channels=2*in_channels, in_channels, 3, 3)
        # We use 'groups=in_channels' to apply [sobel_x, sobel_y] to each input channel
        
        sobel_weights_xy = np.stack([sobel_x, sobel_y], axis=0) # Shape (2, 3, 3)
        
        # Create a (2*in_channels, in_channels, 3, 3) weight tensor
        final_weights = np.zeros((2 * in_channels, in_channels, 3, 3), dtype=np.float32)
        
        for i in range(in_channels):
            final_weights[i*2 : (i+1)*2, i, :, :] = sobel_weights_xy

        # Create the Conv2d layer
        self.conv = nn.Conv2d(in_channels, 2 * in_channels, kernel_size=3, padding=1, groups=in_channels, bias=False)
        
        # Set the weights and make them non-trainable
        self.conv.weight = nn.Parameter(torch.from_numpy(final_weights), requires_grad=False)

    def forward(self, x):
        # Apply the fixed Sobel filter
        return self.conv(x)


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class HFANet(nn.Module):
    def __init__(self, encoder_name='resnet34', classes=1, pretrained='imagenet', in_channels=3):
        super(HFANet, self).__init__()
        
        # --- 1. Spatial Stream (Backbone) ---
        # We use smp to get a pre-built U-Net with a ResNet encoder.
        self.smp_model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=pretrained,
            in_channels=in_channels,
            classes=classes
        )
        
        # --- 2. High-Frequency Stream ---
        self.hf_extractor = HighFrequencyExtractor(in_channels=in_channels)
        
        # The smp ResNet34 encoder's first feature map (sp_features[1])
        # has 64 channels and is at 1/2 resolution.
        # Our HF features (hf_diff) will have (in_channels * 2) channels
        # and be at full resolution.
        hf_channels = in_channels * 2
        
        # Get the channel count of the first spatial feature map
        sp_channels = self.smp_model.encoder.out_channels[1] # e.g., 64 for resnet34

        # --- 3. HFA Module ---
        # This module will create the attention mask.
        # It takes the HF diff map, downsamples it to match the spatial map,
        # and creates a 1-channel attention mask.
        self.hf_attention_head = nn.Sequential(
            nn.Conv2d(hf_channels, hf_channels, kernel_size=3, stride=2, padding=1, bias=False), # Downsample to 1/2 res
            nn.BatchNorm2d(hf_channels),
            nn.ReLU(),
            nn.Conv2d(hf_channels, 1, kernel_size=1, bias=False), # Reduce to 1 channel
            nn.Sigmoid()
        )
        
        # This module fuses the *attended* spatial features
        # It takes the (attended spatial diff) + (raw spatial diff)
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(sp_channels * 2, sp_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(sp_channels),
            nn.ReLU()
        )


    def forward(self, x1, x2):
        
        # --- 1. High-Frequency (HF) Stream ---
        hf_t1 = self.hf_extractor(x1) # (B, 6, 256, 256)
        hf_t2 = self.hf_extractor(x2) # (B, 6, 256, 256)
        
        # Get the high-frequency difference
        hf_diff = torch.abs(hf_t1 - hf_t2)

        # --- 2. Spatial (Semantic) Stream ---
        # Get the list of features from the encoder for both images
        sp_features_t1 = self.smp_model.encoder(x1)
        sp_features_t2 = self.smp_model.encoder(x2)

        # --- 3. High-Frequency Attention (HFA) ---
        
        # A) Create the attention mask from the HF stream
        # This mask will be at 1/2 resolution (e.g., 128x128)
        hf_attention_mask = self.hf_attention_head(hf_diff) # (B, 1, 128, 128)
        
        # B) Get the spatial difference at the same 1/2 resolution
        # sp_features[0] is the input, sp_features[1] is the first stage
        sp_diff_s1 = torch.abs(sp_features_t1[1] - sp_features_t2[1]) # (B, 64, 128, 128)
        
        # C) Apply the HFA
        # Multiply the spatial difference by the HF attention mask
        attended_sp_diff_s1 = sp_diff_s1 * hf_attention_mask
        
        # D) Fuse the attended features with the original spatial difference
        # This allows the network to learn how much HF attention to use
        fused_s1 = self.fusion_conv(torch.cat([sp_diff_s1, attended_sp_diff_s1], dim=1))

        # --- 4. Create Fused Feature List for Decoder ---
        # Get the simple differences for all *other* (deeper) stages
        features_diff = [fused_s1] # Our special HFA-fused features
        for i in range(2, len(sp_features_t1)):
            features_diff.append(torch.abs(sp_features_t1[i] - sp_features_t2[i]))

        # --- 5. Decoder Pass ---
        # Pass the list of *difference features* to the decoder
        change_logits = self.smp_model.decoder(*features_diff)
        change_mask = self.smp_model.segmentation_head(change_logits)
        
        return change_mask

In [4]:
model = HFANet(encoder_name='resnet34', classes=1, pretrained='imagenet', in_channels=3)
print(model)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

HFANet(
  (smp_model): Unet(
    (encoder): ResNetEncoder(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): Batc

HFANet(
  (smp_model): Unet(
    (encoder): ResNetEncoder(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): Batc