In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ASPP(nn.Module):
    """Atrous Spatial Pyramid Pooling (ASPP) module."""
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()
        dilations = [1, 6, 12, 18]
        
        self.aspp1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU())
        self.aspp2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=dilations[1], dilation=dilations[1], bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU())
        self.aspp3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=dilations[2], dilation=dilations[2], bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU())
        self.aspp4 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=dilations[3], dilation=dilations[3], bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU())

        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU())

        self.conv1 = nn.Sequential(
            nn.Conv2d(out_channels * 5, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU())
        # Dropout layer added
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
        
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)
        x = self.conv1(x)
        return self.dropout(x)

In [None]:
class DifferenceAttentionModule(nn.Module):
    """
    The 'Difference attention' block from the diagram.
    It learns a mask to apply to the absolute difference.
    """
    def __init__(self, in_channels):
        super(DifferenceAttentionModule, self).__init__()
        # A simple gate to learn the attention mask
        self.attention_gate = nn.Sequential(
            nn.Conv2d(in_channels * 2, in_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, 1, kernel_size=1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, f_t1, f_t2):
        f_concat = torch.cat([f_t1, f_t2], dim=1)
        f_diff = torch.abs(f_t1 - f_t2)
        
        attention_mask = self.attention_gate(f_concat)
        
        # Refine the difference by multiplying it with the learned mask
        return f_diff * attention_mask


class HDANet_Head(nn.Module):
    """
    The full head of the HDANet, which corresponds to the 
    'ASPP', 'Difference attention', and 'Upsample' blocks in the diagram.
    """
    def __init__(self, in_channels_list, out_channels=256, num_classes=1):
        super(HDANet_Head, self).__init__()
        
        # Create ASPP and DAM for each feature level
        self.aspp_modules = nn.ModuleList()
        self.dam_modules = nn.ModuleList()
        
        for in_c in in_channels_list:
            # As per diagram, ASPP is applied to each branch's features
            # We'll have ASPP output a consistent 'out_channels'
            self.aspp_modules.append(ASPP(in_c, out_channels))
            # The DAM will then process these 256-channel features
            self.dam_modules.append(DifferenceAttentionModule(out_channels))
            
        # Final fusion and upsampling layers
        self.fuse_conv = nn.Sequential(
            nn.Conv2d(len(in_channels_list) * out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        
        self.final_classifier = nn.Sequential(
            nn.Conv2d(out_channels, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, num_classes, kernel_size=1)
        )

    def forward(self, t1_features, t2_features):
        # t1_features and t2_features are lists of tensors from the HRNet backbone
        
        refined_diff_maps = []
        # Target size is the largest feature map (the first one)
        target_size = t1_features[0].shape[2:]
        
        for i in range(len(t1_features)):
            f_t1 = self.aspp_modules[i](t1_features[i])
            f_t2 = self.aspp_modules[i](t2_features[i]) # Using same ASPP module
            
            diff_map = self.dam_modules[i](f_t1, f_t2)
            
            # Upsample all difference maps to the same size for concatenation
            diff_map_upsampled = F.interpolate(diff_map, size=target_size, mode='bilinear', align_corners=False)
            refined_diff_maps.append(diff_map_upsampled)
        
        # Concatenate all refined, upsampled maps
        fused_diff = torch.cat(refined_diff_maps, dim=1)
        
        fused_diff = self.fuse_conv(fused_diff)
        
        # Final upsample to original image size (HRNet's first stage is 1/4)
        out = F.interpolate(fused_diff, scale_factor=4, mode='bilinear', align_corners=False)
        out = self.final_classifier(out)
        
        return out

In [None]:
import timm

class HDANet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, pretrained=True):
        super(HDANet, self).__init__()
        
        # 1. Load the HRNet backbone
        # We'll use hrnet_w18 as an example.
        # `features_only=True` makes it return a list of feature maps
        # from each stage, just like the diagram.
        self.backbone = timm.create_model(
            'hrnet_w18',
            pretrained=pretrained,
            features_only=True,
            in_chans=n_channels
        )
        
        # Get the output channel counts from the backbone
        # hrnet_w18 outputs: [64, 128, 256, 512] at 1/4, 1/8, 1/16, 1/32 res
        # BUT the diagram shows 4 parallel streams: 1/4, 1/8, 1/16, 1/32
        # `timm`'s `features_only=True` for HRNet actually returns [18, 36, 72, 144]
        # or similar for hrnet_w18. Let's verify.
        
        # Let's dynamically get the output channels
        dummy_input = torch.randn(2, n_channels, 256, 256)
        dummy_features = self.backbone(dummy_input)
        
        # This is a list of channel dims for each feature map
        # e.g., for hrnet_w18: [64, 128, 256, 512]
        # NOTE: The diagram shows 4 levels. hrnet_w18 has 4 stages.
        # Let's use the output of the 4 stages.
        # For 'hrnet_w18', timm.create_model returns features with channels:
        # [64, 128, 256, 512]
        # A true HRNet from the paper has [C, 2C, 4C, 8C] at 4 resolutions
        # The timm 'features_only' output is what we want.
        
        # Let's re-init backbone to get channel numbers
        self.backbone = timm.create_model('hrnet_w18', pretrained=pretrained, features_only=True, in_chans=n_channels)
        
        # Get the feature info
        feature_channels = self.backbone.feature_info.channels()
        # This will be something like [64, 128, 256, 512]
        
        # 2. Create the custom HDANet Head
        self.head = HDANet_Head(
            in_channels_list=feature_channels,
            out_channels=256, # Internal processing dim
            num_classes=n_classes
        )

    def forward(self, x1, x2):
        # 1. Pass T1 through the SHARED backbone
        # This returns a list of 4 feature maps
        t1_features = self.backbone(x1)
        
        # 2. Pass T2 through the *EXACT SAME* backbone
        # This is where weights are shared!
        t2_features = self.backbone(x2)
        
        # 3. Pass both sets of features to our custom head
        return self.head(t1_features, t2_features)

In [None]:
# --- Test the HDANet ---
model = HDANet(n_channels=3, n_classes=1, pretrained=True)
model.eval() # Set to eval mode

# Create two dummy input tensors (batch, channels, height, width)
# HRNet expects inputs to be divisible by 32
dummy_image_t1 = torch.randn(2, 3, 256, 256)
dummy_image_t2 = torch.randn(2, 3, 256, 256)

# Pass BOTH images through the model
output_map = model(dummy_image_t1, dummy_image_t2)

# The output is your "change map" logits
# The head upsamples it back to the original size
print("HDANet Output Shape:", output_map.shape)

# HDANet Output Shape: torch.Size([2, 1, 256, 256])