In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor
from math import sqrt
import warnings


PRETRAINED_MODEL_NAME = "openai/clip-vit-base-patch16"
IMG_SIZE = 224
SKIP_LAYER_INDICES = [3, 5, 7, 9]

# --- CLIP ViT Encoder ---
class ClipViTEncoder(nn.Module):
    def __init__(self, model_name=PRETRAINED_MODEL_NAME, freeze_encoder=True, skip_indices=SKIP_LAYER_INDICES):
        super().__init__()
        self.model_name = model_name
        self.skip_indices = sorted(skip_indices, reverse=True) # Process deeper layers first for bottleneck logic

        self.config = CLIPVisionConfig.from_pretrained(model_name)
        self.clip_vit = CLIPVisionModel.from_pretrained(model_name)

        if freeze_encoder:
            for param in self.clip_vit.parameters():
                param.requires_grad = False

        # Calculate grid size
        self.patch_size = self.config.patch_size
        # Ensure image size in config matches expected if not default
        if self.config.image_size != IMG_SIZE:
             warnings.warn(f"Model config image size {self.config.image_size} differs from specified IMG_SIZE {IMG_SIZE}. Using model config size for grid calculation.")
        self.grid_size = self.config.image_size // self.patch_size
        self.hidden_dim = self.config.hidden_size
        self.num_patches = self.grid_size * self.grid_size

        # print(f"  Model: {model_name}")
        # print(f"  Expected Input Image Size (from config): {self.config.image_size}x{self.config.image_size}")
        # print(f"  Patch Size: {self.patch_size}x{self.patch_size}")
        # print(f"  Calculated Grid Size (G): {self.grid_size}x{self.grid_size}")
        # print(f"  Calculated Num Patches (N_p = G*G): {self.num_patches}")
        # print(f"  Hidden Dim (D_vit): {self.hidden_dim}")
        # print(f"  Targeting hidden states at indices: {self.skip_indices}")
        # print(f"  Encoder frozen: {freeze_encoder}")
        # print("--- ClipViTEncoder Initialized ---\n")


    def forward(self, x):
        # print("\n--- ClipViTEncoder Forward ---")
        # print(f"  Input shape: {x.shape}")
        if x.shape[2] != self.config.image_size or x.shape[3] != self.config.image_size:
             warnings.warn(
                 f"Input image size ({x.shape[2]}x{x.shape[3]}) doesn't match "
                 f"CLIP expected size ({self.config.image_size}x{self.config.image_size}). "
                 f"Behavior may be unexpected. Consider resizing input."
             )

        # Pass image through CLIP ViT
        outputs = self.clip_vit(pixel_values=x, output_hidden_states=True)
        all_hidden_states = outputs.hidden_states

        last_hidden_state = outputs.last_hidden_state

        patch_embeddings = last_hidden_state[:, 1:, :]

        if patch_embeddings.shape[1] != self.num_patches:
             print("here1")
             current_num_patches = patch_embeddings.shape[1]
             current_grid_size = int(sqrt(current_num_patches))
             if current_grid_size * current_grid_size != current_num_patches:
                 raise ValueError(f"Cannot reshape patch embeddings.")
             warnings.warn(f"Patch count mismatch. Reshaping to {current_grid_size}x{current_grid_size}.")
             grid_h, grid_w = current_grid_size, current_grid_size
        else:
             grid_h, grid_w = self.grid_size, self.grid_size

        bottleneck_features = patch_embeddings.reshape(x.shape[0], grid_h, grid_w, self.hidden_dim).permute(0, 3, 1, 2).contiguous()

        skip_features_list = []

        # print("  Extracting features from specified hidden state indices:")
        for i in self.skip_indices: # Iterates [12, 9, 6, 3]
            # print(f"    Processing hidden state index: {i}")
            hidden_state = all_hidden_states[i]
            # print(f"      Original hidden_state shape: {hidden_state.shape}") # Should be (B, N_p+1, D_vit)

            # Remove the CLS token embedding
            patch_embeddings = hidden_state[:, 1:, :]
            # print(f"      Shape after removing CLS token: {patch_embeddings.shape}") # Should be (B, N_p, D_vit)

            # Check num_patches consistency
            if patch_embeddings.shape[1] != self.num_patches:
                 # Handle potential mismatch (e.g., if input size wasn't exactly config.image_size)
                 current_num_patches = patch_embeddings.shape[1]
                 current_grid_size = int(sqrt(current_num_patches))
                 if current_grid_size * current_grid_size != current_num_patches:
                     raise ValueError(f"Cannot reliably reshape patch embeddings. Non-square grid? Expected {self.num_patches} patches, got {current_num_patches}.")
                 warnings.warn(f"Actual patch count {current_num_patches} differs from expected {self.num_patches}. Reshaping to {current_grid_size}x{current_grid_size}.")
                 grid_h, grid_w = current_grid_size, current_grid_size
            else:
                 grid_h, grid_w = self.grid_size, self.grid_size

            # Reshape to (B, grid_h, grid_w, D_vit) -> (B, D_vit, grid_h, grid_w)
            reshaped_features = patch_embeddings.reshape(
                x.shape[0], grid_h, grid_w, self.hidden_dim
            )
            reshaped_features = reshaped_features.permute(0, 3, 1, 2).contiguous()
            # print(f"      Shape after reshaping to grid (B, D_vit, G, G): {reshaped_features.shape}")

            skip_features_list.append(reshaped_features)

        # Reverse the collected skips to be shallowest feature first
        skip_features_list.reverse() # Now order matches [index_3_out, index_6_out, index_9_out]

        # print("--- ClipViTEncoder Forward End ---")
        return bottleneck_features, skip_features_list

# --- Example U-Net Decoder Block ---
class DecoderBlock(nn.Module):
    def __init__(self, block_index, in_channels_upsample, in_channels_skip, out_channels, skip_spatial_size, target_spatial_size):
        super().__init__()
        self.block_index = block_index
        # print(f"    Initializing DecoderBlock {self.block_index}: Upsample_in={in_channels_upsample}, Skip_in={in_channels_skip}, Out={out_channels}")
        self.skip_spatial_size = skip_spatial_size
        self.target_spatial_size = target_spatial_size
        self.upsample_out_channels = in_channels_upsample // 2
        self.skip_conv_out_channels = in_channels_upsample // 2 # Match for concatenation

        self.upsample = nn.ConvTranspose2d(in_channels_upsample, self.upsample_out_channels, kernel_size=2, stride=2)
        self.skip_conv = nn.Conv2d(in_channels_skip, self.skip_conv_out_channels, kernel_size=1) # 1x1 conv

        conv_in_channels = self.upsample_out_channels + self.skip_conv_out_channels
        # print(f"      Block {self.block_index}: Upsample out channels={self.upsample_out_channels}, SkipConv out channels={self.skip_conv_out_channels}, Concat channels={conv_in_channels}")
        self.conv_block = nn.Sequential(
            nn.Conv2d(conv_in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x_upsample, x_skip):
        # print(f"\n  --- DecoderBlock {self.block_index} Forward ---")
        # print(f"    Input x_upsample shape: {x_upsample.shape}")
        # print(f"    Input x_skip shape: {x_skip.shape}")

        x_upsample = self.upsample(x_upsample)
        # print(f"    Shape after upsample: {x_upsample.shape}")

        x_skip_proj = self.skip_conv(x_skip)
        # print(f"    Shape after skip_conv (channel adjust): {x_skip_proj.shape}")

        # Resize skip connection spatially if needed
        if x_skip_proj.shape[2:] != x_upsample.shape[2:]:
            #   print(f"    Resizing skip connection from {x_skip_proj.shape[2:]} to {x_upsample.shape[2:]}")
              x_skip_resized = F.interpolate(x_skip_proj, size=x_upsample.shape[2:], mode='bilinear', align_corners=False)
            #   print(f"    Shape after skip resize: {x_skip_resized.shape}")
        else:
              x_skip_resized = x_skip_proj
            #   print(f"    Skip connection spatial size matches upsample, no resize needed.")

        # Concatenate
        x = torch.cat([x_upsample, x_skip_resized], dim=1)
        # print(f"    Shape after concatenation: {x.shape}")

        # Apply convolutions
        x = self.conv_block(x)
        # print(f"    Shape after conv_block (Output): {x.shape}")
        # print(f"  --- DecoderBlock {self.block_index} Forward End ---")
        return x

# --- Example U-Net Decoder ---
class UNetDecoder(nn.Module):
    def __init__(self, encoder_hidden_dim, encoder_grid_size, decoder_channels):
        super().__init__()
        # print("\n--- Initializing UNetDecoder ---")
        self.encoder_hidden_dim = encoder_hidden_dim
        self.encoder_grid_size = encoder_grid_size
        self.decoder_channels = decoder_channels
        num_blocks = len(decoder_channels)

        # *** Important Check ***
        num_expected_skips = len(SKIP_LAYER_INDICES) # One layer is bottleneck
        # print(f"  Encoder provides {num_expected_skips} skip connections (indices besides bottleneck).")
        # print(f"  Decoder configured with {num_blocks} blocks for channels: {decoder_channels}.")
        if num_blocks != num_expected_skips:
            warnings.warn(
               f"*** POTENTIAL MISMATCH: Number of decoder blocks ({num_blocks}) "
               f"does not match the number of available skip connections ({num_expected_skips}). "
               f"This might lead to an IndexError during forward pass if not handled carefully. ***"
            )
        # The code proceeds to create num_blocks based on decoder_channels length

        in_channels_upsample = encoder_hidden_dim
        self.blocks = nn.ModuleList()
        current_spatial_size = encoder_grid_size
        # print(f"  Creating {num_blocks} DecoderBlock(s):")
        for i in range(num_blocks):
            out_ch = decoder_channels[i]
            in_channels_skip = encoder_hidden_dim
            skip_spatial_size = encoder_grid_size
            target_spatial_size = current_spatial_size * 2 # Assumes stride-2 upsampling

            # Pass index for clearer prints inside block
            block = DecoderBlock(
                block_index=i,
                in_channels_upsample=in_channels_upsample,
                in_channels_skip=in_channels_skip,
                out_channels=out_ch,
                skip_spatial_size=skip_spatial_size,
                target_spatial_size=target_spatial_size
            )
            self.blocks.append(block)

            in_channels_upsample = out_ch # Next block's upsample input is this block's output
            current_spatial_size = target_spatial_size

        # print("--- UNetDecoder Initialized ---\n")


    def forward(self, bottleneck_features, skip_features):
        # print("\n--- UNetDecoder Forward ---")
        # print(f"  Input bottleneck_features shape: {bottleneck_features.shape}")
        # print(f"  Input skip_features list length: {len(skip_features)}")
        x = bottleneck_features
        num_available_skips = len(skip_features)

        # Iterate through decoder blocks (indices 0 to num_blocks-1)
        for i, block in enumerate(self.blocks):
            # print(f"  Processing Decoder Block {i}")
            # Calculate index for skip_features list (shallowest first)
            # Block 0 (deepest decoder) uses last skip (deepest skip feature)
            skip_idx = num_available_skips - 1 - i
            # print(f"    Attempting to use skip feature at index: {skip_idx} (from list of {num_available_skips})")

            if skip_idx < 0 or skip_idx >= num_available_skips:
                 # This will happen if num_blocks > num_available_skips
                 raise IndexError(f"Attempted to access skip_features[{skip_idx}] but only {num_available_skips} skips are available. Mismatch between decoder blocks and encoder skips.")

            skip = skip_features[skip_idx]
            # print(f"    Using skip feature with shape: {skip.shape}")
            x = block(x, skip) # Pass current features and corresponding skip

        # print(f"  Final output shape from UNetDecoder: {x.shape}")
        # print("--- UNetDecoder Forward End ---")
        return x

# --- Combined CLIP-U-Net Model ---
class ClipUNet(nn.Module):
    def __init__(self, num_classes=4, decoder_channels=[512, 256, 128, 64], freeze_encoder=True):
        super().__init__()
        # print("\n--- Initializing ClipUNet ---")
        # print(f"  Num output classes: {num_classes}")
        # print(f"  Decoder channels specified: {decoder_channels}")

        self.encoder = ClipViTEncoder(freeze_encoder=freeze_encoder)

        self.decoder = UNetDecoder(
            encoder_hidden_dim=self.encoder.hidden_dim,
            encoder_grid_size=self.encoder.grid_size,
            decoder_channels=decoder_channels
        )

        # Final 1x1 convolution
        last_decoder_channel = decoder_channels[-1] if decoder_channels else self.encoder.hidden_dim # Handle no decoder case?
        # print(f"  Final Conv: {last_decoder_channel} -> {num_classes} channels")
        self.final_conv = nn.Conv2d(last_decoder_channel, num_classes, kernel_size=1)

        # Final upsampling layer check
        final_decoder_size = self.encoder.grid_size * (2**len(decoder_channels))
        # print(f"  Calculated final decoder spatial size: {final_decoder_size}x{final_decoder_size}")
        if final_decoder_size != IMG_SIZE:
            #  print(f"  Final decoder size {final_decoder_size} != Target Image Size {IMG_SIZE}. Adding final Upsample layer.")
             self.final_upsample = nn.Upsample(size=(IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
        else:
            #  print(f"  Final decoder size matches target image size. Final upsample is Identity.")
             self.final_upsample = nn.Identity()

        # print("--- ClipUNet Initialized ---\n")


    def forward(self, x):
        # print("\n\n========== ClipUNet Forward Pass Start ==========")
        # print(f"Overall Input shape: {x.shape}")

        # 1. Encoder
        bottleneck, skips = self.encoder(x)
        # print("\n--- Back in ClipUNet Forward (after Encoder) ---")
        # print(f"  Encoder returned bottleneck shape: {bottleneck.shape}")
        # print(f"  Encoder returned {len(skips)} skip features.")

        # 2. Decoder
        decoder_output = self.decoder(bottleneck, skips)
        # print("\n--- Back in ClipUNet Forward (after Decoder) ---")
        # print(f"  Decoder returned output shape: {decoder_output.shape}")

        # 3. Final Convolution
        output = self.final_conv(decoder_output)
        # print("\n--- Back in ClipUNet Forward (after Final Conv) ---")
        # print(f"  Shape after final_conv: {output.shape}")

        # 4. Final Upsampling (if needed)
        output = self.final_upsample(output)
        # print("\n--- Back in ClipUNet Forward (after Final Upsample) ---")
        # print(f"  Shape after final_upsample: {output.shape}")

        # print("========== ClipUNet Forward Pass End ==========\n")
        return output

In [20]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, PILToTensor
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


import os
from torchvision.io import decode_image

target_batch_size = 64
batch_size = 16

class dataset(Dataset):
    def __init__(self, img_dir, label_dir, transform=None, target_transform=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.img_names = sorted([os.path.splitext(filename)[0] for filename in os.listdir(img_dir)])
        self.len = len(self.img_names)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        img = decode_image(os.path.join(self.img_dir, self.img_names[idx] + ".jpg")).float()/255
        label = decode_image(os.path.join(self.label_dir, self.img_names[idx] + ".png"))

        if self.transform:
            img = self.transform(img)

        if self.target_transform:
            label = self.target_transform(label)

        return img, label
    
def display_img_label(data, idx):
    img, label = data[idx]
    figure = plt.figure(figsize=(10,20))
    figure.add_subplot(1, 2, 1)
    plt.imshow(img.permute(1, 2, 0))

    figure.add_subplot(1, 2, 2)
    plt.imshow(label.permute(1, 2, 0), cmap='grey')

    plt.show()

class target_remap(object):
    def __call__(self, img):
        img[img == 255] = 3
        return img

def diff_size_collate(batch):
    imgs = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    return imgs, labels
    
training_data = dataset("astrain/color", "astrain/label", target_transform=target_remap())
val_data = dataset("Val/color", "Val/label", target_transform=target_remap())
test_data = dataset("Test/color", "Test/label", target_transform=target_remap())

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=diff_size_collate)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, pin_memory=True, collate_fn=diff_size_collate)

In [21]:
import torch
import torch.nn.functional as F
from torchvision.transforms import functional as TF
from tqdm import tqdm

# Assume these are defined somewhere in your code
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
use_amp = True  # or False depending on your setup

def resize_with_padding(image, target_size=512):
    """
    Resize a single image (Tensor of shape (C, H, W)) so that the longer side
    equals target_size, preserving aspect ratio; add black padding as needed.
    Returns the resized and padded image, plus a metadata dictionary.
    """
    _, orig_h, orig_w = image.shape
    scale = min(target_size / orig_w, target_size / orig_h)
    new_w = int(round(orig_w * scale))
    new_h = int(round(orig_h * scale))
    
    # Resize the image
    image_resized = TF.resize(image, size=(new_h, new_w))
    
    # Compute padding on each side
    pad_w = target_size - new_w
    pad_h = target_size - new_h
    pad_left = pad_w // 2
    pad_right = pad_w - pad_left
    pad_top = pad_h // 2
    pad_bottom = pad_h - pad_top

    # Pad the image (padding order: left, top, right, bottom)
    image_padded = TF.pad(image_resized, padding=(pad_left, pad_top, pad_right, pad_bottom), fill=0)

    meta = {
        "original_size": (orig_h, orig_w),
        "new_size": (new_h, new_w),
        "pad": (pad_left, pad_top, pad_right, pad_bottom),
        "scale": scale
    }
    return image_padded, meta

def reverse_resize_and_padding(image, meta, interpolation="bilinear"):
    """
    Remove the padding from image (Tensor of shape (C, target_size, target_size))
    using metadata and then resize the cropped image back to the original size.
    interpolation: "bilinear" for continuous outputs; use "nearest" for label maps.
    """
    pad_left, pad_top, pad_right, pad_bottom = meta["pad"]
    new_h, new_w = meta["new_size"]
    
    # Crop out the padding: from pad_top to pad_top+new_h and pad_left to pad_left+new_w.
    image_cropped = image[..., pad_top: pad_top + new_h, pad_left: pad_left + new_w]
    
    # Resize the cropped image back to the original size.
    orig_h, orig_w = meta["original_size"]
    # F.interpolate expects a 4D tensor.
    image_original = F.interpolate(image_cropped.unsqueeze(0),
                                   size=(orig_h, orig_w),
                                   mode=interpolation,
                                   align_corners=False if interpolation != "nearest" else None)
    return image_original.squeeze(0)

def process_batch_forward(batch_images, target_size=512):
    """
    Process a batch (Tensor of shape (N, C, H, W)) by resizing each image to target_size
    with aspect ratio preserved (adding black padding).
    Returns the processed batch and a list of meta dictionaries.
    """
    resized_batch = []
    meta_list = []
    for image in batch_images:
        if image.ndim == 3 and image.shape[0] == 4:
            image = image[:3, ...] # Slice to keep only the first 3 channels (R, G, B) (RGBA to RGB)
        image_resized, meta = resize_with_padding(image, target_size)
        resized_batch.append(image_resized)
        meta_list.append(meta)
    return torch.stack(resized_batch), meta_list

def process_batch_reverse(batch_outputs, meta_list, interpolation="bilinear"):
    """
    Given a batch of network outputs of shape (N, C, target_size, target_size) and the
    corresponding meta info, reverse the transform for each one to obtain predictions at their
    original sizes.
    """
    original_outputs = []
    for output, meta in zip(batch_outputs, meta_list):
        restored = reverse_resize_and_padding(output, meta, interpolation=interpolation)
        original_outputs.append(restored)
    return original_outputs

In [22]:
class MetricsHistory:
    """
    Accumulates TP, FP, FN, TN over an epoch for multi-class segmentation
    and computes Dice, IoU, and Accuracy metrics.
    """
    def __init__(self, num_classes: int, ignore_index: int = None, device: str = 'cpu'):
        """
        Args:
            num_classes (int): Number of classes including background.
            ignore_index (int, optional): Index of the class to ignore during metric calculation. Defaults to None.
            device (str): Device to perform initial calculations, results accumulated on CPU.
        """
        self.num_classes = num_classes
        self.ignore_index = ignore_index

        self.total_tp = torch.zeros(num_classes, dtype=torch.float64, device='cpu')
        self.total_fp = torch.zeros(num_classes, dtype=torch.float64, device='cpu')
        self.total_fn = torch.zeros(num_classes, dtype=torch.float64, device='cpu')
        self.total_tn = torch.zeros(num_classes, dtype=torch.float64, device='cpu')

        # History lists for epoch metrics
        self.epoch_mean_dice_history = []
        self.epoch_mean_iou_history = []
        self.epoch_mean_acc_history = []
        
        self.epoch_per_class_dice_history = []
        self.epoch_per_class_iou_history = []
        self.epoch_per_class_acc_history = []
        
        self.last_per_class_iou = None
        self.last_per_class_dice = None
        self.last_per_class_acc = None

        # Metric mask (calculated once)
        self.mask = torch.ones(num_classes, dtype=torch.bool)
        if self.ignore_index is not None and 0 <= self.ignore_index < self.num_classes:
            self.mask[self.ignore_index] = False

    def reset(self):
        """Resets the accumulated TP, FP, FN, TN counts."""
        self.total_tp.zero_()
        self.total_fp.zero_()
        self.total_fn.zero_()
        self.total_tn.zero_()

    def accumulate(self, pred: torch.Tensor, label: torch.Tensor):
        """
        Accumulates statistics for a single prediction-label pair.

        Args:
            pred (torch.Tensor): Predicted logits or probabilities (C, H, W). Should be on self.device or moved.
            label (torch.Tensor): Ground truth label map (H, W), LongTensor. Should be on self.device or moved.
        """

        # Get hard predictions
        pred_hard = torch.argmax(pred.squeeze(0), dim=0) # (H, W)

        # One-hot encode
        label_onehot = F.one_hot(label.squeeze(0), num_classes=self.num_classes).permute(2, 0, 1).bool() # (C, H, W)
        pred_onehot = F.one_hot(pred_hard, num_classes=self.num_classes).permute(2, 0, 1).bool() # (C, H, W)

        # Calculate TP, FP, FN, TN per class
        tp = (pred_onehot & label_onehot).sum(dim=(1, 2))
        fp = (pred_onehot & ~label_onehot).sum(dim=(1, 2))
        fn = (~pred_onehot & label_onehot).sum(dim=(1, 2))
        tn = (~pred_onehot & ~label_onehot).sum(dim=(1, 2))
        
        # tp = (pred_onehot & label_onehot).sum(dim=(1, 2))
        # fp = pred_onehot.sum(dim=(1, 2)) - tp
        # fn = label_onehot.sum(dim=(1, 2)) - tp
        # tn = label.numel() - fn - fp - tp

        # Accumulate on CPU with float64
        self.total_tp += tp.cpu().to(torch.float64)
        self.total_fp += fp.cpu().to(torch.float64)
        self.total_fn += fn.cpu().to(torch.float64)
        self.total_tn += tn.cpu().to(torch.float64) # Accumulate TN if needed for accuracy


    def compute_epoch_metrics(self, epsilon: float = 1e-6):
        """
        Computes the macro-averaged metrics for the accumulated epoch statistics,
        appends them to the history lists, and returns the computed mean metrics.

        Args:
            epsilon (float): Small value to avoid division by zero.

        Returns:
            tuple: (mean_dice, mean_iou, mean_acc) for the current epoch.
        """

        tp = self.total_tp
        fp = self.total_fp
        fn = self.total_fn
        tn = self.total_tn

        per_class_iou = tp / (tp + fp + fn)
        per_class_dice = (2 * tp) / (2 * tp + fp + fn)
        per_class_acc = (tp + tn) / (tp + tn + fp + fn)

        mean_iou = per_class_iou[self.mask].mean().item()
        mean_dice = per_class_dice[self.mask].mean().item()
        mean_acc = per_class_acc[self.mask].mean().item()

        # Append to history
        self.epoch_mean_iou_history.append(mean_iou)
        self.epoch_mean_dice_history.append(mean_dice)
        self.epoch_mean_acc_history.append(mean_acc)

        self.epoch_per_class_iou_history.append(per_class_iou.numpy())
        self.epoch_per_class_dice_history.append(per_class_dice.numpy())
        self.epoch_per_class_acc_history.append(per_class_acc.numpy())

        self.last_per_class_iou = per_class_iou
        self.last_per_class_dice = per_class_dice
        self.last_per_class_acc = per_class_acc

        return mean_dice, mean_iou, mean_acc
    
    def get_ignore_index(self):
        return self.ignore_index
    
    def get_num_classes(self):
        return self.num_classes

    def get_mean_dice_history(self):
        return self.epoch_mean_dice_history

    def get_mean_iou_history(self):
        return self.epoch_mean_iou_history

    def get_mean_acc_history(self):
        return self.epoch_mean_acc_history
    
    def get_class_dice_history(self):
        return self.epoch_per_class_dice_history

    def get_class_iou_history(self):
        return self.epoch_per_class_iou_history

    def get_class_acc_history(self):
        return self.epoch_per_class_acc_history

    def get_last_per_class_dice(self):
        return self.last_per_class_dice
    
    def get_last_per_class_iou(self):
        return self.last_per_class_iou
    
    def get_last_per_class_acc(self):
        return self.last_per_class_acc


In [23]:
# --- Training Loop (Adapted for nnU-Net style) ---
def train_loop(dataloader, model, loss_fn, optimizer, scheduler, accumulation_steps, device, target_size = None):
    """Performs one epoch of training resembling nnU-Net practices."""
    model.train()
    total_loss = 0.0
    processed_batches = 0

    optimizer.zero_grad()

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Training")
    for batch_idx, (X, y) in pbar:

        if target_size is not None:
            X, _ = process_batch_forward(X, target_size=target_size)
            y, _ = process_batch_forward(y, target_size=target_size)
        
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        scaled_loss = loss / accumulation_steps
        scaled_loss.backward()

        # Optimizer step after accumulation_steps batches
        if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(dataloader):
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            total_loss += loss.item()
            processed_batches += 1
            pbar.set_postfix({'loss': loss.item(), 'lr': optimizer.param_groups[0]['lr']})

    avg_loss = total_loss / processed_batches if processed_batches > 0 else 0
    print(f"Training Avg loss (per effective batch): {avg_loss:>8f}")
    return avg_loss

# --- Evaluation Loop (Modified for aggregated IoU) ---
def eval_loop(dataloader, model, loss_fn, device, target_size, agg):
    """
    Evaluation loop calculating loss, aggregated Dice, and aggregated IoU.

    Args:
        dataloader: yields batches of (list[Tensor(C,H,W)], list[Tensor(H,W)])
        model: the neural network model (on device)
        loss_fn: the combined loss function (e.g., DiceCELoss) used for training
        device: the torch device (cuda or cpu)
        target_size: the size the model expects for input
    """
    model.eval()
    num_images_processed = 0
    total_loss = 0.0
    num_classes = 4
    
    with torch.no_grad():
        for X, y in tqdm(dataloader, desc="Eval"):

            X, meta_list = process_batch_forward(X, target_size=target_size)
            X = X.to(device)
            preds = model(X) # Logits [N, C, target, target]

            preds = process_batch_reverse(preds, meta_list, interpolation='bilinear')

            for pred, label in zip(preds, y):
                pred = pred.to(device) # (C,H,W)
                label = label.to(device).long() # (H,W)

                loss = loss_fn(pred.unsqueeze(0), label.unsqueeze(0)) # Add batch dimension
                total_loss += loss.item()
                agg.accumulate(pred, label)
                
                num_images_processed += 1

    avg_loss = total_loss / num_images_processed

    mean_dice, mean_iou, mean_acc = agg.compute_epoch_metrics()
    per_class_iou = agg.get_last_per_class_iou()
    ignore_index = agg.get_ignore_index()

    print(f"\n--- Evaluation Complete ---")
    print(f"  Images Processed: {num_images_processed}")
    print(f"  Average Loss (Original Size): {avg_loss:>8f}")
    print(f"  Ignored Class : {ignore_index}")
    print(f"  Macro Avg Acc score: {mean_acc:>8f}")
    print(f"  Macro Avg Dice Score: {mean_dice:>8f}")
    print(f"  Mean IoU (mIoU): {mean_iou:>8f}")
    print(f"  --- Per-Class IoU ---")
    for c in range(num_classes):
        print(f"    Class {c}: {per_class_iou[c].item():>8f}")
    print("-" * 25)

    return avg_loss, mean_dice, mean_iou

In [37]:
import torch
import numpy as np
import os
from tqdm import tqdm
from PIL import Image
from typing import Optional, List # Added List

def calculate_class_weights_v3(
    label_source,
    num_classes: int,
    ignore_index: Optional[int] = None,
    source_type: str = 'files',
    unimportant_class_indices: Optional[List[int]] = None, # Indices to down-weight
    target_unimportant_weight: float = 1.0, # Target weight for unimportant classes
    normalize_target_sum: float = -1.0 # Normalize weights sum (-1 means num_classes)
) -> torch.Tensor:
    """
    Calculates class weights based on inverse frequency, then adjusts weights
    for specified unimportant classes and re-normalizes.

    Args:
        # ... (Args same as v2 except for target_unimportant_weight) ...
        target_unimportant_weight: The desired weight value for unimportant classes
                                   BEFORE final re-normalization. 1.0 is often neutral.
    Returns:
        A torch.Tensor of shape [num_classes] containing weights.
    """
    print(f"Calculating class weights for {num_classes} classes (v3)...")
    if unimportant_class_indices:
        print(f"  Adjusting weights for classes {unimportant_class_indices} post-calculation.")

    # 1. Count all classes (Same counting loop as v2)
    counts = torch.zeros(num_classes, dtype=torch.float64)
    total_valid_pixels = 0
    # --- Start of loop ---
    # (Identical loop structure as before to get raw 'counts')
    iterator = None
    num_labels = 0
    if source_type == 'files': assert isinstance(label_source, (list, tuple)); iterator = label_source; num_labels = len(label_source)
    elif source_type == 'dataset': assert hasattr(label_source, '__getitem__') and hasattr(label_source, '__len__'); iterator = range(len(label_source)); num_labels = len(label_source)
    else: raise ValueError("source_type must be either 'files' or 'dataset'")
    print(f"Processing {num_labels} labels...")
    pbar = tqdm(iterator, total=num_labels)

    for idx_or_path in pbar:
        try:
            label_data = None
            # ... (Identical label loading logic as before) ...
            if source_type == 'files':
                path = idx_or_path; img = Image.open(path); label_data = torch.from_numpy(np.array(img)) # Example
            elif source_type == 'dataset':
                 _, label_data = label_source[idx_or_path]; label_data = torch.tensor(label_data) if not isinstance(label_data, torch.Tensor) else label_data
            if label_data is None: continue
            label_long = label_data.long().view(-1)
            if ignore_index is not None: valid_mask = (label_long != ignore_index); label_valid = label_long[valid_mask]
            else: label_valid = label_long
            label_valid = torch.clamp(label_valid, 0, num_classes - 1)
            if label_valid.numel() > 0:
                 counts += torch.bincount(label_valid, minlength=num_classes).double()
                 total_valid_pixels += label_valid.numel()

        except Exception as e:
             item_id = idx_or_path if source_type=='files' else f"index {idx_or_path}"; print(f"\nError processing {item_id}: {e}. Skipping.")
             continue
    # --- End of loop ---
    print("\nFinished counting.")
    print(f"Raw pixel counts per class: {counts.long().tolist()}")
    print(f"Total valid pixels counted: {total_valid_pixels}")

    if total_valid_pixels == 0:
        print("Warning: No valid pixels found. Returning equal weights.")
        return torch.ones(num_classes, dtype=torch.float32)

    # 2. Calculate initial inverse frequency weights for ALL classes
    frequencies = counts / total_valid_pixels
    epsilon = 1e-6
    inverse_frequencies = 1.0 / (frequencies + epsilon)

    # Intermediate weights (no normalization yet)
    weights = inverse_frequencies

    # 3. Adjust weights for unimportant classes
    if unimportant_class_indices:
        for idx in unimportant_class_indices:
            if 0 <= idx < num_classes:
                # Find the 'base' frequency corresponding to the target weight
                # If target is 1.0, it corresponds to average frequency's inverse
                # For simplicity, let's just scale relative to others,
                # or simply set it directly. Setting directly is easier.
                # Option A: Set directly relative to the *initial* mean inverse frequency
                # mean_inv_freq = inverse_frequencies.mean()
                # weights[idx] = mean_inv_freq * target_unimportant_weight
                # Option B: Just set it towards 1 (average weight after normalization)
                # This value might need tuning
                if target_unimportant_weight == -1:
                    weights[idx] = min(weights) # Assign the target weight DIRECTLY.
                else:
                    weights[idx] = target_unimportant_weight
                 # Note: This works well if target_unimportant_weight=1.0 and normalization
                 # makes the average weight 1.0 later.

            else:
                warnings.warn(f"unimportant_class_index {idx} is out of bounds.")

    # 4. Normalize the *adjusted* weights
    target_sum = normalize_target_sum if normalize_target_sum > 0 else float(num_classes)
    final_weights = weights / weights.sum() * target_sum

    print(f"Calculated Final Class Weights: {final_weights.tolist()}")

    return final_weights.float()

import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from torch.optim.lr_scheduler import PolynomialLR # Import PolynomialLR
from tqdm import tqdm
import numpy as np
import os
import torch.nn.functional as F # Needed for softmax in Dice Loss
import warnings # Keep warnings
from typing import Optional, List, Dict, Tuple, Callable # Keep typing

# --- Configuration ---
EPOCHS = 100
MODEL_SAVE_DIR = "unet" # Changed path
INITIAL_LR = 0.01 # Standard nnU-Net initial LR for SGD
WEIGHT_DECAY = 3e-5 # A common weight decay value, adjust if needed
SGD_MOMENTUM = 0.99 # nnU-Net standard momentum

os.makedirs(MODEL_SAVE_DIR, exist_ok=True)


# --- Modified Dice Loss with Class Weights ---
class WeightedMemoryEfficientDiceLoss(nn.Module):
    """ Version using ignore_index and supporting class weights """
    def __init__(self,
                 apply_softmax: bool = True,
                 ignore_index: Optional[int] = None,
                 class_weights: Optional[torch.Tensor] = None, # New parameter
                 smooth: float = 1e-5):
        super().__init__()
        self.apply_softmax = apply_softmax
        self.ignore_index = ignore_index
        self.smooth = smooth

        # Store class weights, ensuring they are a Tensor if provided
        if class_weights is not None:
            assert isinstance(class_weights, torch.Tensor), "class_weights must be a torch.Tensor"
            self.class_weights = class_weights
        else:
            self.class_weights = None


    def forward(self, x, y):
        num_classes = x.shape[1]
        shp_y = y.shape
        if self.apply_softmax:
            probs = F.softmax(x, dim=1)
        else:
            probs = x

        # --- One-Hot Encoding and Masking ---
        with torch.no_grad():
            # Shape adjustments (same as before)
            if len(shp_y) != len(probs.shape):
                if len(shp_y) == len(probs.shape) - 1 and len(shp_y) >= 2 and shp_y == probs.shape[2:]:
                     y = y.unsqueeze(1)
                elif len(shp_y) == len(probs.shape) and shp_y[1] == 1: pass # ok
                else: raise ValueError(f"Shape mismatch: probs {probs.shape}, y {shp_y}")
            y_long = y.long()

            # Spatial mask based on ignore_index
            mask = None
            # if self.ignore_index is not None:
            #     mask = (y_long != self.ignore_index)

            # Create one-hot ground truth (potentially masked)
            if probs.shape == y.shape: # Already one-hot
                 y_onehot = y.float()
                 if mask is not None:
                      y_indices_for_mask = torch.argmax(y_onehot, dim=1, keepdim=True)
                      mask = (y_indices_for_mask != self.ignore_index)
                      y_onehot = y_onehot * mask
            else: # Create from index map
                y_onehot = torch.zeros_like(probs, device=probs.device)
                y_onehot.scatter_(1, y_long, 1)
                if mask is not None: y_onehot = y_onehot * mask

            sum_gt = y_onehot.sum(dim=(2, 3)) # Pre-calculate GT sum needed later [N, C]
        # --- End One-Hot Encoding ---

        # Apply spatial mask to probabilities before summation
        if mask is not None:
             probs = probs * mask

        # Calculate intersection and prediction sum (still per-sample)
        intersect_persample = (probs * y_onehot).sum(dim=(2, 3)) # Shape [N, C]
        sum_pred_persample = probs.sum(dim=(2, 3))              # Shape [N, C]
        sum_gt_persample = sum_gt                               # Shape [N, C]

        # --- Aggregate across batch ---
        intersect = intersect_persample.sum(0) # Shape [C]
        sum_pred = sum_pred_persample.sum(0)   # Shape [C]
        sum_gt = sum_gt_persample.sum(0)     # Shape [C]

        # --- Calculate per-class Dice ---
        denominator = sum_pred + sum_gt
        dc = (2. * intersect + self.smooth) / (torch.clip(denominator + self.smooth, 1e-8)) # Shape [C]

        # --- Average Dice Logic (Weighted or Unweighted) ---
        # Mask for valid (non-ignored) classes
        valid_classes_mask = torch.ones_like(dc, dtype=torch.bool)
        if self.ignore_index is not None and 0 <= self.ignore_index < num_classes:
            valid_classes_mask[self.ignore_index] = False

        dc_final = torch.tensor(0.0, device=dc.device) # Default loss if no valid classes

        if valid_classes_mask.sum() > 0: # Proceed only if there are valid classes
            dc_valid = dc[valid_classes_mask] # Dice scores for valid classes

            if self.class_weights is not None:
                # Use weighted average for valid classes
                weights = self.class_weights.to(dc_valid.device) # Ensure weights are on correct device
                weights_valid = weights[valid_classes_mask] # Select weights for valid classes
                # Calculate weighted mean: sum(value*weight) / sum(weight)
                weighted_sum = (dc_valid * weights_valid).sum()
                weight_sum = weights_valid.sum()
                dc_final = weighted_sum / weight_sum.clamp(min=1e-8) # Avoid division by zero
            else:
                # Use simple mean if no weights are provided
                dc_final = dc_valid.mean()

        return -dc_final # Return negative Dice score as loss


# --- Modified Combined Loss to use Weighted Dice and accept CE weights ---
class WeightedDiceCELoss(nn.Module): # Renamed for clarity
    """Combines WeightedMemoryEfficientDiceLoss and Cross Entropy Loss with class_weights support."""
    def __init__(self,
                 dice_weight: float = 1.0,
                 ce_weight: float = 1.0,
                 ignore_index: Optional[int] = None,
                 class_weights: Optional[torch.Tensor] = None, # Pass weights here
                 smooth_dice: float = 1e-5, # Adjust smooth value as needed (e.g., 1.0 for training)
                 ce_kwargs={}):
        super().__init__()
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        self.ignore_index = ignore_index # Store ignore_index

        # Ensure class_weights is a Tensor if provided
        if class_weights is not None:
             assert isinstance(class_weights, torch.Tensor), "class_weights must be a torch.Tensor"

        # Instantiate the modified Dice loss, passing weights
        self.dice = WeightedMemoryEfficientDiceLoss(
            apply_softmax=True, # Dice loss usually works on probabilities
            ignore_index=ignore_index,
            class_weights=class_weights, # Pass weights to Dice component
            smooth=smooth_dice
        )

        # Prepare kwargs for standard CrossEntropyLoss
        ce_final_kwargs = ce_kwargs.copy()
        if ignore_index is not None:
            ce_final_kwargs['ignore_index'] = ignore_index
        if class_weights is not None:
            # Pass the weights tensor to CE's 'weight' parameter
            ce_final_kwargs['weight'] = class_weights
            # Note: CE will handle moving the weights tensor to the correct device internally

        self.cross_entropy = nn.CrossEntropyLoss(**ce_final_kwargs)

    def forward(self, outputs, targets):
        # outputs are expected to be logits [N, C, H, W]
        # targets are expected to be class indices [N, H, W] or [N, 1, H, W]

        # --- Dice Loss ---
        # WeightedMemoryEfficientDiceLoss handles softmax internally
        dice_loss = self.dice(outputs, targets)

        # --- Cross Entropy Loss ---
        # Prepare targets for CE
        if targets.ndim == 4 and targets.shape[1] == 1:
             targets_ce = targets.squeeze(1).long()
        elif targets.ndim == 3:
             targets_ce = targets.long() # Assuming [N, H, W]
        else:
             # Added more specific error message for common cases
             if targets.ndim == outputs.ndim and targets.shape[1] != 1:
                 raise ValueError(f"Target shape {targets.shape} has multiple channels but expected class indices [N, H, W] or [N, 1, H, W] for CE.")
             else:
                 raise ValueError(f"Unsupported target shape {targets.shape} for CE. Expected [N, H, W] or [N, 1, H, W].")

        # Weights are handled internally by nn.CrossEntropyLoss via its 'weight' parameter
        ce_loss = self.cross_entropy(outputs, targets_ce)
        
        # --- Combine ---
        combined_loss = (self.dice_weight * dice_loss) + (self.ce_weight * ce_loss)
        return combined_loss

class_weight = [0.30711034803008996, 1.5412496145750956, 1.8445296893647247, 0.30711034803008996]
# class_weight = [1, 1, 1, 1]
class_weight = Tensor(class_weight)
# class_weight = calculate_class_weights_v3(training_data, 4, None, "dataset", [3], 0)

accumulation_steps = target_batch_size // batch_size

# --- Define Model, Loss, Optimizer, Scheduler ---
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)
model = ClipUNet().to(device) # Assuming unet() exists and returns your U-Net model

# --- Define Loss, Optimizer, Scheduler ---
# Configure ignore_index (e.g., 3 to ignore class 3, or 255 if used in labels, None otherwise)
EVAL_IGNORE_INDEX = 3 # Example: ignore class 3 during evaluation metric calculation
TRAIN_IGNORE_INDEX = None  # Example: train on all classes (0,1,2,3)

class_weight = class_weight.to(device)
loss_fn = WeightedDiceCELoss(ignore_index=TRAIN_IGNORE_INDEX, smooth_dice=1, class_weights=class_weight) # Training loss
# Evaluation loss object used inside eval loop only to get settings like ignore_index
# It is NOT used to calculate the loss score reported for eval (that uses training loss object)
# But we pass it to eval_loop so it knows which index to ignore for metric calc if needed
eval_settings_provider = WeightedDiceCELoss(ignore_index=EVAL_IGNORE_INDEX, class_weights=class_weight)


optimizer = optim.SGD(model.parameters(), lr=INITIAL_LR, momentum=SGD_MOMENTUM,
                      weight_decay=WEIGHT_DECAY, nesterov=True)

steps_per_epoch = (len(train_dataloader) // accumulation_steps) + (1 if len(train_dataloader) % accumulation_steps != 0 else 0)
total_iters = steps_per_epoch * EPOCHS
scheduler = PolynomialLR(optimizer, total_iters=total_iters, power=0.9)
NUM_CLASSES = 4
processor = CLIPImageProcessor.from_pretrained(PRETRAINED_MODEL_NAME)
agg = MetricsHistory(NUM_CLASSES, EVAL_IGNORE_INDEX)

start_epoch = 0
best_dev_dice = -np.inf # Track best Dice score
best_dev_miou = -np.inf # Track best mIoU
best_dev_loss = np.inf # Track loss corresponding to best metric

MODEL_NAME = "test.pytorch"
MODEL_SAVE_DIR = "clip"
checkpoint_path = os.path.join(MODEL_SAVE_DIR, MODEL_NAME)

os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
if os.path.isfile(f"{MODEL_SAVE_DIR}/{MODEL_NAME}"):
    print(f"Loading checkpoint from: {MODEL_SAVE_DIR}/{MODEL_NAME}")
    # Load the checkpoint dictionary; move tensors to the correct device
    checkpoint = torch.load(f"{MODEL_SAVE_DIR}/{MODEL_NAME}", map_location=device)

    # Load model state
    model.load_state_dict(checkpoint["model_state_dict"])
    print(" -> Model state loaded.")

    # Load optimizer state
    try:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        print(" -> Optimizer state loaded.")
    except Exception as e:
        print(f" -> Warning: Could not load optimizer state: {e}. Optimizer will start from scratch.")

    # Load scheduler state
    try:
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        print(" -> Scheduler state loaded.")
    except Exception as e:
        print(f" -> Warning: Could not load scheduler state: {e}. Scheduler will start from scratch.")

    try:
        agg = checkpoint.get("history")
    except Exception as e:
        print(f" -> No metric history saved")
        agg = MetricsHistory(NUM_CLASSES, EVAL_IGNORE_INDEX)

    # Load training metadata
    start_epoch = checkpoint.get("epoch", 0) # Load last completed epoch, training continues from next one
    best_dev_dice = checkpoint.get("best_dev_dice", -np.inf)
    best_dev_miou = checkpoint.get("best_dev_miou", -np.inf)
    best_dev_loss = checkpoint.get("best_dev_loss", np.inf)

    print(f" -> Resuming training from epoch {start_epoch + 1}")
    print(f" -> Loaded best metrics: Dice={best_dev_dice:.6f}, mIoU={best_dev_miou:.6f}, Loss={best_dev_loss:.6f}")
    loaded_notes = checkpoint.get("notes", "N/A")
    print(f" -> Notes from checkpoint: {loaded_notes}")

else:
    print(f"Checkpoint file not found at {MODEL_SAVE_DIR}/{MODEL_NAME}. Starting training from scratch.")

# --- Training and Evaluation Loop ---
print("\nStarting Training...")
for t in range(start_epoch, EPOCHS):
    print(f"Epoch {t+1}\n-------------------------------")
    
    train_loss = train_loop(train_dataloader, model, loss_fn, optimizer, scheduler, accumulation_steps, device, 224)
    val_loss, val_dice_micro, val_miou = eval_loop(val_dataloader, model, eval_settings_provider, device, 224, agg)

    metrics = {
        "epoch": t + 1,
        "history": agg
    }
    torch.save(metrics, f"{MODEL_SAVE_DIR}/metrics_{MODEL_NAME}")

    # Save model based on validation MICRO DICE score improvement
    # Could also choose mIoU validation by changing 'val_dice_micro > best_dev_dice'
    if val_dice_micro > best_dev_dice:
        best_dev_dice = val_dice_micro
        best_dev_miou = val_miou # Save corresponding mIoU
        best_dev_loss = val_loss # Save corresponding loss
        print(f"Validation Micro Dice score improved ({best_dev_dice:.6f}). Saving model...")

        checkpoint = {
            "epoch": t + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "best_dev_dice": best_dev_dice,
            "best_dev_miou": best_dev_miou,
            "best_dev_loss": best_dev_loss,
            "history": agg,
            "notes": f"Model saved based on best Micro Dice. Ignored index for metric: {EVAL_IGNORE_INDEX}"
        }
        torch.save(checkpoint, checkpoint_path)
    else:
        print(f"Validation Micro Dice score did not improve from {best_dev_dice:.6f}")


print("\n--- Training Finished! ---")
print(f"Best validation Micro Dice score achieved: {best_dev_dice:.6f}")
print(f"Corresponding validation mIoU: {best_dev_miou:.6f}")
print(f"Corresponding validation loss: {best_dev_loss:.6f}")
print(f"Best model saved to: {os.path.join(MODEL_SAVE_DIR, 'unet_best_dice.pytorch')}")

cuda
Checkpoint file not found at clip/test.pytorch. Starting training from scratch.

Starting Training...
Epoch 1
-------------------------------


Training: 100%|██████████| 398/398 [03:02<00:00,  2.18it/s, loss=0.848, lr=0.00991]  


Training Avg loss (per effective batch): -0.131471


Eval: 100%|██████████| 47/47 [00:12<00:00,  3.78it/s]



--- Evaluation Complete ---
  Images Processed: 738
  Average Loss (Original Size): -0.360313
  Ignored Class : 3
  Macro Avg Acc score: 0.949696
  Macro Avg Dice Score: 0.894777
  Mean IoU (mIoU): 0.811788
  --- Per-Class IoU ---
    Class 0: 0.899327
    Class 1: 0.749620
    Class 2: 0.786417
    Class 3: 0.033430
-------------------------
Validation Micro Dice score improved (0.894777). Saving model...
Epoch 2
-------------------------------


Training: 100%|██████████| 398/398 [03:03<00:00,  2.17it/s, loss=-0.233, lr=0.00982]


Training Avg loss (per effective batch): -0.456696


Eval: 100%|██████████| 47/47 [00:12<00:00,  3.78it/s]



--- Evaluation Complete ---
  Images Processed: 738
  Average Loss (Original Size): -0.368965
  Ignored Class : 3
  Macro Avg Acc score: 0.953437
  Macro Avg Dice Score: 0.903795
  Mean IoU (mIoU): 0.826104
  --- Per-Class IoU ---
    Class 0: 0.904047
    Class 1: 0.787335
    Class 2: 0.786932
    Class 3: 0.167725
-------------------------
Validation Micro Dice score improved (0.903795). Saving model...
Epoch 3
-------------------------------


Training:   2%|▏         | 9/398 [00:04<03:08,  2.06it/s, loss=-0.529, lr=0.00982]


KeyboardInterrupt: 

In [17]:
from torch import Tensor
model = ClipUNet().to(device)
class_weight = [0.30711034803008996, 1.5412496145750956, 1.8445296893647247, 0.30711034803008996]
class_weight = Tensor(class_weight)
class_weight = class_weight.to(device)
loss_fn = WeightedDiceCELoss(ignore_index=3, class_weights=class_weight)


# optimizer = optim.SGD(model.parameters(), lr=INITIAL_LR, momentum=SGD_MOMENTUM,
#                       weight_decay=WEIGHT_DECAY, nesterov=True)

# total_iters = (len(train_dataloader) // accumulation_steps) * EPOCHS
# scheduler = PolynomialLR(optimizer, total_iters=total_iters, power=0.9)

checkpoint = torch.load("clip/clip_weighted_no_mask_loss.pytorch")

model.load_state_dict(checkpoint["model_state_dict"])
# optimizer.load_state_dict(checkpoint["optimizer"])
# scaler.load_state_dict(checkpoint["scaler"])

model.to(device)

# new_dev_loss = eval_loop(test_dataloader, model, loss_fn, device, 224)

# images, labels = next(iter(test_dataloader))
# images, meta = process_batch_forward(images, 224)

# pred = model(images.to(device))

# images = process_batch_reverse(pred, meta)

# print(images[0].shape)

# plt.imshow(pred[0].argmax(dim=0).cpu().numpy())
# plt.show()
# plt.imshow(images[0].argmax(dim=0).cpu().numpy())
# plt.show()
# plt.imshow(labels[0].permute(1,2,0).cpu().numpy())
# plt.show()
agg = MetricsHistory(4, 3)
new_dev_loss = eval_loop(test_dataloader, model, loss_fn, device, 224, agg)

Eval:   8%|▊         | 18/231 [00:05<01:05,  3.27it/s]


KeyboardInterrupt: 

In [None]:
images, labels = next(iter(test_dataloader))
images, meta = process_batch_forward(images, 224)

pred = model(images.to(device))

images = process_batch_reverse(pred, meta)

print(images[0].shape)

plt.imshow(pred[0].argmax(dim=0).cpu().numpy())
plt.show()
plt.imshow(images[0].argmax(dim=0).cpu().numpy())
plt.show()
plt.imshow(labels[0].permute(1,2,0).cpu().numpy())
plt.show()
# # new_dev_loss = eval_loop(test_dataloader, model, loss_fn, device, 256)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Output class definition might vary slightly across transformers versions
# Try importing potential base classes if direct attribute access fails later
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers import CLIPVisionModel, CLIPVisionConfig, CLIPImageProcessor
from math import sqrt
import warnings

# --- Configuration ---
PRETRAINED_MODEL_NAME = "openai/clip-vit-base-patch16"
IMG_SIZE = 224
# BOTTLENECK_LAYER_INDEX is no longer needed for this encoder version

# --- Modified CLIP ViT Encoder (Uses last_hidden_state) ---
class ClipViTEncoderNoSkips(nn.Module):
    def __init__(self, model_name=PRETRAINED_MODEL_NAME, freeze_encoder=True): # Removed bottleneck_index
        super().__init__()
        # print("\n--- Initializing ClipViTEncoderNoSkips (Using last_hidden_state) ---")
        self.model_name = model_name

        self.config = CLIPVisionConfig.from_pretrained(model_name)
        self.clip_vit = CLIPVisionModel.from_pretrained(model_name)

        if freeze_encoder:
            for param in self.clip_vit.parameters():
                param.requires_grad = False

        self.patch_size = self.config.patch_size
        if self.config.image_size != IMG_SIZE:
             warnings.warn(f"Model config image size {self.config.image_size} differs from specified IMG_SIZE {IMG_SIZE}.")
        self.grid_size = self.config.image_size // self.patch_size
        self.hidden_dim = self.config.hidden_size
        self.num_patches = self.grid_size * self.grid_size

        # print(f"  Model: {model_name}")
        # print(f"  Grid Size (G): {self.grid_size}x{self.grid_size}")
        # print(f"  Hidden Dim (D_vit): {self.hidden_dim}")
        # print(f"  Will extract bottleneck from standard 'last_hidden_state' output.")
        # print(f"  Encoder frozen: {freeze_encoder}")
        # print("--- ClipViTEncoderNoSkips Initialized ---\n")

    def forward(self, x):
        # print("\n--- ClipViTEncoderNoSkips Forward ---")
        # print(f"  Input shape: {x.shape}")
        if x.shape[2] != self.config.image_size or x.shape[3] != self.config.image_size:
             warnings.warn(f"Input image size mismatch.")

        # Pass image through CLIP ViT, NO need for output_hidden_states=True
        outputs = self.clip_vit(pixel_values=x)

        # Access the output of the final layer directly
        # The output object type might vary, usually has last_hidden_state
        if isinstance(outputs, BaseModelOutputWithPooling) or hasattr(outputs, 'last_hidden_state'):
             last_hidden_state = outputs.last_hidden_state
            #  print(f"  Accessed 'last_hidden_state', shape: {last_hidden_state.shape}") # Should be (B, N_p+1, D_vit)
        else:
             # Fallback or error if the output structure is unexpected
             # Sometimes the raw tuple output is returned if BaseClass isn't used
             if isinstance(outputs, tuple):
                  last_hidden_state = outputs[0] # Usually the first element
                  warnings.warn("Accessed model output via tuple indexing (index 0). Assumed it's last_hidden_state.")
                #   print(f"  Accessed output via tuple index 0, shape: {last_hidden_state.shape}")
             else:
                  raise TypeError(f"Unexpected output type from CLIPVisionModel: {type(outputs)}. Cannot extract last_hidden_state.")


        # --- Reshape Final Layer Output (Bottleneck) ---
        # print(f"    Original last_hidden_state shape: {last_hidden_state.shape}")

        patch_embeddings = last_hidden_state[:, 1:, :] # Remove CLS token
        # print(f"    Shape after removing CLS token: {patch_embeddings.shape}")

        if patch_embeddings.shape[1] != self.num_patches:
             current_num_patches = patch_embeddings.shape[1]
             current_grid_size = int(sqrt(current_num_patches))
             if current_grid_size * current_grid_size != current_num_patches:
                 raise ValueError(f"Cannot reshape patch embeddings.")
             warnings.warn(f"Patch count mismatch. Reshaping to {current_grid_size}x{current_grid_size}.")
             grid_h, grid_w = current_grid_size, current_grid_size
        else:
             grid_h, grid_w = self.grid_size, self.grid_size

        bottleneck_features = patch_embeddings.reshape(
            x.shape[0], grid_h, grid_w, self.hidden_dim
        ).permute(0, 3, 1, 2).contiguous()
        # print(f"    Bottleneck feature shape (B, D_vit, G, G): {bottleneck_features.shape}")

        # print("--- ClipViTEncoderNoSkips Forward End ---")
        return bottleneck_features # Only return the bottleneck

# --- Modified U-Net Decoder Block (No Skip Connection Logic) ---
# (Keep the DecoderBlockNoSkip from the previous version - no changes needed)
class DecoderBlockNoSkip(nn.Module):
    def __init__(self, block_index, in_channels_upsample, out_channels): # Removed skip-related args
        super().__init__()
        self.block_index = block_index
        # print(f"    Initializing DecoderBlockNoSkip {self.block_index}: Upsample_in={in_channels_upsample}, Out={out_channels}")
        self.upsample_out_channels = in_channels_upsample // 2

        self.upsample = nn.ConvTranspose2d(in_channels_upsample, self.upsample_out_channels, kernel_size=2, stride=2)

        conv_in_channels = self.upsample_out_channels
        # print(f"      Block {self.block_index}: Input channels to conv_block = {conv_in_channels}")

        self.conv_block = nn.Sequential(
            nn.Conv2d(conv_in_channels, out_channels, kernel_size=3, padding=1, bias=False), # Adjusted input channels here
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x_upsample): # Only takes x_upsample as input
        # print(f"\n  --- DecoderBlockNoSkip {self.block_index} Forward ---")
        # print(f"    Input x_upsample shape: {x_upsample.shape}")
        x_upsample = self.upsample(x_upsample)
        # print(f"    Shape after upsample: {x_upsample.shape}")
        x = self.conv_block(x_upsample) # Pass upsampled output directly to conv block
        # print(f"    Shape after conv_block (Output): {x.shape}")
        # print(f"  --- DecoderBlockNoSkip {self.block_index} Forward End ---")
        return x


# --- Modified U-Net Decoder (No Skip Connection Logic) ---
# (Keep the UNetDecoderNoSkips from the previous version - no changes needed)
class UNetDecoderNoSkips(nn.Module):
    def __init__(self, encoder_hidden_dim, encoder_grid_size, decoder_channels):
        super().__init__()
        # print("\n--- Initializing UNetDecoderNoSkips ---")
        self.encoder_hidden_dim = encoder_hidden_dim
        self.encoder_grid_size = encoder_grid_size
        self.decoder_channels = decoder_channels
        num_blocks = len(decoder_channels)

        in_channels_upsample = encoder_hidden_dim # Starts with bottleneck channels
        self.blocks = nn.ModuleList()
        current_spatial_size = encoder_grid_size
        # print(f"  Creating {num_blocks} DecoderBlockNoSkip(s):")
        for i in range(num_blocks):
            out_ch = decoder_channels[i]
            block = DecoderBlockNoSkip( # Use the modified block
                block_index=i,
                in_channels_upsample=in_channels_upsample,
                out_channels=out_ch,
            )
            self.blocks.append(block)
            in_channels_upsample = out_ch # Next block takes output channels from this block
            current_spatial_size *= 2 # Update expected spatial size
        # print("--- UNetDecoderNoSkips Initialized ---\n")

    def forward(self, bottleneck_features): # Only takes bottleneck_features
        # print("\n--- UNetDecoderNoSkips Forward ---")
        # print(f"  Input bottleneck_features shape: {bottleneck_features.shape}")
        x = bottleneck_features
        for i, block in enumerate(self.blocks):
            # print(f"  Processing Decoder Block {i}")
            x = block(x) # Pass output of previous block/bottleneck directly
        # print(f"  Final output shape from UNetDecoderNoSkips: {x.shape}")
        # print("--- UNetDecoderNoSkips Forward End ---")
        return x

# --- Combined CLIP-U-Net Model (No Skips - uses modified encoder) ---
# (Keep the ClipUNetNoSkips from the previous version - no changes needed,
# as it already uses the encoder named ClipViTEncoderNoSkips)
class ClipUNetNoSkips(nn.Module):
    def __init__(self, num_classes=4, decoder_channels=[512, 256, 128, 64], freeze_encoder=True):
        super().__init__()
        # print("\n--- Initializing ClipUNetNoSkips ---")
        # Uses the newly defined ClipViTEncoderNoSkips which accesses last_hidden_state
        self.encoder = ClipViTEncoderNoSkips(freeze_encoder=freeze_encoder)
        self.decoder = UNetDecoderNoSkips( # Uses the decoder without skip logic
            encoder_hidden_dim=self.encoder.hidden_dim,
            encoder_grid_size=self.encoder.grid_size,
            decoder_channels=decoder_channels
        )
        last_decoder_channel = decoder_channels[-1] if decoder_channels else self.encoder.hidden_dim
        self.final_conv = nn.Conv2d(last_decoder_channel, num_classes, kernel_size=1)
        final_decoder_size = self.encoder.grid_size * (2**len(decoder_channels))
        if final_decoder_size != IMG_SIZE:
             self.final_upsample = nn.Upsample(size=(IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
        else:
             self.final_upsample = nn.Identity()
        # print("--- ClipUNetNoSkips Initialized ---\n")

    def forward(self, x):
        # print("\n\n========== ClipUNetNoSkips Forward Pass Start ==========")
        # print(f"Overall Input shape: {x.shape}")
        bottleneck = self.encoder(x) # Encoder now uses last_hidden_state implicitly
        # print("\n--- Back in ClipUNetNoSkips Forward (after Encoder) ---")
        # print(f"  Encoder returned bottleneck shape: {bottleneck.shape}")
        decoder_output = self.decoder(bottleneck) # Decoder takes only bottleneck
        # print("\n--- Back in ClipUNetNoSkips Forward (after Decoder) ---")
        # print(f"  Decoder returned output shape: {decoder_output.shape}")
        output = self.final_conv(decoder_output)
        # print("\n--- Back in ClipUNetNoSkips Forward (after Final Conv) ---")
        # print(f"  Shape after final_conv: {output.shape}")
        output = self.final_upsample(output)
        # print("\n--- Back in ClipUNetNoSkips Forward (after Final Upsample) ---")
        # print(f"  Shape after final_upsample: {output.shape}")
        # print("========== ClipUNetNoSkips Forward Pass End ==========\n")
        return output