<a href="https://colab.research.google.com/github/harshitajain523/TransCMFD_Copy_Move_Forgery_Detection/blob/main/TransCMFD_Baseline_Architecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")


#mounting google drive
from google.colab import drive
drive.mount('/content/drive')


PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA device name: NVIDIA L4
Mounted at /content/drive


In [2]:
'''import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F

# Display Image
from IPython.display import Image

#  Helper Modules
class ConvBlock(nn.Module):

    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.gn = nn.GroupNorm(32, out_channels)
        self.relu = nn.ReLU(inplace=True) ##inplace true for inplace changes
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.gn(x)
        if self.with_nonlinearity:
            x = self.relu(x)
        return x
class Bridge(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bridge = nn.Sequential(  ##nn.Sequential chains operations in a pipeline
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x):
        return self.bridge(x)

# ----------------------------------------------------------------------------- #
#upblock
class UpBlockForUNetWithResNet50(nn.Module):
    #Consists of Upsample ->(Concatenation with skip connection)->ConvBlock.
    def __init__(self, in_channels_after_concat, out_channels, up_conv_in_channels, up_conv_out_channels,
                 upsampling_method="bilinear"):
        super().__init__()

        self.upsampling_method = upsampling_method
        if upsampling_method == "conv_transpose":
            self.upsample_layer = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
        elif upsampling_method == "bilinear": #bilinear interpolation
            self.upsample_layer = nn.Upsample(scale_factor=2.0, mode='bilinear', align_corners=False)
        else:
            raise ValueError("Unsupported upsampling_method")

        self.conv_block = ConvBlock(in_channels_after_concat, out_channels)

    def forward(self, up_x, down_x):
        up_x = self.upsample_layer(up_x)

        if up_x.shape[2] != down_x.shape[2] or up_x.shape[3] != down_x.shape[3]:
            up_x = F.interpolate(up_x, size=(down_x.shape[2], down_x.shape[3]), mode='bilinear', align_corners=False)

        x = torch.cat([up_x, down_x], 1)
        x = self.conv_block(x)
        return x


# ----------------------------------------------------------------------------- #

#CNN Encoder
class Encoder(nn.Module):
    #CNN Encoder based on a pre-trained ResNet-50, using the channels for resnet 50 itself.
    DEPTH = 6

    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet.resnet50(pretrained=True)##using pretrained weights

        self.input_block = nn.Sequential(*list(resnet.children()))[:4]

        down_blocks = []
        for bottleneck_stage in list(resnet.children())[4:]:
            if isinstance(bottleneck_stage, nn.Sequential):
                down_blocks.append(bottleneck_stage)
        self.down_blocks = nn.ModuleList(down_blocks)

    def forward(self, x):
        pre_pools = dict()
        pre_pools["layer_0"] = x #original input image for the very last decoder stage

        x = self.input_block(x)
        pre_pools["layer_1"] = x

        for i, block in enumerate(self.down_blocks):
            x = block(x)
            if i < len(self.down_blocks) - 1: #ALL but the last one (layer4 output is going to bridge)
                pre_pools[f"layer_{i+2}"] = x

        return x, pre_pools
# ----------------------------------------------------------------------------- #
class Decoder(nn.Module):

    def __init__(self, n_classes=2):
        super().__init__()
        # Decoder stages: in_channels_after_concat refers to channels after upsample + skip concat.
        # Channels adapted for ResNet50 outputs.
        # For 512x512 input, spatial resolutions are:
        # Encoder: Input -> 128x128 (layer1, input_block) -> 64x64 (layer2) -> 32x32 (layer3) -> 16x16 (layer4/bridge)
        # Decoder stages will upsample: 16x16 -> 32x32 -> 64x64 -> 128x128 -> 256x256 -> 512x512

        #Up-block 1: Fuses deepest feature (from bridge/transformer) with ResNet.layer3 output
        # up_x (from bridge) is 2048 channels, 16x16. Skip (pre_pools["layer_4"]) is 1024 channels, 32x32.
        # Concatenated: 2048 + 1024 = 3072 channels. Output: 1024 channels, 32x32 spatial.
        self.up_block1 = UpBlockForUNetWithResNet50(
            in_channels_after_concat=2048 + 1024, out_channels=1024,
            up_conv_in_channels=2048, up_conv_out_channels=2048
        )
        # Up-block 2: Fuses output of up_block1 with ResNet.layer2 output
        # up_x (from up_block1) is 1024 channels, 32x32. Skip (pre_pools["layer_3"]) is 512 channels, 64x64.
        # Concatenated: 1024 + 512 = 1536 channels. Output: 512 channels, 64x64 spatial.
        self.up_block2 = UpBlockForUNetWithResNet50(
            in_channels_after_concat=1024 + 512, out_channels=512,
            up_conv_in_channels=1024, up_conv_out_channels=1024
        )
        # Up-block 3: Fuses output of up_block2 with ResNet.layer1 output
        # up_x (from up_block2) is 512 channels, 64x64. Skip (pre_pools["layer_2"]) is 256 channels, 128x128.
        # Concatenated: 512 + 256 = 768 channels. Output: 256 channels, 128x128 spatial.
        self.up_block3 = UpBlockForUNetWithResNet50(
            in_channels_after_concat=512 + 256, out_channels=256,
            up_conv_in_channels=512, up_conv_out_channels=512
        )
        # Up-block 4: Fuses output of up_block3 with Encoder's input_block output (after manual upsampling)
        # up_x (from up_block3) is 256 channels, 128x128. Skip (pre_pools["layer_1"]) is 64 channels, 128x128.
        # After explicit upsampling of `pre_pools["layer_1"]` to 256x256 in forward,
        # Concatenated: 256 + 64 = 320 channels. Output: 128 channels, 256x256 spatial.
        self.up_block4 = UpBlockForUNetWithResNet50(
            in_channels_after_concat=256 + 64, out_channels=128,
            up_conv_in_channels=256, up_conv_out_channels=256
        )

        # Final upsampling and convolution as per model summary
        self.last_upsample = nn.Upsample(scale_factor=2.0, mode='bilinear', align_corners=False)
        self.last_conv = ConvBlock(in_channels=128, out_channels=64, with_nonlinearity=True)

        # Output layer to produce the binary tampering mask
        self.out = nn.Conv2d(64, n_classes, kernel_size=(1, 1), stride=(1, 1))




    def forward(self, x, pre_pools):
        # x is the deepest feature (from bridge/transformer path), 16x16 spatial
        x = self.up_block1(x, pre_pools["layer_4"]) # Output: 1024 channels, 32x32 spatial
        x = self.up_block2(x, pre_pools["layer_3"]) # Output: 512 channels, 64x64 spatial
        x = self.up_block3(x, pre_pools["layer_2"]) # Output: 256 channels, 128x128 spatial

        #############CRITICAL FIX#####: Explicitly upsample pre_pools["layer_1"] to 256x256 for up_block4
        #this ensures spatial compatibility for concatenation with up_x (which becomes 256x256 after upsampling in UpBlock).
        upsampled_skip_layer1 = F.interpolate(
            pre_pools["layer_1"],
            size=(256, 256), #Target size for this skip connection
            mode='bilinear',
            align_corners=False
        )
        x = self.up_block4(x, upsampled_skip_layer1) # \Output: 128 channels, 256x256 spatial.

        # Final upsampling and convolution to reach original input resolution (512x512)
        x = self.last_upsample(x) # Upsamples 256x256 to 512x512
        x = self.last_conv(x)      # Applies ConvBlock (128 channels -> 64 channels, 512x512)

        x = self.out(x)
        return x

# ----------------------------------------------------------------------------- #
class FeatureSimilarityModule(nn.Module):
    """
    Feature Similarity Module (FSM) implementing block-wise Pearson correlation and percentile pooling.
    """
    def __init__(self):
        super(FeatureSimilarityModule, self).__init__()
        self.K = 32 # Number of top similarity scores to select per block (1/8)

    def _pearson_correlation_coefficient(self, B_i, B_j):
        """
        Calculates the Pearson correlation coefficient between two flattened feature blocks.
        """
        B_i = B_i.float()
        B_j = B_j.float()

        mean_B_i = torch.mean(B_i)
        mean_B_j = torch.mean(B_j)

        std_B_i = torch.std(B_i, unbiased=False)
        std_B_j = torch.std(B_j, unbiased=False)

        if std_B_i == 0 or std_B_j == 0:
            return torch.tensor(0.0, device=B_i.device)

        normalized_B_i = (B_i - mean_B_i) / std_B_i
        normalized_B_j = (B_j - mean_B_j) / std_B_j

        correlation = torch.dot(normalized_B_i, normalized_B_j) / B_i.numel()
        return correlation

    def forward(self, feature_map):
        # Expected input `feature_map` is the "image-like feature F" from transformer output,
        # anticipated to be `(batch_size, channels=512, H=256, W=256)`.
        batch_size, channels, H, W = feature_map.shape

        block_size_spatial = 16 #Original Paper specifies 16x16 non-overlapping blocks
        num_blocks_h = H // block_size_spatial
        num_blocks_w = W // block_size_spatial
        num_blocks = num_blocks_h * num_blocks_w #Results in 256 total blocks for 256x256 input

        #Self-Correlation Calculation Block: For feature map into blocks
        ##unfold for patch extracrion
        blocks_unfolded = feature_map.unfold(2, block_size_spatial, block_size_spatial).unfold(3, block_size_spatial, block_size_spatial)

        #Flatten each block into a 1D feature vector for similarity computation
        blocks_flat = blocks_unfolded.permute(0, 2, 3, 1, 4, 5).contiguous().view(
            batch_size, num_blocks, -1
        )

        #Compute pairwise Pearson correlation coefficients, forming a similarity matrix
        similarity_matrix = torch.zeros(batch_size, num_blocks, num_blocks, device=feature_map.device)
        for b in range(batch_size):
            for i in range(num_blocks):
                for j in range(num_blocks):
                    similarity_matrix[b, i, j] = self._pearson_correlation_coefficient(blocks_flat[b, i], blocks_flat[b, j])

        #Percentile Pooling Block: Select top K similarities
        percentile_scores = torch.zeros(batch_size, num_blocks, self.K, device=feature_map.device)
        for b in range(batch_size):
            for i in range(num_blocks):
                sorted_scores = torch.sort(similarity_matrix[b, i], descending=True).values

                if self.K <= sorted_scores.shape[0]:
                    percentile_scores[b, i] = sorted_scores[:self.K]
                else:
                    percentile_scores[b, i, :sorted_scores.shape[0]] = sorted_scores

        return percentile_scores ####Output shape: (batch_size, 256, self.K)

#Adaptive Transformer Components
# Custom Adaptive Multi-Head Self-Attention for the transformer's core innovation as given in the paper
class AdaptiveMultiHeadSelfAttention(nn.Module):
    """
    Custom Adaptive Multi-Head Self-Attention (AdaptiveMSA) module.
    Implements the Dual-Path Adaptive Attention Mechanism (DPAAM) as described in Algorithm 1.
    """
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads # d_k = d_v = d_model / h

        if self.head_dim * num_heads != self.embed_dim:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})")

        #Linear transformations for Query, Key, Value
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        #Learnable parameters for DPAAM: S1, S2, b1, b2, epsilon, delta
        self.S1 = nn.Parameter(torch.ones(embed_dim)) #Initialized to 1 for identity-like start
        self.S2 = nn.Parameter(torch.ones(embed_dim)) #Applied element-wise across the embedding dimension

        self.b1 = nn.Parameter(torch.zeros(embed_dim)) #Learnable bias vectors
        self.b2 = nn.Parameter(torch.zeros(embed_dim))

        self.epsilon = nn.Parameter(torch.tensor(0.5)) #Weighting factors (initially balanced)
        self.delta = nn.Parameter(torch.tensor(0.5))
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        batch_size, seq_len, embed_dim = query.size()

        #Project Query, Key, Value
        Q = self.q_proj(query)
        K = self.k_proj(key)
        V = self.v_proj(value)

        #Reshape for multi-head attention: (B, SeqLen, EmbedDim) -> (B, NumHeads, SeqLen, HeadDim)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        #Compute Attention Weights (QK^T / sqrt(d_k))
        attention_weights = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)

        #Apply Softmax to get probabilities
        attention_probs = F.softmax(attention_weights, dim=-1)

        #Calculate Head_i' (standard attention output before adaptivity)
        Head_prime = torch.matmul(attention_probs, V)

        #Reshape Head_prime back to (batch_size, seq_len, embed_dim) for element-wise DPAAM ops
        Head_prime_combined = Head_prime.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)

        #Apply DPAAM (Dual-Path Adaptive Attention Mechanism)
        M1_Head = self.S1 * Head_prime_combined + self.b1 # Element-wise scale and bias
        M2_Head = self.S2 * Head_prime_combined + self.b2 # Element-wise scale and bias

        #Apply sigmoid to epsilon and delta to constrain them to (0,1) as per paper
        epsilon_val = torch.sigmoid(self.epsilon)
        delta_val = torch.sigmoid(self.delta)

        #Combine the mapped heads with epsilon and delta weights
        adaptive_heads_combined = epsilon_val * M1_Head + delta_val * M2_Head

        #Final output projection (W_0)
        output = self.out_proj(adaptive_heads_combined)
        return output

#AdaptiveTransformerLayer (UPDATED to use custom AdaptiveMSA)
class AdaptiveTransformerLayer(nn.Module):
    """
    A single layer of the Adaptive Transformer Encoder, incorporating AdaptiveMSA and MLP.
    """
    def __init__(self, dim, num_heads):
        super().__init__()
        #Using the custom AdaptiveMultiHeadSelfAttention module
        self.adaptive_mhsa = AdaptiveMultiHeadSelfAttention(embed_dim=dim, num_heads=num_heads)

        self.norm1 = nn.LayerNorm(dim) #layer Normalization
        self.norm2 = nn.LayerNorm(dim)

        self.mlp = nn.Sequential( #MLP (Feedforward Network)
            nn.Linear(dim, dim * 4), #Common expansion factor of 4
            nn.GELU(), #GELU activation is common in modern transformers
            nn.Linear(dim * 4, dim)
        )

    def forward(self, x):
        #x is expected as token sequence: (batch_size, sequence_length, embedding_dimension)
        attn_output = self.adaptive_mhsa(x, x, x) #Custom MHSA call
        x = self.norm1(x + attn_output) #Add residual connection and apply LayerNorm again

        x_mlp = self.mlp(x)
        x = self.norm2(x + x_mlp) #Add residual connection and apply LayerNorm
        return x

#AdaptiveTransformerEncoder (Uses AdaptiveTransformerLayer)
class AdaptiveTransformerEncoder(nn.Module):
    """
    The Adaptive Transformer Encoder, stacking multiple AdaptiveTransformerLayers.
    Handles Tokenization and Positional Embedding for CNN features.
    """
    def __init__(self, dim, num_heads, num_layers):
        super().__init__()
        self.dim = dim #Transformer's embedding dimension
        self.num_layers = num_layers

        self.layers = nn.ModuleList([AdaptiveTransformerLayer(dim, num_heads) for _ in range(num_layers)])

        #Tokenization: Projects CNN encoder's C4 output channels (2048) to transformer's `dim`
        self.proj_to_dim = nn.Conv2d(2048, dim, kernel_size=1)

        #Learnable Positional Embedding
        #Assuming fixed 16x16 spatial size for tokens (256 tokens) from a 512x512 input.
        self.pos_embedding = nn.Parameter(torch.randn(1, 16 * 16, dim))

    def forward(self, feature_map_c4):
        # `feature_map_c4` is the deepest feature from the CNN Encoder (Bridge input):
        #Expected shape: (batch_size, 2048, H_spatial, W_spatial) (e.g., 16x16 for 512x512 input)

        x = self.proj_to_dim(feature_map_c4) # Output: (B, dim, H_spatial, W_spatial)
        x = x.flatten(2).permute(0, 2, 1) # (B, H_spatial * W_spatial, dim) -> e.g., (B, 256, dim)

        if x.shape[1] != self.pos_embedding.shape[1]:
            raise ValueError(f"Positional embedding sequence length mismatch. Expected {self.pos_embedding.shape[1]}, got {x.shape[1]}")

        x = x + self.pos_embedding #Add learnable positional encoding to tokens

        #Pass through the stack of transformer layers
        for layer in self.layers:
            x = layer(x)
        return x #Output tokens: (batch_size, sequence_length, embedding_dimension)


#Main TransCMFDBaseline Model

class TransCMFDBaseline(nn.Module):
    """
    The complete TransCMFD model architecture.
    """
    def __init__(self, n_classes=2):
        super().__init__()
        self.encoder = Encoder()
        self.bridge = Bridge(2048, 2048)

        #Adaptive Transformer Encoder setup
        transformer_dim = 512
        transformer_heads = 8
        transformer_layers = 2 #Reduced to 2 layers as per plan for time constraint

        self.adaptive_transformer_encoder = AdaptiveTransformerEncoder(
            dim=transformer_dim,
            num_heads=transformer_heads,
            num_layers=transformer_layers
        )

        #Transforms Transformer Output (tokens) to FSM Input (spatial feature map F)
        #Transformer output: (B, seq_len=256, dim=512)
        #FSM expects input `F` as: (B, channels=512, H=256, W=256)

        self.transformer_output_to_fsm_input = nn.Upsample(size=(256, 256), mode='bilinear', align_corners=False)

        self.fsm = FeatureSimilarityModule() # Feature Similarity Module

        #Transforms FSM output back to a spatial feature map for fusion with decoder path
        #FSM output: (B, num_blocks=256, K=32)
        #This layer takes `(B, K, 16, 16)` (after internal reshape) and projects channels `K` to `2048`
        #to match `x_bridge` for fusion.
        self.fsm_output_fusion_transform = nn.Conv2d(self.fsm.K, 2048, kernel_size=1)

        self.decoder = Decoder(n_classes=n_classes) # CNN Decoder

    def forward(self, x):
        # 1.CNN Encoder: Extracts local features and provides skip connections
        encoder_output_c4, pre_pools = self.encoder(x)

        # 2.Bridge: Connects encoder's deepest feature to transformer/decoder path
        x_bridge = self.bridge(encoder_output_c4) # (B, 2048, 16, 16) for 512x512 input

        # 3.Adaptive Transformer Encoder: Learns global representations from tokens
        transformer_output_tokens = self.adaptive_transformer_encoder(encoder_output_c4) # Output: (B, 256, dim=512)

        # 4.Transform Transformer Output to FSM Input ("image-like feature F")
        # Reshape transformer tokens `(B, 256, 512)` into a spatial feature map `(B, 512, 16, 16)`
        fsm_input_spatial_reshaped = transformer_output_tokens.permute(0, 2, 1).contiguous().view(
            transformer_output_tokens.size(0),
            self.adaptive_transformer_encoder.dim, # Channels (512)
            16, 16 # Spatial dimensions, derived from sequence length (256 = 16*16)
        )
        #Upsample this `(B, 512, 16, 16)` to `(B, 512, 256, 256)` as FSM expects
        fsm_input = self.transformer_output_to_fsm_input(fsm_input_spatial_reshaped)

        # 5.Feature Similarity Module (FSM): Identifies similar regions
        fsm_output_raw = self.fsm(fsm_input) # Output: (B, 256, K=32)

        # 6.Fuse FSM output with decoder path
        # Reshape FSM output `(B, 256, K=32)` into `(B, K, 16, 16)` for spatial compatibility
        fsm_output_spatial_for_fusion = fsm_output_raw.permute(0, 2, 1).contiguous().view(
            fsm_output_raw.size(0), self.fsm.K, 16, 16
        )
        #Project `K` channels to `2048` channels to match `x_bridge` for element-wise addition (fusion)
        fsm_output_fused = self.fsm_output_fusion_transform(fsm_output_spatial_for_fusion)

        #Fusion: Add FSM contribution to the bridge output (main decoder input)
        fused_decoder_input = x_bridge + fsm_output_fused

        #7.CNN Decoder: Reconstructs the mask from fused features
        output = self.decoder(fused_decoder_input, pre_pools)
        return output



#Dummy Data Testing
#This block will test the model's forward pass and print its structure.
if __name__ == '__main__':
    print("Initializing TransCMFDBaseline model...")
    model = TransCMFDBaseline().cuda() # Instantiate model and move to GPU
    print("TransCMFDBaseline model successfully loaded on cuda.")

    #Print the model summary to verify custom modules that I have used in this architecture
    print("\n--- Model Architecture Summary ---")
    print(model)
    print("----------------------------------\n")

    dummy_input = torch.rand((2, 3, 512, 512)).cuda() #Dummy input: Batch size 2, 3 channels (RGB), 512x512 resolution of image
    print(f"Dummy input shape: {dummy_input.shape}")

    with torch.no_grad(): # Disable gradient computation for faster dummy pass (just for dummy input usage)
        output = model(dummy_input)

    print(f"Output shape: {output.shape}")'''



import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F

# Display Image (for Jupyter/IPython, not strictly needed for model logic)
from IPython.display import Image

# Helper Modules
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride)
        self.gn = nn.GroupNorm(32, out_channels) # GroupNorm works well when batch size is small or varies
        self.relu = nn.ReLU(inplace=True)
        self.with_nonlinearity = with_nonlinearity

    def forward(self, x):
        x = self.conv(x)
        x = self.gn(x)
        if self.with_nonlinearity:
            x = self.relu(x)
        return x

class Bridge(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.bridge = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            ConvBlock(out_channels, out_channels)
        )

    def forward(self, x):
        return self.bridge(x)

# ----------------------------------------------------------------------------- #
# UpBlock
class UpBlockForUNetWithResNet50(nn.Module):
    def __init__(self, in_channels_after_concat, out_channels, up_conv_in_channels, up_conv_out_channels,
                 upsampling_method="bilinear"):
        super().__init__()

        self.upsampling_method = upsampling_method
        if upsampling_method == "conv_transpose":
            self.upsample_layer = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2)
        elif upsampling_method == "bilinear":
            self.upsample_layer = nn.Upsample(scale_factor=2.0, mode='bilinear', align_corners=False)
        else:
            raise ValueError("Unsupported upsampling_method")

        self.conv_block = ConvBlock(in_channels_after_concat, out_channels)

    def forward(self, up_x, down_x):
        up_x = self.upsample_layer(up_x)

        # Handle potential size mismatch after upsampling (due to odd dimensions etc.)
        if up_x.shape[2] != down_x.shape[2] or up_x.shape[3] != down_x.shape[3]:
            up_x = F.interpolate(up_x, size=(down_x.shape[2], down_x.shape[3]), mode='bilinear', align_corners=False)

        x = torch.cat([up_x, down_x], 1)
        x = self.conv_block(x)
        return x

# ----------------------------------------------------------------------------- #

# CNN Encoder
class Encoder(nn.Module):
    # CNN Encoder based on a pre-trained ResNet-50.
    # Modified to use 3 convolution stages, making layer3's output (1024 channels) the deepest.
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet.resnet50(pretrained=True)

        # ResNet-50 structure:
        # 0: Conv2d (initial 7x7)
        # 1: BatchNorm2d
        # 2: ReLU
        # 3: MaxPool2d
        # 4: layer1 (residual block, 64 channels out) - spatial size 56x56 for 224x224 input
        # 5: layer2 (residual block, 128 channels out) - spatial size 28x28
        # 6: layer3 (residual block, 256 channels out) - spatial size 14x14
        # 7: layer4 (residual block, 512 channels out) - spatial size 7x7
        # 8: AvgPool2d
        # 9: Linear (FC layer)

        # Input block: initial conv, BN, ReLU, MaxPool
        self.input_block = nn.Sequential(*list(resnet.children()))[:4] # Still corresponds to 64 channels output

        # Down blocks: layer1, layer2, layer3
        # We want to stop before layer4 to get 1024 channels (output of layer3)
        down_blocks = []
        # We need `list(resnet.children())[4:7]` to get layer1, layer2, layer3
        # ResNet's `layer1` outputs 256 channels
        # ResNet's `layer2` outputs 512 channels
        # ResNet's `layer3` outputs 1024 channels
        # ResNet's `layer4` outputs 2048 channels

        # Original code comment suggests channels:
        # layer1: 256, layer2: 512, layer3: 1024, layer4: 2048
        # The structure is: input_block -> 64 channel feature.
        # then layer1 (block 4) outputs 256.
        # layer2 (block 5) outputs 512.
        # layer3 (block 6) outputs 1024.
        # layer4 (block 7) outputs 2048.

        # We will use blocks 4, 5, 6 (layer1, layer2, layer3)
        self.down_blocks = nn.ModuleList(list(resnet.children())[4:7]) # This will include layer1, layer2, layer3

    def forward(self, x):
        pre_pools = dict()
        pre_pools["layer_0"] = x # Original input image for the very last decoder stage (512x512 initially)

        x = self.input_block(x) # Output e.g., (B, 64, 56, 56) for 224x224 input
        pre_pools["layer_1"] = x # Skip connection for decoder's up_block4

        # After input_block, the spatial dims for 224x224 input:
        # Initial 224x224 -> MaxPool (stride 2) -> 112x112 -> MaxPool (stride 2) -> 56x56
        # The first layer of ResNet (layer1) starts operating on 56x56

        # down_blocks contains: layer1, layer2, layer3
        # pre_pools["layer_1"] is the output of input_block (64 channels, spatial 56x56)
        # pre_pools["layer_2"] will be output of layer1 (256 channels, spatial 56x56)
        # pre_pools["layer_3"] will be output of layer2 (512 channels, spatial 28x28)
        # pre_pools["layer_4"] will be output of layer3 (1024 channels, spatial 14x14) - this is the deepest feature

        # The loop now runs for 3 blocks (layer1, layer2, layer3)
        for i, block in enumerate(self.down_blocks):
            x = block(x)
            # Store skip connections. layer_i+2 maps to pre_pools:
            # i=0 (layer1): x is output of layer1. Store as pre_pools["layer_2"]
            # i=1 (layer2): x is output of layer2. Store as pre_pools["layer_3"]
            # i=2 (layer3): x is output of layer3. This is the deepest feature; it goes to bridge/transformer.
            # We don't store it as a 'pre_pool' because it's directly passed to the bridge.
            if i < len(self.down_blocks) - 1: # ALL but the last one (layer3 output is going to bridge)
                pre_pools[f"layer_{i+2}"] = x

        # `x` here is the output of layer3 (1024 channels, 14x14 spatial for 224x224 input)
        return x, pre_pools

# ----------------------------------------------------------------------------- #
class Decoder(nn.Module):
    def __init__(self, n_classes=1):
        super().__init__()
        # Decoder stages:
        # For 224x224 input, spatial resolutions are:
        # Encoder: Input -> 56x56 (input_block/layer1) -> 28x28 (layer2) -> 14x14 (layer3/bridge)
        # Decoder stages will upsample: 14x14 -> 28x28 -> 56x56 -> 112x112 -> 224x224

        # Up-block 1: Fuses deepest feature (from bridge/transformer output) with ResNet.layer2 output
        # up_x (from bridge) is 1024 channels, 14x14. Skip (pre_pools["layer_3"]) is 512 channels, 28x28.
        # Up-x after upsample becomes 1024 channels, 28x28.
        # Concatenated: 1024 + 512 = 1536 channels. Output: 512 channels, 28x28 spatial.
        self.up_block1 = UpBlockForUNetWithResNet50(
            in_channels_after_concat=1024 + 512, out_channels=512,
            up_conv_in_channels=1024, up_conv_out_channels=1024
        )
        # Up-block 2: Fuses output of up_block1 with ResNet.layer1 output
        # up_x (from up_block1) is 512 channels, 28x28. Skip (pre_pools["layer_2"]) is 256 channels, 56x56.
        # Up-x after upsample becomes 512 channels, 56x56.
        # Concatenated: 512 + 256 = 768 channels. Output: 256 channels, 56x56 spatial.
        self.up_block2 = UpBlockForUNetWithResNet50(
            in_channels_after_concat=512 + 256, out_channels=256,
            up_conv_in_channels=512, up_conv_out_channels=512
        )
        # Up-block 3: Fuses output of up_block2 with Encoder's input_block output
        # up_x (from up_block2) is 256 channels, 56x56. Skip (pre_pools["layer_1"]) is 64 channels, 56x56.
        # Concatenated: 256 + 64 = 320 channels. Output: 128 channels, 56x56 spatial.
        self.up_block3 = UpBlockForUNetWithResNet50(
            in_channels_after_concat=256 + 64, out_channels=128,
            up_conv_in_channels=256, up_conv_out_channels=256
        )

        # Up-block 4: Fuses output of up_block3 with original input (pre_pools["layer_0"])
        # up_x (from up_block3) is 128 channels, 56x56.
        # up-x after upsample becomes 128 channels, 112x112
        # Original input (pre_pools["layer_0"]) is 3 channels, 224x224.
        # There's a mismatch here; the upsampling from 56x56 to 112x112 is one step,
        # but the target is 224x224. We need an additional upsampling step or modify `up_block4` logic.

        # Let's adjust the final stages.
        # The output of up_block3 is 128 channels, 56x56.
        # We need to upsample this to 224x224 (2 steps: 56->112, 112->224)

        # First upsample from 56x56 to 112x112
        self.upsample_to_112 = nn.Upsample(scale_factor=2.0, mode='bilinear', align_corners=False)
        self.conv_after_upsample_112 = ConvBlock(128, 64) # Reduce channels after upsampling

        # Second upsample from 112x112 to 224x224 and final convolution
        self.upsample_to_224 = nn.Upsample(scale_factor=2.0, mode='bilinear', align_corners=False)
        # Concatenate with pre_pools["layer_0"] (the original input image)
        # After self.conv_after_upsample_112, we have 64 channels.
        # Original input (pre_pools["layer_0"]) has 3 channels.
        # So, in_channels_after_concat will be 64 + 3 = 67.
        self.final_conv_block = ConvBlock(in_channels=64 + 3, out_channels=64)

        # Output layer
        self.out = nn.Conv2d(64, n_classes, kernel_size=(1, 1), stride=(1, 1))

    def forward(self, x, pre_pools):
        # x is the deepest feature (from bridge/transformer path), 14x14 spatial
        x = self.up_block1(x, pre_pools["layer_3"]) # Output: 512 channels, 28x28 spatial
        x = self.up_block2(x, pre_pools["layer_2"]) # Output: 256 channels, 56x56 spatial
        x = self.up_block3(x, pre_pools["layer_1"]) # Output: 128 channels, 56x56 spatial

        # Now, upsample from 56x56 to 224x224 (target original input size)
        # Step 1: 56x56 -> 112x112
        x = self.upsample_to_112(x) # 128 channels, 112x112
        x = self.conv_after_upsample_112(x) # 64 channels, 112x112

        # Step 2: 112x112 -> 224x224, then concatenate with original input (pre_pools["layer_0"])
        x = self.upsample_to_224(x) # 64 channels, 224x224

        # Concatenate with original input image (pre_pools["layer_0"])
        # Ensure sizes match, although interpolate in up_block handles this.
        # But here, we explicitly concatenate the final upsampled feature with the original input.
        # Original input `pre_pools["layer_0"]` is (B, 3, 224, 224)
        x = torch.cat([x, pre_pools["layer_0"]], 1) # (B, 64+3, 224, 224)

        x = self.final_conv_block(x) # (B, 64, 224, 224)

        x = self.out(x) # (B, n_classes, 224, 224)
        return x

# ----------------------------------------------------------------------------- #
class FeatureSimilarityModule(nn.Module):
    """
    Feature Similarity Module (FSM) implementing block-wise Pearson correlation and percentile pooling.
    """
    def __init__(self):
        super(FeatureSimilarityModule, self).__init__()
        self.K = 32 # Number of top similarity scores to select per block (1/8)

    # Removed _pearson_correlation_coefficient as it will be vectorized

    def forward(self, feature_map):
        batch_size, channels, H, W = feature_map.shape

        block_size_spatial = 16
        num_blocks_h = H // block_size_spatial
        num_blocks_w = W // block_size_spatial
        num_blocks = num_blocks_h * num_blocks_w

        # Unfold for patch extraction
        # Output: (B, C, num_blocks_h, num_blocks_w, block_size, block_size)
        blocks_unfolded = feature_map.unfold(2, block_size_spatial, block_size_spatial).unfold(3, block_size_spatial, block_size_spatial)

        # Flatten each block into a 1D feature vector for similarity computation
        # Desired shape for batch matrix multiplication: (B, num_blocks, C * block_size * block_size)
        # Permute to (B, num_blocks_h, num_blocks_w, C, block_size, block_size) then view
        blocks_flat = blocks_unfolded.permute(0, 2, 3, 1, 4, 5).contiguous().view(
            batch_size, num_blocks, -1
        ) # Shape: (B, N, D_block) where N=num_blocks, D_block=flattened dim of a block

        # --- Vectorized Pearson Correlation ---
        # 1. Center the data (subtract mean)
        mean = torch.mean(blocks_flat, dim=2, keepdim=True) # Mean of each block's features
        centered_blocks = blocks_flat - mean

        # 2. Calculate standard deviation
        std = torch.std(blocks_flat, dim=2, keepdim=True, unbiased=False)

        # Avoid division by zero: replace 0 std with 1 to avoid NaN in division
        # This is crucial for numerical stability and preventing NaN/inf, often seen in OOM cases too.
        std_safe = torch.where(std == 0, torch.tensor(1.0, device=feature_map.device, dtype=std.dtype), std)
        normalized_blocks = centered_blocks / std_safe

        # 3. Compute pairwise dot products (cosine similarity, which is Pearson if data is normalized)
        # (B, N, D_block) @ (B, D_block, N) -> (B, N, N)
        similarity_matrix = torch.matmul(normalized_blocks, normalized_blocks.transpose(1, 2))

        # Pearson correlation also involves dividing by the number of elements in the block.
        # D_block = channels * block_size_spatial * block_size_spatial
        D_block = blocks_flat.shape[2]
        similarity_matrix = similarity_matrix / D_block

        # Percentile Pooling Block: Select top K similarities
        # Ensure sorting is efficient and doesn't create excessive temporaries
        percentile_scores = torch.zeros(batch_size, num_blocks, self.K, device=feature_map.device, dtype=similarity_matrix.dtype)

        # This loop is still Python, but it's on a much smaller tensor (256x256), not creating huge temp tensors repeatedly.
        # Can be optimized further by using torch.topk, but current loop on 256 elements is fine.
        for b in range(batch_size):
            # Using torch.topk is more efficient than torch.sort followed by slicing
            top_k_scores, _ = torch.topk(similarity_matrix[b], k=self.K, dim=-1, largest=True, sorted=True)
            percentile_scores[b] = top_k_scores

        return percentile_scores


# Adaptive Transformer Components
# Custom Adaptive Multi-Head Self-Attention for the transformer's core innovation
class AdaptiveMultiHeadSelfAttention(nn.Module):
    """
    Custom Adaptive Multi-Head Self-Attention (AdaptiveMSA) module.
    Implements the Dual-Path Adaptive Attention Mechanism (DPAAM).
    """
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        if self.head_dim * num_heads != self.embed_dim:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})")

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        self.S1 = nn.Parameter(torch.ones(embed_dim))
        self.S2 = nn.Parameter(torch.ones(embed_dim))

        self.b1 = nn.Parameter(torch.zeros(embed_dim))
        self.b2 = nn.Parameter(torch.zeros(embed_dim))

        self.epsilon = nn.Parameter(torch.tensor(0.5))
        self.delta = nn.Parameter(torch.tensor(0.5))
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        batch_size, seq_len, embed_dim = query.size()

        Q = self.q_proj(query)
        K = self.k_proj(key)
        V = self.v_proj(value)

        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        attention_weights = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention_probs = F.softmax(attention_weights, dim=-1)

        Head_prime = torch.matmul(attention_probs, V)
        Head_prime_combined = Head_prime.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)

        M1_Head = self.S1 * Head_prime_combined + self.b1
        M2_Head = self.S2 * Head_prime_combined + self.b2

        epsilon_val = torch.sigmoid(self.epsilon)
        delta_val = torch.sigmoid(self.delta)

        adaptive_heads_combined = epsilon_val * M1_Head + delta_val * M2_Head

        output = self.out_proj(adaptive_heads_combined)
        return output

# AdaptiveTransformerLayer (UPDATED to use custom AdaptiveMSA)
class AdaptiveTransformerLayer(nn.Module):
    """
    A single layer of the Adaptive Transformer Encoder, incorporating AdaptiveMSA and MLP.
    """
    def __init__(self, dim, num_heads):
        super().__init__()
        self.adaptive_mhsa = AdaptiveMultiHeadSelfAttention(embed_dim=dim, num_heads=num_heads)

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, x):
        attn_output = self.adaptive_mhsa(x, x, x)
        x = self.norm1(x + attn_output)

        x_mlp = self.mlp(x)
        x = self.norm2(x + x_mlp)
        return x

# AdaptiveTransformerEncoder (Uses AdaptiveTransformerLayer)
class AdaptiveTransformerEncoder(nn.Module):
    """
    The Adaptive Transformer Encoder, stacking multiple AdaptiveTransformerLayers.
    Handles Tokenization and Positional Embedding for CNN features.
    """
    def __init__(self, dim, num_heads, num_layers, input_spatial_size):
        super().__init__()
        self.dim = dim
        self.num_layers = num_layers
        self.input_spatial_size = input_spatial_size # e.g., 14 for 14x14 feature map

        self.layers = nn.ModuleList([AdaptiveTransformerLayer(dim, num_heads) for _ in range(num_layers)])

        # Tokenization: Projects CNN encoder's deepest output channels (1024) to transformer's `dim`
        self.proj_to_dim = nn.Conv2d(1024, dim, kernel_size=1)

        # Learnable Positional Embedding
        self.pos_embedding = nn.Parameter(torch.randn(1, input_spatial_size * input_spatial_size, dim))

    def forward(self, feature_map_c_deepest):
        # `feature_map_c_deepest` is the deepest feature from the CNN Encoder (layer3 output):
        # Expected shape: (batch_size, 1024, H_spatial, W_spatial) (e.g., 14x14 for 224x224 input)

        x = self.proj_to_dim(feature_map_c_deepest) # Output: (B, dim, H_spatial, W_spatial)
        x = x.flatten(2).permute(0, 2, 1) # (B, H_spatial * W_spatial, dim) -> e.g., (B, 14*14=196, dim=512)

        if x.shape[1] != self.pos_embedding.shape[1]:
            raise ValueError(f"Positional embedding sequence length mismatch. Expected {self.pos_embedding.shape[1]}, got {x.shape[1]}")

        x = x + self.pos_embedding

        for layer in self.layers:
            x = layer(x)
        return x # Output tokens: (batch_size, sequence_length, embedding_dimension)


# Main TransCMFDBaseline Model
class TransCMFDBaseline(nn.Module):
    """
    The complete TransCMFD model architecture.
    """
    def __init__(self, n_classes=1):
        super().__init__()
        self.encoder = Encoder()

        # Bridge now takes 1024 channels as input (from Encoder's layer3 output)
        self.bridge = Bridge(1024, 1024)

        # Adaptive Transformer Encoder setup
        transformer_dim = 512
        transformer_heads = 8
        transformer_layers = 2 # As per your requirement

        # For 224x224 input, the deepest feature map (layer3 output) is 14x14
        transformer_input_spatial_size = 14

        self.adaptive_transformer_encoder = AdaptiveTransformerEncoder(
            dim=transformer_dim,
            num_heads=transformer_heads,
            num_layers=transformer_layers,
            input_spatial_size=transformer_input_spatial_size
        )

        # Transforms Transformer Output (tokens) to FSM Input (spatial feature map F)
        # Transformer output: (B, seq_len=196, dim=512)
        # Reshape to (B, 512, 14, 14) and then upsample to (B, 512, 256, 256) for FSM input as specified
        # in the original FSM comment. This means the FSM is expected to work on a relatively high-resolution
        # feature map.
        self.transformer_output_to_fsm_input = nn.Upsample(size=(256, 256), mode='bilinear', align_corners=False)

        self.fsm = FeatureSimilarityModule() # Feature Similarity Module

        # Transforms FSM output back to a spatial feature map for fusion with decoder path
        # FSM output: (B, num_blocks=256, K=32)
        # Reshape to (B, K, 16, 16) for fusion.
        # This layer now projects K channels to 1024 (to match bridge output for fusion).
        self.fsm_output_fusion_transform = nn.Conv2d(self.fsm.K, 1024, kernel_size=1)

        self.decoder = Decoder(n_classes=n_classes) # CNN Decoder

    def forward(self, x):
        # 1. CNN Encoder: Extracts local features and provides skip connections
        # encoder_output_c_deepest will be from ResNet's layer3 (1024 channels, 14x14 for 224x224 input)
        encoder_output_c_deepest, pre_pools = self.encoder(x)

        # 2. Bridge: Connects encoder's deepest feature to transformer/decoder path
        # x_bridge will be (B, 1024, 14, 14)
        x_bridge = self.bridge(encoder_output_c_deepest)

        # 3. Adaptive Transformer Encoder: Learns global representations from tokens
        # Input to transformer is encoder_output_c_deepest (1024 channels, 14x14 spatial)
        transformer_output_tokens = self.adaptive_transformer_encoder(encoder_output_c_deepest) # Output: (B, 196, dim=512)

        # 4. Transform Transformer Output to FSM Input ("image-like feature F")
        # Reshape transformer tokens `(B, 196, 512)` into a spatial feature map `(B, 512, 14, 14)`
        fsm_input_spatial_reshaped = transformer_output_tokens.permute(0, 2, 1).contiguous().view(
            transformer_output_tokens.size(0),
            self.adaptive_transformer_encoder.dim, # Channels (512)
            self.adaptive_transformer_encoder.input_spatial_size, # H (14)
            self.adaptive_transformer_encoder.input_spatial_size # W (14)
        )
        # Upsample this `(B, 512, 14, 14)` to `(B, 512, 256, 256)` as FSM expects
        fsm_input = self.transformer_output_to_fsm_input(fsm_input_spatial_reshaped)

        # 5. Feature Similarity Module (FSM): Identifies similar regions
        fsm_output_raw = self.fsm(fsm_input) # Output: (B, 256, K=32)

        # 6. Fuse FSM output with decoder path
        # Reshape FSM output `(B, 256, K=32)` into `(B, K, 16, 16)` for spatial compatibility
        fsm_output_spatial_for_fusion = fsm_output_raw.permute(0, 2, 1).contiguous().view(
            fsm_output_raw.size(0), self.fsm.K, 16, 16
        )
        # Project `K` channels to `1024` channels to match `x_bridge` for element-wise addition (fusion)
        fsm_output_fused = self.fsm_output_fusion_transform(fsm_output_spatial_for_fusion)

        # Fusion: Add FSM contribution to the bridge output (main decoder input)
        # Before fusion, ensure spatial dimensions match.
        # x_bridge is (B, 1024, 14, 14)
        # fsm_output_fused is (B, 1024, 16, 16)
        # We need to interpolate one of them to match the other.
        # Let's interpolate fsm_output_fused to 14x14
        fsm_output_fused_interpolated = F.interpolate(
            fsm_output_fused,
            size=(x_bridge.shape[2], x_bridge.shape[3]), # Target size of x_bridge (14x14)
            mode='bilinear',
            align_corners=False
        )
        fused_decoder_input = x_bridge + fsm_output_fused_interpolated

        # 7. CNN Decoder: Reconstructs the mask from fused features
        output = self.decoder(fused_decoder_input, pre_pools)
        return output


# Dummy Data Testing
if __name__ == '__main__':
    print("Initializing TransCMFDBaseline model...")
    # It's good practice to set a device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TransCMFDBaseline().to(device) # Instantiate model and move to device
    print(f"TransCMFDBaseline model successfully loaded on {device}.")

    print("\n--- Model Architecture Summary ---")
    print(model)
    print("----------------------------------\n")

    # Dummy input: Batch size 2, 3 channels (RGB), 224x224 resolution
    dummy_input = torch.rand((2, 3, 224, 224)).to(device)
    print(f"Dummy input shape: {dummy_input.shape}")

    with torch.no_grad(): # Disable gradient computation for faster dummy pass
        output = model(dummy_input)

    print(f"Output shape: {output.shape}")

Initializing TransCMFDBaseline model...


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 193MB/s]


TransCMFDBaseline model successfully loaded on cuda.

--- Model Architecture Summary ---
TransCMFDBaseline(
  (encoder): Encoder(
    (input_block): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (down_blocks): ModuleList(
      (0): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
     

In [3]:
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms


class CMFDataset(Dataset):
    """
    PyTorch Dataset for Copy-Move Forgery Detection using CASIA v2.0.
    Loads (Tampered Image, Ground Truth Mask) pairs and applies transformations.
    """
    def __init__(self, data_root, transform=None, img_size=(512, 512)):
        self.data_root = data_root

        self.tampered_folder = os.path.join(data_root, 'Tp')
        self.gt_folder = os.path.join(data_root, 'CASIA 2 Groundtruth')

        self.transform = transform
        self.img_size = img_size

        self.data_pairs = []

        tampered_image_files = sorted([f for f in os.listdir(self.tampered_folder) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))])

        for img_name in tampered_image_files:
            img_path = os.path.join(self.tampered_folder, img_name)

            # This part assumes mask names like 'Tp_D_CNN_S_N_txt00043_txt00051_10378_gt.png' from 'Tp_D_CNN_S_N_txt00043_txt00051_10378.jpg'

            base_name_without_ext = os.path.splitext(img_name)[0]
            mask_name = f"{base_name_without_ext}_gt.png"
            mask_path = os.path.join(self.gt_folder, mask_name)

            if os.path.exists(mask_path):
                self.data_pairs.append((img_path, mask_path))
            else:
                print(f"Warning: Mask not found for {img_name} at {mask_path}. Skipping.")

        print(f"Loaded {len(self.data_pairs)} valid tampered image-mask pairs from {data_root}")

    def __len__(self):
        return len(self.data_pairs)

    def __getitem__(self, idx):
        img_path, mask_path = self.data_pairs[idx]

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        image = image.resize(self.img_size)
        mask = mask.resize(self.img_size, Image.NEAREST)

        image_tensor = transforms.ToTensor()(image)
        mask_tensor = transforms.ToTensor()(mask)
        mask_tensor = (mask_tensor > 0.5).float()

        if self.transform:
            image_tensor = self.transform(image_tensor)

        return image_tensor, mask_tensor

In [4]:


# Define transformations (ImageNet normalization is standard for ResNet pretrained)
transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
CASIA_V2_ROOT = '/content/drive/MyDrive/TransCMFD_dataset/CASIA2'

# Create the dataset instance
full_dataset = CMFDataset(
    data_root=CASIA_V2_ROOT,
    transform=transform,
    img_size=(224, 224)
)

# Split dataset into training and validation sets
# If CASIA v2.0 has ~5123 images, this split is reasonable.
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

print(f"Dataset split: Training {len(train_dataset)} samples, Validation {len(val_dataset)} samples")

# Create DataLoaders
BATCH_SIZE = 8 # Keep small initially, adjust based on GPU memory
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) # num_workers > 0 for faster loading
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Train DataLoader batches: {len(train_loader)}")
print(f"Validation DataLoader batches: {len(val_loader)}")

# --- Test Data Loading ---
print("\nTesting data loading...")
try:
    # Fetch one batch to verify shapes and types
    images, masks = next(iter(train_loader)) # Use next(iter()) to get one batch
    print(f"Batch images shape: {images.shape}") # Should be (BATCH_SIZE, 3, 512, 512)
    print(f"Batch images dtype: {images.dtype}") # Should be torch.float32
    print(f"Batch masks shape: {masks.shape}")   # Should be (BATCH_SIZE, 1, 512, 512)
    print(f"Batch masks dtype: {masks.dtype}")   # Should be torch.float32
    print("Data loading successful!")
except Exception as e:
    print(f"Error during data loading test: {e}")
    print("Please double-check your CASIA_V2_ROOT path, folder structure (Tampered, Groundtruth), and mask naming.")
    print("Common issues: wrong root path, incorrect subfolder names, mask files not found for image files.")

Loaded 2004 valid tampered image-mask pairs from /content/drive/MyDrive/TransCMFD_dataset/CASIA2
Dataset split: Training 1603 samples, Validation 401 samples
Train DataLoader batches: 201
Validation DataLoader batches: 51

Testing data loading...
Batch images shape: torch.Size([8, 3, 224, 224])
Batch images dtype: torch.float32
Batch masks shape: torch.Size([8, 1, 224, 224])
Batch masks dtype: torch.float32
Data loading successful!


In [5]:
import os

CASIA_V2_ROOT = '/content/drive/MyDrive/TransCMFD_dataset/CASIA2'

print(f"Contents of {CASIA_V2_ROOT}:")
try:
    for item in os.listdir(CASIA_V2_ROOT):
        item_path = os.path.join(CASIA_V2_ROOT, item)
        if os.path.isdir(item_path):
            print(f"  [DIR] {item}")
        else:
            print(f"  [FILE] {item}")
except FileNotFoundError:
    print(f"Error: {CASIA_V2_ROOT} not found. Please double-check the path.")
except Exception as e:
    print(f"An error occurred: {e}")

Contents of /content/drive/MyDrive/TransCMFD_dataset/CASIA2:
  [DIR] CASIA 2 Groundtruth
  [DIR] Tp


In [6]:
import os

CASIA_V2_ROOT = '/content/drive/MyDrive/TransCMFD_dataset/CASIA2'
TAMPERED_FOLDER = os.path.join(CASIA_V2_ROOT, 'Tp')
GT_FOLDER = os.path.join(CASIA_V2_ROOT, 'CASIA 2 Groundtruth')

print(f"--- Listing Tampered Images in '{TAMPERED_FOLDER}' ---")
try:
    tampered_files = sorted([f for f in os.listdir(TAMPERED_FOLDER) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))])
    if tampered_files:
        print(f"Found {len(tampered_files)} tampered images. First 5:")
        for i, fname in enumerate(tampered_files[:5]):
            print(f"  {fname}")

        if tampered_files:
            sample_img_name = tampered_files[0]
            base_name_without_ext = os.path.splitext(sample_img_name)[0]
            derived_mask_name = f"{base_name_without_ext}_gt.png" # This is the current assumption
            expected_mask_path = os.path.join(GT_FOLDER, derived_mask_name)

            print(f"\n--- Checking Mask for Sample Image: '{sample_img_name}' ---")
            print(f"  Derived mask name: '{derived_mask_name}'")
            print(f"  Expected mask path: '{expected_mask_path}'")

            if os.path.exists(expected_mask_path):
                print("  --> Mask file EXISTS at the derived path! This is good.")
            else:
                print("  --> Mask file DOES NOT EXIST at the derived path! This is the problem.")
                print("  Please check the exact naming convention of your mask files in the 'CASIA 2 Groundtruth' folder.")
                print("  Look for a mask file that corresponds to this image:")
                print(f"    Image: {sample_img_name}")
                print("  And provide its exact mask filename.")

    else:
        print(f"No image files found in '{TAMPERED_FOLDER}'. Please check the path and content.")
except FileNotFoundError:
    print(f"Error: '{TAMPERED_FOLDER}' not found. Check subfolder name or root path.")
except Exception as e:
    print(f"An error occurred: {e}")

--- Listing Tampered Images in '/content/drive/MyDrive/TransCMFD_dataset/CASIA2/Tp' ---
Found 2072 tampered images. First 5:
  Tp_D_CND_S_N_txt00028_txt00006_10848.jpg
  Tp_D_CNN_M_B_nat00056_nat00099_11105.jpg
  Tp_D_CNN_M_B_nat10139_nat00059_11949.jpg
  Tp_D_CNN_M_B_nat10139_nat00097_11948.jpg
  Tp_D_CNN_M_N_ani00052_ani00054_11130.jpg

--- Checking Mask for Sample Image: 'Tp_D_CND_S_N_txt00028_txt00006_10848.jpg' ---
  Derived mask name: 'Tp_D_CND_S_N_txt00028_txt00006_10848_gt.png'
  Expected mask path: '/content/drive/MyDrive/TransCMFD_dataset/CASIA2/CASIA 2 Groundtruth/Tp_D_CND_S_N_txt00028_txt00006_10848_gt.png'
  --> Mask file EXISTS at the derived path! This is good.


In [7]:
'''import torch.optim as optim
from tqdm.notebook import tqdm
import os

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    def forward(self, prediction, target):
        target = target.float()
        prediction = torch.sigmoid(prediction) #sigmoid for conversion to probabilities
        prediction_flat = prediction.contiguous().view(-1)
        target_flat = target.contiguous().view(-1)
        intersection = (prediction_flat * target_flat).sum()
        dice_coefficient = (2. * intersection + self.smooth) / (prediction_flat.sum() + target_flat.sum() + self.smooth)
        return 1 - dice_coefficient

class AdaptiveRegularizationLoss(nn.Module):
    def __init__(self):
        super(AdaptiveRegularizationLoss, self).__init__()
    def forward(self, model):
        l_adapt_total = 0.0
        if hasattr(model, 'adaptive_transformer_encoder') and hasattr(model.adaptive_transformer_encoder, 'layers'):
            for layer in model.adaptive_transformer_encoder.layers:
                if hasattr(layer, 'adaptive_mhsa'):
                    s1_param = layer.adaptive_mhsa.S1
                    s2_param = layer.adaptive_mhsa.S2
                    l_adapt_total += torch.norm(s1_param, 2)**2
                    l_adapt_total += torch.norm(s2_param, 2)**2   ##squared l2 norm
        return l_adapt_total



# Define Loss Functions
dice_loss_fn = DiceLoss(smooth=1.0)
bce_loss_fn = nn.BCEWithLogitsLoss()
adaptive_reg_loss_fn = AdaptiveRegularizationLoss()

# Define Optimizer
LEARNING_RATE = 1e-4
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Loss weighting factors
ALPHA = 0.5
BETA = 0.5
GAMMA = 0.01


NUM_EPOCHS = 10

def train_model(model, train_loader, val_loader, optimizer, dice_loss, bce_loss, adaptive_reg_loss, num_epochs, alpha, beta, gamma):
    model.train()
    device = next(model.parameters()).device

    # Initialize GradScaler for Automatic Mixed Precision (AMP)
    scaler = torch.cuda.amp.GradScaler() #

    best_val_loss = float('inf')

    # Clear CUDA cache before training loop to free up any residual memory
    torch.cuda.empty_cache()

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        model.train()
        running_loss = 0.0
        train_loop = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")

        for batch_idx, (images, masks) in enumerate(train_loop):
            images = images.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()

            # Use autocast for mixed precision training
            with torch.cuda.amp.autocast():
                predictions = model(images)
                l_dice = dice_loss(predictions, masks)
                l_bce = bce_loss(predictions, masks)
                l_adapt = adaptive_reg_loss(model)

                total_loss = alpha * l_dice + beta * l_bce + gamma * l_adapt

            # Scale the loss and perform backward pass
            scaler.scale(total_loss).backward()

            # Optimizer step
            scaler.step(optimizer)

            # Update the scaler for next iteration
            scaler.update()

            running_loss += total_loss.item() * images.size(0)

            train_loop.set_postfix(loss=total_loss.item())

        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1} Training Loss: {epoch_loss:.4f}")

        # --- Validation Phase ---
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            val_loop = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}")
            for images, masks in val_loop:
                images = images.to(device)
                masks = masks.to(device)

                # Use autocast for validation too (no scaler.scale/update needed)
                with torch.cuda.amp.autocast():
                    predictions = model(images)
                    l_dice_val = dice_loss(predictions, masks)
                    l_bce_val = bce_loss(predictions, masks)
                    l_adapt_val = adaptive_reg_loss(model)
                    total_val_loss = alpha * l_dice_val + beta * l_bce_val + gamma * l_adapt_val

                val_loss += total_val_loss.item() * images.size(0)
                val_loop.set_postfix(val_loss=total_val_loss.item())

        epoch_val_loss = val_loss / len(val_loader.dataset)
        print(f"Epoch {epoch+1} Validation Loss: {epoch_val_loss:.4f}")

        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            checkpoint_path = '/content/drive/MyDrive/TransCMFD_Checkpoints/best_model.pth'
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Model saved to {checkpoint_path} with validation loss: {best_val_loss:.4f}")

    print("\nTraining complete!")


if __name__ == '__main__':

    BATCH_SIZE = 4
    NUM_WORKERS = 0

    print(f"Using BATCH_SIZE: {BATCH_SIZE}, NUM_WORKERS: {NUM_WORKERS}")


    try:

        train_size = int(0.1 * len(full_dataset))
        val_size = len(full_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
        print("DataLoaders re-created with new batch size and num_workers.")
    except NameError:
        print("Warning: full_dataset not found. Ensure CMFDataset and DataLoader setup cells were run correctly.")
        print("Please define/re-run your data loading setup to ensure train_loader and val_loader are available.")


    print("\nStarting model training...")
    train_model(model, train_loader, val_loader, optimizer, dice_loss_fn, bce_loss_fn, adaptive_reg_loss_fn, NUM_EPOCHS, ALPHA, BETA, GAMMA)'''
import torch.optim as optim
from tqdm.notebook import tqdm
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

# Assuming your model definition (ConvBlock, Bridge, UpBlockForUNetWithResNet50,
# Encoder, Decoder, FeatureSimilarityModule, AdaptiveMultiHeadSelfAttention,
# AdaptiveTransformerLayer, AdaptiveTransformerEncoder, TransCMFDBaseline)
# is in the same file or imported correctly.

# If running as a script, ensure TransCMFDBaseline is defined or imported
# from your model file. For this example, I'll assume it's defined.

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    def forward(self, prediction, target):
        target = target.float()
        prediction = torch.sigmoid(prediction) #sigmoid for conversion to probabilities
        prediction_flat = prediction.contiguous().view(-1)
        target_flat = target.contiguous().view(-1)
        intersection = (prediction_flat * target_flat).sum()
        dice_coefficient = (2. * intersection + self.smooth) / (prediction_flat.sum() + target_flat.sum() + self.smooth)
        return 1 - dice_coefficient

class AdaptiveRegularizationLoss(nn.Module):
    def __init__(self):
        super(AdaptiveRegularizationLoss, self).__init__()
    def forward(self, model):
        l_adapt_total = 0.0
        if hasattr(model, 'adaptive_transformer_encoder') and hasattr(model.adaptive_transformer_encoder, 'layers'):
            for layer in model.adaptive_transformer_encoder.layers:
                if hasattr(layer, 'adaptive_mhsa'):
                    s1_param = layer.adaptive_mhsa.S1
                    s2_param = layer.adaptive_mhsa.S2
                    l_adapt_total += torch.norm(s1_param, 2)**2
                    l_adapt_total += torch.norm(s2_param, 2)**2
        return l_adapt_total

# Define Loss Functions
dice_loss_fn = DiceLoss(smooth=1.0)
bce_loss_fn = nn.BCEWithLogitsLoss()
adaptive_reg_loss_fn = AdaptiveRegularizationLoss()

# Define Optimizer - Assuming 'model' is already defined and moved to device
# (If running this file directly, 'model' would need to be instantiated globally or passed in)
# Example: model = TransCMFDBaseline().to(device)
# If this code is appended to the previous model definition file, 'model' will be available.

LEARNING_RATE = 1e-4
# No need to redefine optimizer here if it's already defined with model.parameters()

# Loss weighting factors
ALPHA = 0.5
BETA = 0.5
GAMMA = 0.01

NUM_EPOCHS = 10
GRADIENT_ACCUMULATION_STEPS = 4 # New: for simulating larger batch size

def train_model(model, train_loader, val_loader, optimizer, dice_loss, bce_loss, adaptive_reg_loss, num_epochs, alpha, beta, gamma, scheduler=None, gradient_accumulation_steps=1):
    model.train()
    device = next(model.parameters()).device

    scaler = torch.cuda.amp.GradScaler() # For Automatic Mixed Precision (AMP)

    best_val_loss = float('inf')

    torch.cuda.empty_cache() # Clear CUDA cache before training loop

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        model.train()
        running_loss = 0.0
        train_loop = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")

        optimizer.zero_grad() # Zero gradients for accumulation at the start of epoch/accumulation cycle

        for batch_idx, (images, masks) in enumerate(train_loop):
            images = images.to(device, non_blocking=True) # non_blocking=True for async transfer
            masks = masks.to(device, non_blocking=True)

            with torch.cuda.amp.autocast():
                predictions = model(images)
                l_dice = dice_loss(predictions, masks)
                l_bce = bce_loss(predictions, masks)
                l_adapt = adaptive_reg_loss(model)

                total_loss = alpha * l_dice + beta * l_bce + gamma * l_adapt
                total_loss = total_loss / gradient_accumulation_steps # Scale loss for accumulation

            scaler.scale(total_loss).backward() # Accumulate scaled gradients

            # Perform optimizer step only after accumulating enough gradients
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                # Unscale the gradients before clipping, if clipping is used
                # scaler.unscale_(optimizer)
                # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Example: Clip gradients

                scaler.step(optimizer) # Update model parameters
                scaler.update() # Update the scaler for next iteration
                optimizer.zero_grad() # Clear gradients for the next accumulation cycle

            # Adjust running_loss to reflect the unscaled loss per sample
            running_loss += total_loss.item() * images.size(0) * gradient_accumulation_steps

            train_loop.set_postfix(loss=total_loss.item() * gradient_accumulation_steps) # Display unscaled loss

        # Handle any remaining accumulated gradients from the last partial batch
        if (batch_idx + 1) % gradient_accumulation_steps != 0:
            # scaler.unscale_(optimizer) # Unscale before clipping if needed
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Example: Clip gradients

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1} Training Loss: {epoch_loss:.4f}")

        # --- Validation Phase ---
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            val_loop = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}")
            for images, masks in val_loop:
                images = images.to(device, non_blocking=True)
                masks = masks.to(device, non_blocking=True)

                with torch.cuda.amp.autocast():
                    predictions = model(images)
                    l_dice_val = dice_loss(predictions, masks)
                    l_bce_val = bce_loss(predictions, masks)
                    l_adapt_val = adaptive_reg_loss(model)
                    total_val_loss = alpha * l_dice_val + beta * l_bce_val + gamma * l_adapt_val

                val_loss += total_val_loss.item() * images.size(0)
                val_loop.set_postfix(val_loss=total_val_loss.item())

        epoch_val_loss = val_loss / len(val_loader.dataset)
        print(f"Epoch {epoch+1} Validation Loss: {epoch_val_loss:.4f}")

        if scheduler:
            scheduler.step(epoch_val_loss) # For ReduceLROnPlateau

        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            checkpoint_path = '/content/drive/MyDrive/TransCMFD_Checkpoints/best_model.pth'
            os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Model saved to {checkpoint_path} with validation loss: {best_val_loss:.4f}")

    print("\nTraining complete!")


if __name__ == '__main__':
    # Assuming 'TransCMFDBaseline' class is available (from previous code block)
    # and 'full_dataset' and 'DataLoader' are imported/defined.
    # For a standalone run of this file, you'd need these:
    from torch.utils.data import DataLoader, Dataset
    import torchvision.transforms as transforms
    # Dummy Dataset for demonstration if full_dataset is not defined
    class DummyDataset(Dataset):
        def __init__(self, num_samples=100, img_size=224):
            self.num_samples = num_samples
            self.img_size = img_size
        def __len__(self):
            return self.num_samples
        def __getitem__(self, idx):
            # RGB image, 3 channels, HxW
            image = torch.rand(3, self.img_size, self.img_size)
            # Binary mask, 1 channel, HxW
            mask = (torch.rand(1, self.img_size, self.img_size) > 0.5).float()
            return image, mask

    try:
        # Check if full_dataset exists from previous execution context
        _ = full_dataset
    except NameError:
        print("`full_dataset` not found. Creating a DummyDataset for demonstration.")
        full_dataset = DummyDataset(num_samples=1000, img_size=224) # Adjust num_samples

    # Instantiate model and move to device HERE
    model = TransCMFDBaseline().to(device)
    print(f"Model instantiated on: {device}")

    # Define Optimizer (needs to be after model instantiation)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Learning Rate Scheduler
    from torch.optim.lr_scheduler import ReduceLROnPlateau
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True) # Reduce LR if val loss plateaus

    BATCH_SIZE = 1
    NUM_WORKERS = os.cpu_count() # Recommended: Use all available CPU cores for data loading
    if NUM_WORKERS is None: # Fallback for environments where os.cpu_count() might return None
        NUM_WORKERS = 0
    print(f"Using BATCH_SIZE: {BATCH_SIZE}, NUM_WORKERS: {NUM_WORKERS}")

    # DataLoaders setup
    train_size = int(0.8 * len(full_dataset)) # Adjust split ratio if needed
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
    print("DataLoaders re-created with new batch size, num_workers, and pin_memory=True.")

    print("\nStarting model training...")
    train_model(
        model,
        train_loader,
        val_loader,
        optimizer,
        dice_loss_fn,
        bce_loss_fn,
        adaptive_reg_loss_fn,
        NUM_EPOCHS,
        ALPHA,
        BETA,
        GAMMA,
        scheduler=scheduler, # Pass the scheduler
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS # Pass accumulation steps
    )

Model instantiated on: cuda
Using BATCH_SIZE: 1, NUM_WORKERS: 12
DataLoaders re-created with new batch size, num_workers, and pin_memory=True.

Starting model training...

Epoch 1/10


  scaler = torch.cuda.amp.GradScaler() # For Automatic Mixed Precision (AMP)


Training Epoch 1:   0%|          | 0/1603 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Epoch 1 Training Loss: nan


Validation Epoch 1:   0%|          | 0/401 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Epoch 1 Validation Loss: nan

Epoch 2/10


Training Epoch 2:   0%|          | 0/1603 [00:00<?, ?it/s]

Epoch 2 Training Loss: nan


Validation Epoch 2:   0%|          | 0/401 [00:00<?, ?it/s]

Epoch 2 Validation Loss: nan

Epoch 3/10


Training Epoch 3:   0%|          | 0/1603 [00:00<?, ?it/s]

Epoch 3 Training Loss: nan


Validation Epoch 3:   0%|          | 0/401 [00:00<?, ?it/s]

Epoch 3 Validation Loss: nan

Epoch 4/10


Training Epoch 4:   0%|          | 0/1603 [00:00<?, ?it/s]

Epoch 4 Training Loss: nan


Validation Epoch 4:   0%|          | 0/401 [00:00<?, ?it/s]

Epoch 4 Validation Loss: nan

Epoch 5/10


Training Epoch 5:   0%|          | 0/1603 [00:00<?, ?it/s]

Epoch 5 Training Loss: nan


Validation Epoch 5:   0%|          | 0/401 [00:00<?, ?it/s]

Epoch 5 Validation Loss: nan

Epoch 6/10


Training Epoch 6:   0%|          | 0/1603 [00:00<?, ?it/s]

Epoch 6 Training Loss: nan


Validation Epoch 6:   0%|          | 0/401 [00:00<?, ?it/s]

Epoch 6 Validation Loss: nan

Epoch 7/10


Training Epoch 7:   0%|          | 0/1603 [00:00<?, ?it/s]

Epoch 7 Training Loss: nan


Validation Epoch 7:   0%|          | 0/401 [00:00<?, ?it/s]

Epoch 7 Validation Loss: nan

Epoch 8/10


Training Epoch 8:   0%|          | 0/1603 [00:00<?, ?it/s]

Epoch 8 Training Loss: nan


Validation Epoch 8:   0%|          | 0/401 [00:00<?, ?it/s]

Epoch 8 Validation Loss: nan

Epoch 9/10


Training Epoch 9:   0%|          | 0/1603 [00:00<?, ?it/s]

Epoch 9 Training Loss: nan


Validation Epoch 9:   0%|          | 0/401 [00:00<?, ?it/s]

Epoch 9 Validation Loss: nan

Epoch 10/10


Training Epoch 10:   0%|          | 0/1603 [00:00<?, ?it/s]

Epoch 10 Training Loss: nan


Validation Epoch 10:   0%|          | 0/401 [00:00<?, ?it/s]

Epoch 10 Validation Loss: nan

Training complete!
