<a href="https://colab.research.google.com/github/harshitajain523/TransCMFD_Copy_Move_Forgery_Detection/blob/main/TransCMFD_Baseline_Architecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")


#mounting google drive
from google.colab import drive
drive.mount('/content/drive')


PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA device name: NVIDIA A100-SXM4-40GB
Mounted at /content/drive


In [2]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F

#  Helper Modules
class ConvBlock(nn.Module):

    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.gn = nn.GroupNorm(32, out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.gn(x)
        if self.with_nonlinearity:
            x = self.relu(x)
        return x
class Bridge(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bridge = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x):
        return self.bridge(x)

# ----------------------------------------------------------------------------- #
#upblock
class UpBlockForUNetWithResNet50(nn.Module):
    #Consists of Upsample ->(Concatenation with skip connection)->ConvBlock.
    def __init__(self, in_channels_after_concat, out_channels, up_conv_in_channels, up_conv_out_channels,
                 upsampling_method="bilinear"):
        super().__init__()

        self.upsampling_method = upsampling_method
        if upsampling_method == "conv_transpose":
            self.upsample_layer = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
        elif upsampling_method == "bilinear":
            self.upsample_layer = nn.Upsample(scale_factor=2.0, mode='bilinear', align_corners=False)
        else:
            raise ValueError("Unsupported upsampling_method")

        self.conv_block = ConvBlock(in_channels_after_concat, out_channels)

    def forward(self, up_x, down_x):
        up_x = self.upsample_layer(up_x)

        if up_x.shape[2] != down_x.shape[2] or up_x.shape[3] != down_x.shape[3]:
            up_x = F.interpolate(up_x, size=(down_x.shape[2], down_x.shape[3]), mode='bilinear', align_corners=False)

        x = torch.cat([up_x, down_x], 1)
        x = self.conv_block(x)
        return x


# ----------------------------------------------------------------------------- #

#CNN Encoder
class Encoder(nn.Module):
    #CNN Encoder based on a pre-trained ResNet-50, using the channels for resnet 50 itself.
    DEPTH = 6

    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet.resnet50(pretrained=True)##using pretrained weights

        self.input_block = nn.Sequential(*list(resnet.children()))[:4]

        down_blocks = []
        for bottleneck_stage in list(resnet.children())[4:]:
            if isinstance(bottleneck_stage, nn.Sequential):
                down_blocks.append(bottleneck_stage)
        self.down_blocks = nn.ModuleList(down_blocks)

    def forward(self, x):
        pre_pools = dict()
        pre_pools["layer_0"] = x #original input image for the very last decoder stage

        x = self.input_block(x)
        pre_pools["layer_1"] = x

        for i, block in enumerate(self.down_blocks):
            x = block(x)
            if i < len(self.down_blocks) - 1: #ALL but the last one (layer4 output is going to bridge)
                pre_pools[f"layer_{i+2}"] = x

        return x, pre_pools
# ----------------------------------------------------------------------------- #
class Decoder(nn.Module):

    def __init__(self, n_classes=2):
        super().__init__()
        # Decoder stages: in_channels_after_concat refers to channels after upsample + skip concat.
        # Channels adapted for ResNet50 outputs.
        # For 512x512 input, spatial resolutions are:
        # Encoder: Input -> 128x128 (layer1, input_block) -> 64x64 (layer2) -> 32x32 (layer3) -> 16x16 (layer4/bridge)
        # Decoder stages will upsample: 16x16 -> 32x32 -> 64x64 -> 128x128 -> 256x256 -> 512x512

        #Up-block 1: Fuses deepest feature (from bridge/transformer) with ResNet.layer3 output
        # up_x (from bridge) is 2048 channels, 16x16. Skip (pre_pools["layer_4"]) is 1024 channels, 32x32.
        # Concatenated: 2048 + 1024 = 3072 channels. Output: 1024 channels, 32x32 spatial.
        self.up_block1 = UpBlockForUNetWithResNet50(
            in_channels_after_concat=2048 + 1024, out_channels=1024,
            up_conv_in_channels=2048, up_conv_out_channels=2048
        )
        # Up-block 2: Fuses output of up_block1 with ResNet.layer2 output
        # up_x (from up_block1) is 1024 channels, 32x32. Skip (pre_pools["layer_3"]) is 512 channels, 64x64.
        # Concatenated: 1024 + 512 = 1536 channels. Output: 512 channels, 64x64 spatial.
        self.up_block2 = UpBlockForUNetWithResNet50(
            in_channels_after_concat=1024 + 512, out_channels=512,
            up_conv_in_channels=1024, up_conv_out_channels=1024
        )
        # Up-block 3: Fuses output of up_block2 with ResNet.layer1 output
        # up_x (from up_block2) is 512 channels, 64x64. Skip (pre_pools["layer_2"]) is 256 channels, 128x128.
        # Concatenated: 512 + 256 = 768 channels. Output: 256 channels, 128x128 spatial.
        self.up_block3 = UpBlockForUNetWithResNet50(
            in_channels_after_concat=512 + 256, out_channels=256,
            up_conv_in_channels=512, up_conv_out_channels=512
        )
        # Up-block 4: Fuses output of up_block3 with Encoder's input_block output (after manual upsampling)
        # up_x (from up_block3) is 256 channels, 128x128. Skip (pre_pools["layer_1"]) is 64 channels, 128x128.
        # After explicit upsampling of `pre_pools["layer_1"]` to 256x256 in forward,
        # Concatenated: 256 + 64 = 320 channels. Output: 128 channels, 256x256 spatial.
        self.up_block4 = UpBlockForUNetWithResNet50(
            in_channels_after_concat=256 + 64, out_channels=128,
            up_conv_in_channels=256, up_conv_out_channels=256
        )

        # Final upsampling and convolution as per model summary
        self.last_upsample = nn.Upsample(scale_factor=2.0, mode='bilinear', align_corners=False)
        self.last_conv = ConvBlock(in_channels=128, out_channels=64, with_nonlinearity=True)

        # Output layer to produce the binary tampering mask
        self.out = nn.Conv2d(64, n_classes, kernel_size=(1, 1), stride=(1, 1))




    def forward(self, x, pre_pools):
        # x is the deepest feature (from bridge/transformer path), 16x16 spatial
        x = self.up_block1(x, pre_pools["layer_4"]) # Output: 1024 channels, 32x32 spatial
        x = self.up_block2(x, pre_pools["layer_3"]) # Output: 512 channels, 64x64 spatial
        x = self.up_block3(x, pre_pools["layer_2"]) # Output: 256 channels, 128x128 spatial

        #############CRITICAL FIX#####: Explicitly upsample pre_pools["layer_1"] to 256x256 for up_block4
        #this ensures spatial compatibility for concatenation with up_x (which becomes 256x256 after upsampling in UpBlock).
        upsampled_skip_layer1 = F.interpolate(
            pre_pools["layer_1"],
            size=(256, 256), #Target size for this skip connection
            mode='bilinear',
            align_corners=False
        )
        x = self.up_block4(x, upsampled_skip_layer1) # \Output: 128 channels, 256x256 spatial.

        # Final upsampling and convolution to reach original input resolution (512x512)
        x = self.last_upsample(x) # Upsamples 256x256 to 512x512
        x = self.last_conv(x)      # Applies ConvBlock (128 channels -> 64 channels, 512x512)

        x = self.out(x)
        return x

# ----------------------------------------------------------------------------- #
class FeatureSimilarityModule(nn.Module):
    """
    Feature Similarity Module (FSM) implementing block-wise Pearson correlation and percentile pooling.
    """
    def __init__(self):
        super(FeatureSimilarityModule, self).__init__()
        self.K = 32 # Number of top similarity scores to select per block

    def _pearson_correlation_coefficient(self, B_i, B_j):
        """
        Calculates the Pearson correlation coefficient between two flattened feature blocks.
        """
        B_i = B_i.float()
        B_j = B_j.float()

        mean_B_i = torch.mean(B_i)
        mean_B_j = torch.mean(B_j)

        std_B_i = torch.std(B_i, unbiased=False)
        std_B_j = torch.std(B_j, unbiased=False)

        if std_B_i == 0 or std_B_j == 0:
            return torch.tensor(0.0, device=B_i.device)

        normalized_B_i = (B_i - mean_B_i) / std_B_i
        normalized_B_j = (B_j - mean_B_j) / std_B_j

        correlation = torch.dot(normalized_B_i, normalized_B_j) / B_i.numel()
        return correlation

    def forward(self, feature_map):
        # Expected input `feature_map` is the "image-like feature F" from transformer output,
        # anticipated to be `(batch_size, channels=512, H=256, W=256)`.
        batch_size, channels, H, W = feature_map.shape

        block_size_spatial = 16 #Original Paper specifies 16x16 non-overlapping blocks
        num_blocks_h = H // block_size_spatial
        num_blocks_w = W // block_size_spatial
        num_blocks = num_blocks_h * num_blocks_w #Results in 256 total blocks for 256x256 input

        #Self-Correlation Calculation Block: For feature map into blocks
        blocks_unfolded = feature_map.unfold(2, block_size_spatial, block_size_spatial).unfold(3, block_size_spatial, block_size_spatial)

        #Flatten each block into a 1D feature vector for similarity computation
        blocks_flat = blocks_unfolded.permute(0, 2, 3, 1, 4, 5).contiguous().view(
            batch_size, num_blocks, -1
        )

        #Compute pairwise Pearson correlation coefficients, forming a similarity matrix
        similarity_matrix = torch.zeros(batch_size, num_blocks, num_blocks, device=feature_map.device)
        for b in range(batch_size):
            for i in range(num_blocks):
                for j in range(num_blocks):
                    similarity_matrix[b, i, j] = self._pearson_correlation_coefficient(blocks_flat[b, i], blocks_flat[b, j])

        #Percentile Pooling Block: Select top K similarities
        percentile_scores = torch.zeros(batch_size, num_blocks, self.K, device=feature_map.device)
        for b in range(batch_size):
            for i in range(num_blocks):
                sorted_scores = torch.sort(similarity_matrix[b, i], descending=True).values

                if self.K <= sorted_scores.shape[0]:
                    percentile_scores[b, i] = sorted_scores[:self.K]
                else:
                    percentile_scores[b, i, :sorted_scores.shape[0]] = sorted_scores

        return percentile_scores ####Output shape: (batch_size, 256, self.K)

#Adaptive Transformer Components
# Custom Adaptive Multi-Head Self-Attention for the transformer's core innovation as given in the paper
class AdaptiveMultiHeadSelfAttention(nn.Module):
    """
    Custom Adaptive Multi-Head Self-Attention (AdaptiveMSA) module.
    Implements the Dual-Path Adaptive Attention Mechanism (DPAAM) as described in Algorithm 1.
    """
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads # d_k = d_v = d_model / h

        if self.head_dim * num_heads != self.embed_dim:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})")

        #Linear transformations for Query, Key, Value
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        #Learnable parameters for DPAAM: S1, S2, b1, b2, epsilon, delta
        self.S1 = nn.Parameter(torch.ones(embed_dim)) #Initialized to 1 for identity-like start
        self.S2 = nn.Parameter(torch.ones(embed_dim)) #Applied element-wise across the embedding dimension

        self.b1 = nn.Parameter(torch.zeros(embed_dim)) #Learnable bias vectors
        self.b2 = nn.Parameter(torch.zeros(embed_dim))

        self.epsilon = nn.Parameter(torch.tensor(0.5)) #Weighting factors (initially balanced)
        self.delta = nn.Parameter(torch.tensor(0.5))
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        batch_size, seq_len, embed_dim = query.size()

        #Project Query, Key, Value
        Q = self.q_proj(query)
        K = self.k_proj(key)
        V = self.v_proj(value)

        #Reshape for multi-head attention: (B, SeqLen, EmbedDim) -> (B, NumHeads, SeqLen, HeadDim)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        #Compute Attention Weights (QK^T / sqrt(d_k))
        attention_weights = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)

        #Apply Softmax to get probabilities
        attention_probs = F.softmax(attention_weights, dim=-1)

        #Calculate Head_i' (standard attention output before adaptivity)
        Head_prime = torch.matmul(attention_probs, V)

        #Reshape Head_prime back to (batch_size, seq_len, embed_dim) for element-wise DPAAM ops
        Head_prime_combined = Head_prime.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)

        #Apply DPAAM (Dual-Path Adaptive Attention Mechanism)
        M1_Head = self.S1 * Head_prime_combined + self.b1 # Element-wise scale and bias
        M2_Head = self.S2 * Head_prime_combined + self.b2 # Element-wise scale and bias

        #Apply sigmoid to epsilon and delta to constrain them to (0,1) as per paper
        epsilon_val = torch.sigmoid(self.epsilon)
        delta_val = torch.sigmoid(self.delta)

        #Combine the mapped heads with epsilon and delta weights
        adaptive_heads_combined = epsilon_val * M1_Head + delta_val * M2_Head

        #Final output projection (W_0)
        output = self.out_proj(adaptive_heads_combined)
        return output

#AdaptiveTransformerLayer (UPDATED to use custom AdaptiveMSA)
class AdaptiveTransformerLayer(nn.Module):
    """
    A single layer of the Adaptive Transformer Encoder, incorporating AdaptiveMSA and MLP.
    """
    def __init__(self, dim, num_heads):
        super().__init__()
        #Using the custom AdaptiveMultiHeadSelfAttention module
        self.adaptive_mhsa = AdaptiveMultiHeadSelfAttention(embed_dim=dim, num_heads=num_heads)

        self.norm1 = nn.LayerNorm(dim) #layer Normalization
        self.norm2 = nn.LayerNorm(dim)

        self.mlp = nn.Sequential( #MLP (Feedforward Network)
            nn.Linear(dim, dim * 4), #Common expansion factor of 4
            nn.GELU(), #GELU activation is common in modern transformers
            nn.Linear(dim * 4, dim)
        )

    def forward(self, x):
        #x is expected as token sequence: (batch_size, sequence_length, embedding_dimension)
        attn_output = self.adaptive_mhsa(x, x, x) #Custom MHSA call
        x = self.norm1(x + attn_output) #Add residual connection and apply LayerNorm

        x_mlp = self.mlp(x)
        x = self.norm2(x + x_mlp) #Add residual connection and apply LayerNorm
        return x

#AdaptiveTransformerEncoder (Uses AdaptiveTransformerLayer)
class AdaptiveTransformerEncoder(nn.Module):
    """
    The Adaptive Transformer Encoder, stacking multiple AdaptiveTransformerLayers.
    Handles Tokenization and Positional Embedding for CNN features.
    """
    def __init__(self, dim, num_heads, num_layers):
        super().__init__()
        self.dim = dim #Transformer's embedding dimension
        self.num_layers = num_layers

        self.layers = nn.ModuleList([AdaptiveTransformerLayer(dim, num_heads) for _ in range(num_layers)])

        #Tokenization: Projects CNN encoder's C4 output channels (2048) to transformer's `dim`
        self.proj_to_dim = nn.Conv2d(2048, dim, kernel_size=1)

        #Learnable Positional Embedding
        #Assuming fixed 16x16 spatial size for tokens (256 tokens) from a 512x512 input.
        self.pos_embedding = nn.Parameter(torch.randn(1, 16 * 16, dim))

    def forward(self, feature_map_c4):
        # `feature_map_c4` is the deepest feature from the CNN Encoder (Bridge input):
        #Expected shape: (batch_size, 2048, H_spatial, W_spatial) (e.g., 16x16 for 512x512 input)

        x = self.proj_to_dim(feature_map_c4) # Output: (B, dim, H_spatial, W_spatial)
        x = x.flatten(2).permute(0, 2, 1) # (B, H_spatial * W_spatial, dim) -> e.g., (B, 256, dim)

        if x.shape[1] != self.pos_embedding.shape[1]:
            raise ValueError(f"Positional embedding sequence length mismatch. Expected {self.pos_embedding.shape[1]}, got {x.shape[1]}")

        x = x + self.pos_embedding #Add learnable positional encoding to tokens

        #Pass through the stack of transformer layers
        for layer in self.layers:
            x = layer(x)
        return x #Output tokens: (batch_size, sequence_length, embedding_dimension)


#Main TransCMFDBaseline Model

class TransCMFDBaseline(nn.Module):
    """
    The complete TransCMFD model architecture.
    """
    def __init__(self, n_classes=2):
        super().__init__()
        self.encoder = Encoder()
        self.bridge = Bridge(2048, 2048)

        #Adaptive Transformer Encoder setup
        transformer_dim = 512
        transformer_heads = 8
        transformer_layers = 2 #Reduced to 2 layers as per plan for time constraint

        self.adaptive_transformer_encoder = AdaptiveTransformerEncoder(
            dim=transformer_dim,
            num_heads=transformer_heads,
            num_layers=transformer_layers
        )

        #Transforms Transformer Output (tokens) to FSM Input (spatial feature map F)
        #Transformer output: (B, seq_len=256, dim=512)
        #FSM expects input `F` as: (B, channels=512, H=256, W=256)

        self.transformer_output_to_fsm_input = nn.Upsample(size=(256, 256), mode='bilinear', align_corners=False)

        self.fsm = FeatureSimilarityModule() # Feature Similarity Module

        #Transforms FSM output back to a spatial feature map for fusion with decoder path
        #FSM output: (B, num_blocks=256, K=32)
        #This layer takes `(B, K, 16, 16)` (after internal reshape) and projects channels `K` to `2048`
        #to match `x_bridge` for fusion.
        self.fsm_output_fusion_transform = nn.Conv2d(self.fsm.K, 2048, kernel_size=1)

        self.decoder = Decoder(n_classes=n_classes) # CNN Decoder

    def forward(self, x):
        # 1.CNN Encoder: Extracts local features and provides skip connections
        encoder_output_c4, pre_pools = self.encoder(x)

        # 2.Bridge: Connects encoder's deepest feature to transformer/decoder path
        x_bridge = self.bridge(encoder_output_c4) # (B, 2048, 16, 16) for 512x512 input

        # 3.Adaptive Transformer Encoder: Learns global representations from tokens
        transformer_output_tokens = self.adaptive_transformer_encoder(encoder_output_c4) # Output: (B, 256, dim=512)

        # 4.Transform Transformer Output to FSM Input ("image-like feature F")
        # Reshape transformer tokens `(B, 256, 512)` into a spatial feature map `(B, 512, 16, 16)`
        fsm_input_spatial_reshaped = transformer_output_tokens.permute(0, 2, 1).contiguous().view(
            transformer_output_tokens.size(0),
            self.adaptive_transformer_encoder.dim, # Channels (512)
            16, 16 # Spatial dimensions, derived from sequence length (256 = 16*16)
        )
        #Upsample this `(B, 512, 16, 16)` to `(B, 512, 256, 256)` as FSM expects
        fsm_input = self.transformer_output_to_fsm_input(fsm_input_spatial_reshaped)

        # 5.Feature Similarity Module (FSM): Identifies similar regions
        fsm_output_raw = self.fsm(fsm_input) # Output: (B, 256, K=32)

        # 6.Fuse FSM output with decoder path
        # Reshape FSM output `(B, 256, K=32)` into `(B, K, 16, 16)` for spatial compatibility
        fsm_output_spatial_for_fusion = fsm_output_raw.permute(0, 2, 1).contiguous().view(
            fsm_output_raw.size(0), self.fsm.K, 16, 16
        )
        #Project `K` channels to `2048` channels to match `x_bridge` for element-wise addition (fusion)
        fsm_output_fused = self.fsm_output_fusion_transform(fsm_output_spatial_for_fusion)

        #Fusion: Add FSM contribution to the bridge output (main decoder input)
        fused_decoder_input = x_bridge + fsm_output_fused

        #7.CNN Decoder: Reconstructs the mask from fused features
        output = self.decoder(fused_decoder_input, pre_pools)
        return output



#Dummy Data Testing
#This block will test the model's forward pass and print its structure.
if __name__ == '__main__':
    print("Initializing TransCMFDBaseline model...")
    model = TransCMFDBaseline().cuda() # Instantiate model and move to GPU
    print("TransCMFDBaseline model successfully loaded on cuda.")

    #Print the model summary to verify custom modules that I have used in this architecture
    print("\n--- Model Architecture Summary ---")
    print(model)
    print("----------------------------------\n")

    dummy_input = torch.rand((2, 3, 512, 512)).cuda() #Dummy input: Batch size 2, 3 channels (RGB), 512x512 resolution of image
    print(f"Dummy input shape: {dummy_input.shape}")

    with torch.no_grad(): # Disable gradient computation for faster dummy pass (just for dummy input usage)
        output = model(dummy_input)

    print(f"Output shape: {output.shape}")

Initializing TransCMFDBaseline model...


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 226MB/s]


TransCMFDBaseline model successfully loaded on cuda.

--- Model Architecture Summary ---
TransCMFDBaseline(
  (encoder): Encoder(
    (input_block): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (down_blocks): ModuleList(
      (0): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
     