# Panoptic Segmentation with CLIP + Mask2Former (LoRA)

This notebook implements a panoptic segmentation pipeline using a CLIP backbone (fine-tuned with LoRA) and a lightweight Mask2Former-style decoder.

**Note**: The dataset is downloaded manually from MIT SceneParsing to ensure stability.


In [ ]:
!pip install transformers datasets albumentations peft torchmetrics scipy


In [ ]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import cv2
import glob
from torch.utils.data import Dataset, DataLoader
# from datasets import load_dataset # Eliminated
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import CLIPVisionModel, CLIPConfig
from peft import LoraConfig, get_peft_model
from scipy.optimize import linear_sum_assignment
from tqdm.notebook import tqdm

# Configuration
IMAGE_SIZE = 512
BATCH_SIZE = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CLIP_MODEL_ID = "openai/clip-vit-base-patch16"
NUM_CLASSES = 150 


In [ ]:
# --- Dataset Implementation ---
import os
import torch
import numpy as np
import cv2
import glob
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import CLIPProcessor

# --- Configuration Block ---
# Adjust these based on your specific Colab environment and requirements
DATASET_NAME = "scene_parse_150" # HuggingFace dataset name for ADE20k
IMAGE_SIZE = 512  # Resize images to this dimension (Square)
CLIP_MODEL_ID = "openai/clip-vit-base-patch16"
BATCH_SIZE = 4
NUM_WORKERS = 2

class ADE20kPanopticDataset(Dataset):
    """
    Dataset class for ADE20k Panoptic Segmentation.
    
    Note: The official HuggingFace 'scene_parse_150' dataset provides semantic and instance masks.
    For true 'panoptic' format, we typically combine these.
    However, for this simplified implementation, we will treat it as a collection of binary masks 
    and class labels, which is what Mask2Former expects.
    """
    def __init__(self, root_dir="./ADEChallengeData2016", split="train", transform=None):
        self.root_dir = root_dir
        self.split = "training" if split == "train" else "validation"
        self.transform = transform
        
        # Check if dataset exists, if not download
        if not os.path.exists(self.root_dir):
            self.download_ade20k()
            
        self.image_dir = os.path.join(self.root_dir, "images", self.split)
        self.mask_dir = os.path.join(self.root_dir, "annotations", self.split)
        
        self.images = sorted(glob.glob(os.path.join(self.image_dir, "*.jpg")))
        self.masks = sorted(glob.glob(os.path.join(self.mask_dir, "*.png")))
        
        print(f"Found {len(self.images)} images in {self.image_dir}")

    def download_ade20k(self):
        print("Downloading ADE20k dataset (this may take a while)...")
        # Direct link to ADE20k
        url = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
        zip_path = "ADEChallengeData2016.zip"
        
        if not os.path.exists(zip_path):
            os.system(f"wget {url} -O {zip_path}")
            
        print("Unzipping...")
        os.system(f"unzip -q {zip_path}")
        print("Download complete.")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        mask_path = self.masks[idx]
        
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Annotation in ADE20k zip: Int masks
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']

        # Convert back to numpy if albumentations converted to tensor
        if isinstance(mask, torch.Tensor):
            mask = mask.numpy()

        # Mask2Former expects:
        # - pixel_values: (C, H, W) -> Normalized image
        # - pixel_mask: (H, W) -> Padding mask (optional)
        # - mask_labels: list of binary masks (N, H, W)
        # - class_labels: list of class ids (N)
        
        # Process Mask into Binary Masks + Labels
        unique_ids = np.unique(mask)
        # Remove background/ignore index if present (usually 0 or 255)
        unique_ids = unique_ids[unique_ids != 0] 
        
        masks = []
        labels = []
        
        for uid in unique_ids:
            # Create binary mask for this instance/class
            binary_mask = (mask == uid).astype(np.float32)
            masks.append(binary_mask)
            labels.append(uid - 1) # ADE20k IDs are 1-150. We need 0-149 for model.
            
        if len(masks) > 0:
            masks = torch.tensor(np.stack(masks), dtype=torch.float32)
            labels = torch.tensor(labels, dtype=torch.long)
        else:
            # Handle standard case with no objects (rare in ADE20k)
            masks = torch.zeros((0, IMAGE_SIZE, IMAGE_SIZE), dtype=torch.float32)
            labels = torch.tensor([], dtype=torch.long)

        # Normalize image for CLIP
        # CLIP Expects:
        # mean = [0.48145466, 0.4578275, 0.40821073]
        # std  = [0.26862954, 0.26130258, 0.27577711]
        # Validated against CLIPProcessor defaults
        
        return {
            "pixel_values": image, 
            "masks": masks, 
            "class_labels": labels,
            "original_size": (image.shape[1], image.shape[2]) # H, W after transform (or keep original before)
        }

def get_transforms(image_size=512):
    # CLIP Normalization constants
    mean = (0.48145466, 0.4578275, 0.40821073)
    std  = (0.26862954, 0.26130258, 0.27577711)
    
    return A.Compose([
        A.Resize(height=image_size, width=image_size),
        A.Normalize(mean=mean, std=std),
        ToTensorV2(),
    ])

def collafe_fn(batch):
    # Custom collate because masks have variable channel (N instances)
    pixel_values = torch.stack([x['pixel_values'] for x in batch])
    
    targets = []
    for x in batch:
        targets.append({
            "masks": x['masks'],
            "class_labels": x['class_labels']
        })
        
    return pixel_values, targets

# --- Usage Example ---
# dataset = ADE20kPanopticDataset(split="train", transform=get_transforms(IMAGE_SIZE))
# dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=collafe_fn)


In [ ]:
# --- Model Architecture ---
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPVisionModel, CLIPConfig
from peft import LoraConfig, get_peft_model

class CLIPBackbone(nn.Module):
    def __init__(self, model_name="openai/clip-vit-base-patch16", use_lora=True, lora_rank=16):
        super().__init__()
        # Load standard CLIP Vision Model
        # SPEED OPTIMIZATION: Default output_hidden_states=False is much faster
        self.base_model = CLIPVisionModel.from_pretrained(model_name)
        
        if use_lora:
            print(f"Injecting LoRA adapters with rank={lora_rank}...")
            peft_config = LoraConfig(
                r=lora_rank, 
                lora_alpha=lora_rank*2, 
                target_modules=["q_proj", "v_proj"], 
                lora_dropout=0.1, 
                bias="none",
                modules_to_save=[], 
            )
            self.base_model = get_peft_model(self.base_model, peft_config)
            self.base_model.print_trainable_parameters()

    def forward(self, x):
        # x: [B, 3, H, W]
        outputs = self.base_model(pixel_values=x, interpolate_pos_encoding=True)
        last_hidden = outputs.last_hidden_state
        patch_tokens = last_hidden[:, 1:, :] 
        B, L, D = patch_tokens.shape
        H = W = int(L**0.5) 
        
        feature_map = patch_tokens.permute(0, 2, 1).reshape(B, D, H, W)
                    
        return feature_map

class SimpleFPN(nn.Module):
    """
    ViTDet-style Simple Feature Pyramid.
    Builds a pyramid from a single high-level feature map.
    """
    def __init__(self, in_channels=768, hidden_dim=256):
        super().__init__()
        
        self.simfpn0 = nn.Sequential(
            nn.ConvTranspose2d(in_channels, hidden_dim, kernel_size=2, stride=2),
            nn.BatchNorm2d(hidden_dim),
            nn.GELU(),
            nn.ConvTranspose2d(hidden_dim, hidden_dim, kernel_size=2, stride=2),
        )
        self.simfpn1 = nn.Sequential(
             nn.ConvTranspose2d(in_channels, hidden_dim, kernel_size=2, stride=2),
        )
        self.simfpn2 = nn.Sequential(
            nn.Identity(), 
            nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
        )
        self.simfpn3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.simfpn3_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)

    def forward(self, x):
        # x: [B, 768, H/16, W/16]
        p2 = self.simfpn0(x) # 1/4
        p3 = self.simfpn1(x) # 1/8
        p4 = self.simfpn2(x) # 1/16
        p5 = self.simfpn3(x) # 1/32
        p5 = self.simfpn3_proj(p5)
        
        return [p2, p3, p4, p5] 

class LightMask2Former(nn.Module):
    def __init__(self, in_channels=256, num_queries=100, num_classes=150, hidden_dim=256):
        super().__init__()
        
        self.num_queries = num_queries
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        
        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=8, dim_feedforward=1024)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=3)
        
        self.class_head = nn.Linear(hidden_dim, num_classes + 1)
        self.mask_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # P3 and P4 projection layers? 
        # Actually we assume FPN outputs are all 'hidden_dim' channels, so we can reuse logic.
        
    def forward_mask_prediction(self, mask_embed, feature_map):
        # mask_embed: [B, Q, C]
        # feature_map: [B, C, H, W]
        B, C, H, W = feature_map.shape
        pixel_embed_flat = feature_map.flatten(2) # [B, C, HW]
        pred_masks = torch.bmm(mask_embed, pixel_embed_flat)
        pred_masks = pred_masks.reshape(B, self.num_queries, H, W)
        return pred_masks

    def forward(self, features):
        # features: [P2, P3, P4, P5]
        # Main Scale: P2
        pixel_embed = features[0] 
        B, C, H, W = pixel_embed.shape
        
        # Flatten [H*W, B, C]
        pixel_embed_flat = pixel_embed.flatten(2).permute(2, 0, 1)
        
        queries = self.query_embed.weight.unsqueeze(1).repeat(1, B, 1) # [Q, B, C]
        
        # Decode
        out_queries = self.transformer_decoder(tgt=queries, memory=pixel_embed_flat)
        out_queries = out_queries.permute(1, 0, 2) # [B, Q, C]
        
        pred_logits = self.class_head(out_queries)
        mask_embed = self.mask_head(out_queries) # [B, Q, C]
        
        # Main Prediction (P2)
        pred_masks = self.forward_mask_prediction(mask_embed, features[0])
        
        # Auxiliary Predictions (P3, P4) for Consistency Loss
        # We use the SAME mask embeddings, just projected onto coarser feature maps.
        # This forces the feature maps to be consistent.
        pred_masks_p3 = self.forward_mask_prediction(mask_embed, features[1])
        pred_masks_p4 = self.forward_mask_prediction(mask_embed, features[2])
        
        return {
            "pred_logits": pred_logits,
            "pred_masks": pred_masks,
            "aux_outputs": [
                {"pred_masks": pred_masks_p3}, # Scale 1/8
                {"pred_masks": pred_masks_p4}  # Scale 1/16
            ]
        }

class CLIPPanopticModel(nn.Module):
    def __init__(self, num_classes=150, lora_rank=64):
        super().__init__()
        self.backbone = CLIPBackbone(lora_rank=lora_rank)
        self.pixel_decoder = SimpleFPN(in_channels=768, hidden_dim=256)
        self.decoder = LightMask2Former(in_channels=256, hidden_dim=256, num_classes=num_classes)
        
    def forward(self, x):
        backbone_feature = self.backbone(x)
        fpn_features = self.pixel_decoder(backbone_feature)
        outputs = self.decoder(fpn_features)
        return outputs


In [ ]:
# --- Evaluation & Visualization ---
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from torchmetrics.detection.mean_ap import MeanAveragePrecision

@torch.no_grad()
def evaluate_model(model, dataloader, device):
    model.eval()
    metric = MeanAveragePrecision(iou_type="segm")
    
    print("Running evaluation...")
    for batch in tqdm(dataloader, desc="Evaluating"):
        pixel_values, targets = batch
        pixel_values = pixel_values.to(device)
        
        # Function to ensure targets are in format expected by torchmetrics
        # Torchmetrics expects:
        # - targets: list of dicts with 'masks' (bool or uint8), 'labels', 'boxes' (optional for segm but good to have)
        formatted_targets = []
        for t in targets:
             # Masks: [N, H, W] -> Boolean
            masks_bool = t['masks'].to(device) > 0.5
            formatted_targets.append({
                "masks": masks_bool,
                "labels": t['class_labels'].to(device)
            })

        outputs = model(pixel_values)
        
        # Process Outputs
        # pred_logits: [B, Q, K+1]
        # pred_masks: [B, Q, H, W]
        preds = []
        for i in range(len(formatted_targets)):
            logits = outputs['pred_logits'][i]
            masks_logits = outputs['pred_masks'][i]
            
            # Probabilities and labels
            prob = logits.softmax(-1) # [Q, K+1]
            scores, labels = prob[:, :-1].max(-1) # Exclude 'no-object' class
            
            # Filter low confidence
            # DETR models have low confidence initially. 
            # 0.5 is too high for early epochs and kills mAP (cuts off PR curve).
            keep = scores > 0.05 
            
            if keep.sum() == 0:
                # No predictions
                preds.append({
                    "masks": torch.zeros((0, *masks_logits.shape[-2:]), dtype=torch.bool, device=device),
                    "scores": torch.tensor([], device=device),
                    "labels": torch.tensor([], device=device)
                })
                continue

            filtered_scores = scores[keep]
            filtered_labels = labels[keep]
            filtered_masks = masks_logits[keep]
            
            # Upsample masks to target resolution (if needed, usually done by metric but let's match target)
            # Assuming target resolution is 512x512 (same as input)
            # Model outputs low-res or 512x512 depending on decoder upsample. 
            # Our LightMask2Former outputs 16x downsampled or similar if not upsampled at end.
            # Let's force upsample to IMAGE_SIZE (512)
            target_H, target_W = formatted_targets[i]['masks'].shape[-2:]
            
            filtered_masks = F.interpolate(filtered_masks.unsqueeze(1), size=(target_H, target_W), mode="bilinear", align_corners=False).squeeze(1)
            filtered_masks = filtered_masks.sigmoid() > 0.5
            
            preds.append({
                "masks": filtered_masks,
                "scores": filtered_scores,
                "labels": filtered_labels
            })
            
        metric.update(preds, formatted_targets)
        
    result = metric.compute()
    return result

def visualize_prediction(model, dataset, idx, device):
    model.eval()
    
    # Load raw item for display
    # We need transforms for model input, but want raw for display
    # Re-access dataset item
    item_dict = dataset[idx]
    image_tensor = item_dict['pixel_values'].unsqueeze(0).to(device) # [1, 3, H, W]
    
    # Ground Truth
    gt_masks = item_dict['masks']
    
    # Inference
    with torch.no_grad():
        outputs = model(image_tensor)
    
    # Decode Prediction
    logits = outputs['pred_logits'][0]
    pred_masks = outputs['pred_masks'][0]
    
    prob = logits.softmax(-1)
    scores, labels = prob[:, :-1].max(-1)
    
    keep = scores > 0.05
    final_masks = pred_masks[keep]
    final_scores = scores[keep]
    final_labels = labels[keep]
    
    # Upsample
    if len(final_masks) > 0:
        H, W = image_tensor.shape[-2:]
        final_masks = F.interpolate(final_masks.unsqueeze(1), size=(H, W), mode="bilinear", align_corners=False).squeeze(1)
        final_masks = final_masks.sigmoid() > 0.5
    
    # --- Plotting ---
    # Denormalize Image
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3, 1, 1)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3, 1, 1)
    img_disp = (item_dict['pixel_values'].cpu() * std + mean).permute(1, 2, 0).numpy()
    img_disp = np.clip(img_disp, 0, 1)
    
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    
    # 1. Original
    axs[0].imshow(img_disp)
    axs[0].set_title("Input Image")
    axs[0].axis('off')
    
    # 2. Ground Truth Overlay
    combined_gt = np.zeros_like(img_disp)
    if len(gt_masks) > 0:
        for i, m in enumerate(gt_masks):
            color = np.random.rand(3)
            # mask is float 0-1
            m = m.numpy()
            combined_gt[m > 0.5] = color
    
    axs[1].imshow(img_disp)
    axs[1].imshow(combined_gt, alpha=0.5)
    axs[1].set_title("Ground Truth")
    axs[1].axis('off')
    
    # 3. Prediction Overlay
    combined_pred = np.zeros_like(img_disp)
    if len(final_masks) > 0:
        for i, m in enumerate(final_masks):
            color = np.random.rand(3)
            m = m.cpu().numpy()
            combined_pred[m > 0.5] = color
            
    axs[2].imshow(img_disp)
    axs[2].imshow(combined_pred, alpha=0.5)
    axs[2].set_title(f"Prediction ({len(final_masks)} objects)")
    axs[2].axis('off')
    
    plt.show()


In [ ]:
# --- Loss & Matcher ---
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from torchvision.ops import sigmoid_focal_loss
import numpy as np
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform
from transformers import CLIPTokenizer, CLIPModel

# ADE20k Class Names (150 classes)
ADE20K_CLASSES = [
    "wall", "building", "sky", "floor", "tree", "ceiling", "road", "bed", "windowpane", "grass", "cabinet",
    "sidewalk", "person", "earth", "door", "table", "mountain", "plant", "curtain", "chair", "car",
    "water", "painting", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat",
    "fence", "desk", "rock", "wardrobe", "lamp", "bathtub", "railing", "cushion", "base", "box", "column",
    "signboard", "chest of drawers", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator",
    "grandstand", "path", "stairs", "runway", "case", "pool table", "pillow", "screen door", "stairway",
    "river", "bridge", "bookcase", "blind", "coffee table", "toilet", "flower", "book", "hill", "bench",
    "countertop", "stove", "palm", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine",
    "hovel", "bus", "towel", "light", "truck", "tower", "chandelier", "awning", "streetlight", "booth",
    "television receiver", "airplane", "dirt track", "apparel", "pole", "land", "bannister", "escalator",
    "ottoman", "bottle", "buffet", "poster", "stage", "van", "ship", "fountain", "conveyer belt", "canopy",
    "washer", "plaything", "swimming pool", "stool", "barrel", "basket", "waterfall", "tent", "bag", "minibike",
    "cradle", "oven", "ball", "food", "step", "tank", "trade name", "microwave", "pot", "animal", "bicycle",
    "lake", "dishwasher", "screen", "blanket", "sculpture", "hood", "sconce", "vase", "traffic light", "tray",
    "ashcan", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator",
    "glass", "clock", "flag"
]

class BoundaryAwareLoss(nn.Module):
    def __init__(self, kernel_size=3):
        super().__init__()
        kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32)
        self.kernel = kernel.unsqueeze(0).unsqueeze(0)

    def forward(self, pred_masks, target_masks):
        if pred_masks.numel() == 0:
            return pred_masks.sum() * 0
        pred_input = pred_masks.unsqueeze(1) 
        target_input = target_masks.unsqueeze(1) 
        kernel = self.kernel.to(pred_masks.device)
        pred_edges = F.conv2d(pred_input, kernel, padding=1)
        target_edges = F.conv2d(target_input, kernel, padding=1)
        return F.l1_loss(pred_edges, target_edges)

class HungarianMatcher(nn.Module):
    def __init__(self, cost_class=1, cost_mask=1, cost_dice=1):
        super().__init__()
        self.cost_class = cost_class
        self.cost_mask = cost_mask
        self.cost_dice = cost_dice

    @torch.no_grad()
    def forward(self, outputs, targets):
        bs, num_queries = outputs["pred_logits"].shape[:2]
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  
        out_mask = outputs["pred_masks"].flatten(0, 1).flatten(1)  
        tgt_ids = torch.cat([v["class_labels"] for v in targets])
        tgt_mask = torch.cat([v["masks"] for v in targets])
        H_p, W_p = outputs["pred_masks"].shape[-2:]
        tgt_mask = F.interpolate(tgt_mask.unsqueeze(1), size=(H_p, W_p), mode='nearest').squeeze(1)
        tgt_mask = tgt_mask.flatten(1) 
        cost_class = -out_prob[:, tgt_ids]
        cost_mask = torch.cdist(out_mask, tgt_mask, p=1)
        out_mask_sig = out_mask.sigmoid()
        numerator = 2 * torch.mm(out_mask_sig, tgt_mask.t())
        denominator = out_mask_sig.sum(-1).unsqueeze(1) + tgt_mask.sum(-1).unsqueeze(0)
        cost_dice = 1 - (numerator / (denominator + 1e-6))
        C = self.cost_class * cost_class + self.cost_mask * cost_mask + self.cost_dice * cost_dice
        C = C.view(bs, num_queries, -1).cpu()
        indices = []
        sizes = [len(v["class_labels"]) for v in targets]
        for i, c in enumerate(C.split(sizes, -1)):
            if c.shape[-1] == 0: 
                indices.append((torch.empty(0, dtype=torch.long), torch.empty(0, dtype=torch.long)))
                continue
            row_ind, col_ind = linear_sum_assignment(c[0]) 
            indices.append((torch.as_tensor(row_ind, dtype=torch.int64), torch.as_tensor(col_ind, dtype=torch.int64)))
        return indices

class SetCriterion(nn.Module):
    def __init__(self, num_classes, matcher, weight_dict, num_parents=30, label_smoothing=0.1, device='cuda'):
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.boundary_loss_func = BoundaryAwareLoss()
        
        # Hierarchical Config
        self.num_parents = num_parents
        self.label_smoothing = label_smoothing
        self.hierarchy = None
        self.device_name = device # Store device string
        
        # To be computed on first forward or init
        # We compute it lazily or here if device is ready
        self._hierarchy_computed = False

    def _compute_hierarchy(self, device):
        print("Computing Class Hierarchy from CLIP Embeddings...")
        tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        
        class_names = ADE20K_CLASSES[:self.num_classes] 
        inputs = tokenizer(class_names, padding=True, return_tensors="pt").to(device)
        
        with torch.no_grad():
            embeddings = model.get_text_features(**inputs)
            
        embeddings = embeddings.cpu().numpy()
        # Cosine distance
        norm = np.linalg.norm(embeddings, axis=1, keepdims=True)
        embeddings = embeddings / (norm + 1e-8)
        distances = 1 - (embeddings @ embeddings.T)
        distances = (distances + distances.T) / 2 # Enforce symmetry
        np.fill_diagonal(distances, 0) # Enforce 0 diagonal for scipy
        distances = distances.clip(min=0)
        
        # Clustering
        condensed = squareform(distances) # Clip negative precision errors
        linkage_matrix = linkage(condensed, method='ward')
        cluster_labels = fcluster(linkage_matrix, self.num_parents, criterion='maxclust')
        
        self.hierarchy = torch.tensor(cluster_labels - 1, dtype=torch.long, device=device)
        self._hierarchy_computed = True
        print("Hierarchy Computed.")

    def create_soft_target(self, target_classes_o, device):
        """
        Create soft target for matched queries (B_matches, NumClasses)
        """
        N = target_classes_o.shape[0]
        soft_target = torch.ones(N, self.num_classes, device=device) * (self.label_smoothing / (self.num_classes - 1))
        
        for i, class_id in enumerate(target_classes_o):
            parent_id = self.hierarchy[class_id]
            
            # True class
            soft_target[i, class_id] = 1.0 - self.label_smoothing
            
            # Siblings
            sibling_mask = (self.hierarchy == parent_id)
            sibling_ids = torch.where(sibling_mask)[0]
            
            if len(sibling_ids) > 1:
                remaining_mass = self.label_smoothing / len(sibling_ids)
                # Distribute to siblings (including self, but self already boosted, so effectively boost 'other' siblings)
                # Actually user logic: "Distribute remaining mass among siblings"
                # Let's simple add boosted probability to siblings
                soft_target[i, sibling_ids] += remaining_mass

        # Renormalize
        soft_target = soft_target / (soft_target.sum(dim=1, keepdim=True) + 1e-8)
        return soft_target

    def loss_labels(self, outputs, targets, indices, num_boxes):
        src_logits = outputs['pred_logits'] # [B, Q, K+1]
        src_logits = src_logits[..., :-1]   # [B, Q, K]
        
        idx = self._get_src_permutation_idx(indices) # (Batch_idx, Query_idx) for Matches
        target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
        
        # Initialize Hierarchy if needed
        if not self._hierarchy_computed:
            self._compute_hierarchy(src_logits.device)

        # --- for "No Object" queries (Background), we usually push probabilities down -> entropy Max
        # But Mask2Former typically matches specific queries.
        # Unmatched queries (Backgound) are standard Focal Loss to 0.
        
        # 1. Standard Focal Loss (Classification Hard Imbalance)
        target_classes_onehot = torch.zeros_like(src_logits)
        target_classes_onehot[idx[0], idx[1], target_classes_o] = 1.0
        loss_focal = sigmoid_focal_loss(src_logits, target_classes_onehot, alpha=0.25, gamma=2.0, reduction="sum")
        loss_focal = loss_focal / num_boxes

        # 2. Hierarchical Loss Components (Only on MATCHED queries)
        # We only apply semantic guidance to positive objects.
        matched_logits = src_logits[idx] # [N_matches, K]
        
        if len(matched_logits) > 0:
            # Soft Targets
            soft_targets = self.create_soft_target(target_classes_o, src_logits.device)
            
            # KL Loss (Fineness)
            # LogSoftmax on logits
            log_probs = F.log_softmax(matched_logits, dim=1)
            loss_kl = F.kl_div(log_probs, soft_targets, reduction='batchmean')
            
            # Parent Loss (Coarseness)
            # Sum logits for parents
            parent_logits = torch.zeros(len(matched_logits), self.num_parents, device=src_logits.device)
            # There is probably a scatter_add_ way to do this faster, but loop is safe for now
            for pid in range(self.num_parents):
                child_mask = (self.hierarchy == pid)
                if child_mask.any():
                    # logsumexp of children logits for this parent
                    parent_logits[:, pid] = torch.logsumexp(matched_logits[:, child_mask], dim=1)
            
            target_parents = self.hierarchy[target_classes_o]
            loss_parent = F.cross_entropy(parent_logits, target_parents)
            
        else:
            loss_kl = torch.tensor(0.0, device=src_logits.device)
            loss_parent = torch.tensor(0.0, device=src_logits.device)

        return {'loss_ce': loss_focal, 'loss_kl': loss_kl, 'loss_parent': loss_parent}

    def loss_masks(self, outputs, targets, indices, num_boxes):
        src_idx = self._get_src_permutation_idx(indices)
        src_masks = outputs['pred_masks'][src_idx] 
        target_masks = torch.cat([t['masks'][J] for t, (_, J) in zip(targets, indices)])
        
        src_masks = F.interpolate(src_masks[:, None], size=target_masks.shape[-2:],
                                mode="bilinear", align_corners=False).squeeze(1)

        # Use BCEWithLogits for AMP safety
        loss_sigmoid = F.binary_cross_entropy_with_logits(src_masks, target_masks)
        
        src_masks_sigmoid = src_masks.sigmoid()
        
        src_masks_flat = src_masks_sigmoid.flatten(1)
        target_masks_flat = target_masks.flatten(1)
        numerator = 2 * (src_masks_flat * target_masks_flat).sum(1)
        denominator = src_masks_flat.sum(1) + target_masks_flat.sum(1)
        loss_dice = 1 - (numerator + 1) / (denominator + 1)
        loss_dice = loss_dice.mean()
        
        loss_boundary = self.boundary_loss_func(src_masks_sigmoid, target_masks)
        return {'loss_mask': loss_sigmoid, 'loss_dice': loss_dice, 'loss_boundary': loss_boundary}
    
    def loss_consistency(self, outputs, targets, indices, num_boxes):
        if "aux_outputs" not in outputs:
            return {'loss_consistency': torch.tensor(0.0).to(outputs['pred_logits'].device)}
        src_masks_high = outputs['pred_masks'] 
        loss = 0.0
        for i, aux in enumerate(outputs["aux_outputs"]):
            src_masks_low = aux["pred_masks"] 
            target_size = src_masks_low.shape[-2:]
            src_masks_high_down = F.interpolate(src_masks_high, size=target_size, mode='bilinear', align_corners=False)
            loss += F.l1_loss(src_masks_high_down.sigmoid(), src_masks_low.sigmoid())
        return {'loss_consistency': loss}

    def _get_src_permutation_idx(self, indices):
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def forward(self, outputs, targets):
        indices = self.matcher(outputs, targets)
        num_boxes = sum(len(t["class_labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        
        losses = {}
        losses.update(self.loss_labels(outputs, targets, indices, num_boxes))
        losses.update(self.loss_masks(outputs, targets, indices, num_boxes))
        losses.update(self.loss_consistency(outputs, targets, indices, num_boxes))
        
        # User defined weights adapted
        w_focal = self.weight_dict.get('loss_ce', 0.25)
        w_parent = self.weight_dict.get('loss_parent', 0.20)
        w_kl = self.weight_dict.get('loss_kl', 0.40)
        
        w_mask = self.weight_dict.get('loss_mask', 5.0)
        w_dice = self.weight_dict.get('loss_dice', 5.0) 
        w_boundary = self.weight_dict.get('loss_boundary', 2.0)
        w_consistency = self.weight_dict.get('loss_consistency', 1.0)
        
        final_loss = (losses['loss_ce'] * w_focal + 
                     losses['loss_parent'] * w_parent + 
                     losses['loss_kl'] * w_kl +
                     losses['loss_mask'] * w_mask + 
                     losses['loss_dice'] * w_dice +
                     losses['loss_boundary'] * w_boundary +
                     losses['loss_consistency'] * w_consistency)
        
        return final_loss, losses


In [ ]:
# --- Training Loop Execution ---
print(f"Using device: {DEVICE}")

# 1. Data
# This will trigger download if not found
train_ds = ADE20kPanopticDataset(split="train", transform=get_transforms(IMAGE_SIZE))
val_ds = ADE20kPanopticDataset(split="validation", transform=get_transforms(IMAGE_SIZE))

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collafe_fn, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collafe_fn, num_workers=2)

# Configuration
IMAGE_SIZE = 512
BATCH_SIZE = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CLIP_MODEL_ID = "openai/clip-vit-base-patch16"
NUM_CLASSES = 150 
LORA_RANK = 16 # Adjustable LoRA Rank

# 2. Model
model = CLIPPanopticModel(num_classes=NUM_CLASSES, lora_rank=LORA_RANK)
model.to(DEVICE)

# Optimization: PyTorch 2.0 Compilation
# This fuses kernels for JAX-like performance
if hasattr(torch, "compile"):
    print("Compiling model with torch.compile...")
    model = torch.compile(model)
torch.backends.cudnn.benchmark = True

# 3. Loss
matcher = HungarianMatcher()
# Weights updated to prioritize Hierarchical Loss (Self-Supervised + Semantic)
weight_dict = {
    'loss_ce': 0.1,          # Reduced focal loss (let hierarchy drive)
    'loss_parent': 1.0,      # Coarse semantic grouping
    'loss_kl': 2.0,          # Soft target matching
    'loss_mask': 5.0,        # Shape
    'loss_dice': 5.0,        # Overlap
    'loss_boundary': 2.0,    # Edges
    'loss_consistency': 1.0  # Multi-scale
} 
# Note: criterion will autodownload CLIP for hierarchy on init
criterion = SetCriterion(num_classes=NUM_CLASSES, matcher=matcher, weight_dict=weight_dict, device=DEVICE).to(DEVICE)

# 4. Optimizer
param_dicts = [
    {"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], "lr": 1e-5},
    {"params": [p for n, p in model.named_parameters() if "decoder" in n and p.requires_grad], "lr": 1e-4},
]
optimizer = optim.AdamW(param_dicts, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler() # Mixed Precision Scaler

# 5. Loop
EPOCHS = 5
for epoch in range(EPOCHS):
    print(f"--- Epoch {epoch+1}/{EPOCHS} ---")
    
    # Train One Epoch Inline
    model.train()
    total_loss = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    
    for batch in pbar:
        pixel_values, targets = batch
        pixel_values = pixel_values.to(DEVICE)
        targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
        
        # Mixed Precision Training
        with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
            outputs = model(pixel_values)
            loss, loss_dict = criterion(outputs, targets)
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        pbar.set_postfix({"loss": loss.item()})
        
    print(f"Average Loss: {total_loss / len(train_loader):.4f}")
    
    # Validation & Visualization Step (Every Epoch)
    if (epoch + 1) % 1 == 0:
        # Visualize Prediction on a random sample
        print("Visualizing random sample...")
        rand_idx = np.random.randint(0, len(val_ds))
        visualize_prediction(model, val_ds, rand_idx, DEVICE)

torch.save(model.state_dict(), "clip_panoptic_lora.pth")
print("Model saved!")
