In [None]:

import time

import os
import torch
import numpy as np
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import matplotlib.pyplot as plt
import torchvision


NUM_CLASSES = 8
IGNORE_INDEX = 0
IMG_SIZE = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
AMP_ENABLED = (DEVICE == "cuda")
from matplotlib.patches import Patch

CLASS_INFO = {
    1: ("Background",    (128,   0,   0)),
    2: ("Building",       (  0, 128,   0)),
    3: ("Road",          (128, 128,   0)),
    4: ("Water",          (  0,   0, 128)),
    5: ("Barren",        (128,   0, 128)),
    6: ("Forest",       (  0, 128, 128)),
    7: ("Agriculture",    (128, 128, 128)),
}


COLOR_MAP = np.array([
    [0, 0, 0],         # 0: ignored
    [128, 0, 0],        # 1: background (incl. playground)
    [0, 128, 0],       # 2: building
    [128, 128, 0],    # 3: road
    [0, 0, 128],       # 4: water
    [128, 0, 128],    # 5: barren
    [0, 128, 128],      # 6: forest
    [128, 128, 128],   # 7: agriculture
], dtype=np.uint8)


CLASS_NAMES = [
    "No-data (ignored)",        # 0
    "Background",                # 1 (includes playground)
    "Building",                  # 2
    "Road",                      # 3
    "Water",                      # 4
    "Barren",                   # 5
    "Forest",                     # 6
    "Agriculture",              # 7
]


class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1, groups=1, use_dropout=False, p_drop=0.1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, s, p, bias=False, groups=groups)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.drop = nn.Dropout2d(p_drop) if use_dropout else nn.Identity()
        nn.init.kaiming_normal_(self.conv.weight, mode="fan_out", nonlinearity="relu")
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.drop(x)
        return x

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, use_dropout=False, p_drop=0.1):
        super().__init__()
        self.depth = ConvBNReLU(in_ch, in_ch, k=3, s=1, p=1,
                                groups=in_ch, use_dropout=use_dropout, p_drop=p_drop)
        self.point = ConvBNReLU(in_ch, out_ch, k=1, s=1, p=0,
                                use_dropout=use_dropout, p_drop=p_drop)
    def forward(self, x):
        x = self.depth(x)
        x = self.point(x)
        return x

class SelfAttentionBlock(nn.Module):
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim)
        self.drop = nn.Dropout(dropout)
    
    def forward(self, x):
        b, c, h, w = x.shape
        n = h * w
        x_flat = x.permute(0, 2, 3, 1).reshape(b, n, c)
        x_norm = self.norm(x_flat)
        q, k, v = self.qkv(x_norm).chunk(3, dim=-1)
        q = q.view(b, n, self.num_heads, c // self.num_heads).transpose(1, 2)
        k = k.view(b, n, self.num_heads, c // self.num_heads).transpose(1, 2)
        v = v.view(b, n, self.num_heads, c // self.num_heads).transpose(1, 2)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = attn @ v
        out = out.transpose(1, 2).reshape(b, n, c)
        out = self.drop(self.proj(out))
        out = out + x_flat
        out = out.reshape(b, h, w, c).permute(0, 3, 1, 2)
        return out

class GlobalContextBlock(nn.Module):
    def __init__(self, in_channels, reduction=4):
        super().__init__()
        self.attention = nn.Conv2d(in_channels, 1, kernel_size=1)
        self.transform = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False),
        )
    
    def forward(self, x):
        B, C, H, W = x.shape
        attn = self.attention(x)
        attn = attn.view(B, 1, -1)
        attn = torch.softmax(attn, dim=-1)
        attn = attn.view(B, 1, H, W)
        context = (x * attn).sum(dim=[2, 3], keepdim=True)
        out = self.transform(context)
        return x + out

class ConvNeXtTinyEncoder(nn.Module):
    def __init__(self, pretrained=False):
        super().__init__()
        try:
            backbone = torchvision.models.convnext_tiny(weights="DEFAULT" if pretrained else None)
        except TypeError:
            backbone = torchvision.models.convnext_tiny(pretrained=pretrained)
        self.stem = backbone.features[0]   # /4
        self.stage1 = backbone.features[1] # 1/4,  96C
        self.stage2 = backbone.features[2] # 1/8,  192C
        self.stage3 = backbone.features[3] # 1/16, 384C
        self.stage4 = backbone.features[4] # 1/32, 768C
    
    def forward(self, x):
        x = self.stem(x)
        e1 = self.stage1(x)      # [B,96,H/4,W/4]
        e2 = self.stage2(e1)     # [B,192,H/8,W/8]
        e3 = self.stage3(e2)     # [B,384,H/16,W/16]
        e4 = self.stage4(e3)     # [B,768,H/32,W/32]
        return e1, e2, e3, e4

class Decoder(nn.Module):
    def __init__(self, encoder_channels=(96, 192, 192, 384),
                 decoder_channels=(256, 96)):
        super().__init__()
        c1, c2, c3, c4 = encoder_channels
        d1, d2 = decoder_channels
  
        self.use_gc = True
        self.use_mhsa = True
        self.use_aux = True
        
        self.gc_e4 = GlobalContextBlock(c4)
        self.gc_e3 = GlobalContextBlock(c3)
        self.att_block = SelfAttentionBlock(c4, num_heads=8, dropout=0.1)
        self.up3 = nn.ConvTranspose2d(c4, c3, kernel_size=2, stride=2)
        self.dec3 = DepthwiseSeparableConv(c3 + c3 + c2, d1, use_dropout=True)  # 384+384+192=960 -> 256
        self.up2 = nn.ConvTranspose2d(d1, 128, kernel_size=2, stride=2)
        self.dec2 = DepthwiseSeparableConv(128 + c1, d2, use_dropout=True)  # 128+96=224 -> 96
        self.out_ch = d2
        self.aux_head_16 = nn.Conv2d(d1, NUM_CLASSES, kernel_size=1)
        self.aux_head_8  = nn.Conv2d(d2, NUM_CLASSES, kernel_size=1)
    
    def forward(self, e1, e2, e3, e4):
        if self.use_gc:
            e4 = self.gc_e4(e4)
            e3 = self.gc_e3(e3)
        if self.use_mhsa:
            e4 = self.att_block(e4)
        
        x = self.up3(e4)                    # 768->384
        x = torch.cat([x, e3, e2], dim=1)   # 384+384+192=960
        x = self.dec3(x)                    # 960->256
        aux_16 = self.aux_head_16(x) if self.use_aux else None
        
        x = self.up2(x)                     # 256->128
        x = torch.cat([x, e1], dim=1)       # 128+96=224
        x = self.dec2(x)                    # 224->96
        aux_8 = self.aux_head_8(x) if self.use_aux else None
        
        return x, aux_16, aux_8

class DetailBranch(nn.Module):
    def __init__(self, in_ch=3, out_ch=96):
        super().__init__()
        self.down = nn.Sequential(
            ConvBNReLU(in_ch, 32, k=3, s=2, p=1),   # 1/2
            ConvBNReLU(32, 64, k=3, s=2, p=1),      # 1/4
        )
        self.block = DepthwiseSeparableConv(64, out_ch, use_dropout=True)
    
    def forward(self, x):
        x = self.down(x)
        x = self.block(x)
        return x

class SegmentationModel(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.encoder = ConvNeXtTinyEncoder(pretrained=True)
        self.decoder = Decoder(
            encoder_channels=(96, 192, 192, 384),
            decoder_channels=(256, 96)
        )
        self.use_detail = True
        self.use_boundary = True
        self.detail_branch = DetailBranch(in_ch=3, out_ch=96)
        self.fuse = DepthwiseSeparableConv(self.decoder.out_ch + 96, 96, use_dropout=True)
        self.seg_head = nn.Conv2d(96, num_classes, kernel_size=1)
        self.boundary_head = nn.Conv2d(96, 1, kernel_size=1)
    
    def forward(self, x):
        e1, e2, e3, e4 = self.encoder(x)
        dec_1_4, aux_16, aux_8 = self.decoder(e1, e2, e3, e4)
        
        if self.use_detail:
            detail = self.detail_branch(x)
            fused = torch.cat([dec_1_4, detail], dim=1)
        else:
            fused = dec_1_4
        
        fused = self.fuse(fused)
        seg_logits_1_4 = self.seg_head(fused)
        
        if self.use_boundary:
            boundary_logits_1_4 = self.boundary_head(fused)
        else:
            boundary_logits_1_4 = None
        
        H, W = x.shape[2], x.shape[3]
        seg_logits = F.interpolate(seg_logits_1_4, size=(H, W), mode="bilinear")
        
        if boundary_logits_1_4 is not None:
            boundary_logits = F.interpolate(boundary_logits_1_4, size=(H, W), mode="bilinear")
        else:
            boundary_logits = None
        
        if aux_16 is not None:
            aux_16 = F.interpolate(aux_16, size=(H, W), mode="bilinear")
        if aux_8 is not None:
            aux_8 = F.interpolate(aux_8, size=(H, W), mode="bilinear")
        
        return {
            "logits": seg_logits,
            "aux_16": aux_16,
            "aux_8": aux_8,
            "boundary_logits": boundary_logits
        }

def load_model(model_path):
    print(f"Loading model from: {model_path}")

    model = SegmentationModel(num_classes=NUM_CLASSES)

    checkpoint = torch.load(model_path, map_location=DEVICE)
    state_dict = checkpoint["model_state"]


    state_dict = {
        k: v for k, v in state_dict.items()
        if not k.endswith("total_ops") and not k.endswith("total_params")
    }

    model.load_state_dict(state_dict, strict=True)
    model.to(DEVICE)
    model.eval()

    print(f"Model loaded successfully. Best mIoU: {checkpoint.get('best_mIoU', 'N/A')}")
    print(f"Model device: {next(model.parameters()).device}")

    return model

def build_legend(class_ids):
    handles = []
    for cid in sorted(class_ids):
        if cid in CLASS_INFO:
            name, color = CLASS_INFO[cid]
            handles.append(
                Patch(facecolor=np.array(color) / 255.0, label=name)
            )
    return handles

def preprocess_image(image_path):

    img = Image.open(image_path).convert("RGB")
    orig_size = img.size  # (width, height)
    
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])
    
    img_tensor = transform(img).unsqueeze(0)  # Add batch dimension
    return img_tensor, orig_size, img

def decode_segmap(mask_np):
    h, w = mask_np.shape
    rgb = COLOR_MAP[mask_np.flatten()].reshape(h, w, 3)
    return rgb

def load_ground_truth(gt_path, image_size):
    if not os.path.exists(gt_path):
        return None

    gt = Image.open(gt_path).convert("L")
    gt = gt.resize(image_size, Image.NEAREST)
    gt_np = np.array(gt, dtype=np.int64)

    # ðŸ”¥ CRITICAL: replicate training preprocessing
    gt_np[gt_np == 8] = 1      # playground â†’ background
    gt_np[gt_np == 0] = 0      # ignored stays ignored

    return gt_np


def calculate_metrics(pred_mask, gt_mask):
    if gt_mask is None:
        return None
    

    pred_flat = pred_mask.flatten()
    gt_flat = gt_mask.flatten()
    

    valid_mask = gt_flat != IGNORE_INDEX
    pred_flat = pred_flat[valid_mask]
    gt_flat = gt_flat[valid_mask]
    
    metrics = {}

    accuracy = np.sum(pred_flat == gt_flat) / len(pred_flat)
    metrics['accuracy'] = accuracy
    
    classes_in_gt = np.unique(gt_flat)
    classes_in_gt = classes_in_gt[
        (classes_in_gt != IGNORE_INDEX) &
        (classes_in_gt != 8)
    ]


    metrics['per_class'] = {}
    metrics['present_classes'] = classes_in_gt.tolist()
    
    for cls in classes_in_gt:
        tp = np.sum((pred_flat == cls) & (gt_flat == cls))
        fp = np.sum((pred_flat == cls) & (gt_flat != cls))
        fn = np.sum((pred_flat != cls) & (gt_flat == cls))
        
        precision = tp / (tp + fp + 1e-10)
        recall = tp / (tp + fn + 1e-10)
        iou = tp / (tp + fp + fn + 1e-10)
        f1 = 2 * precision * recall / (precision + recall + 1e-10)
        
        metrics['per_class'][int(cls)] = {
            'precision': precision,
            'recall': recall,
            'iou': iou,
            'f1': f1,
            'support': np.sum(gt_flat == cls),
            'predicted_count': np.sum(pred_flat == cls)
        }
    
    if len(classes_in_gt) > 0:
        ious = [metrics['per_class'][int(cls)]['iou'] for cls in classes_in_gt]
        metrics['mean_iou'] = np.mean(ious)
    else:
        metrics['mean_iou'] = 0
    
    classes_in_pred = np.unique(pred_flat)
    metrics['predicted_classes'] = classes_in_pred.tolist()

    fp_classes = [int(c) for c in classes_in_pred if c not in classes_in_gt and c != IGNORE_INDEX]
    metrics['false_positive_classes'] = fp_classes
    fn_classes = [int(c) for c in classes_in_gt if c not in classes_in_pred]
    metrics['false_negative_classes'] = fn_classes
    
    return metrics

def create_comparison_plot(original_img, pred_mask, gt_mask, metrics, image_name):
    """Create comparison plot with metrics"""

    pred_color = decode_segmap(pred_mask)
    gt_color = decode_segmap(gt_mask) if gt_mask is not None else None
    

    overlay = 0.6 * np.array(original_img).astype(np.float32) / 255.0 + 0.4 * pred_color.astype(np.float32) / 255.0
    overlay = np.clip(overlay * 255, 0, 255).astype(np.uint8)
    
    if gt_mask is not None:
        error_mask = np.zeros_like(pred_color)
 
        error_mask[(pred_mask != gt_mask) & (gt_mask != IGNORE_INDEX)] = [255, 0, 0]
 
        error_mask[(pred_mask != gt_mask) & (pred_mask != IGNORE_INDEX)] = [0, 0, 255]
        error_mask[gt_mask == IGNORE_INDEX] = 0  # Ignore background
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    

    axes[0].imshow(original_img)
    axes[0].set_title(f"Original Image\n{image_name}", fontsize=12)
    axes[0].axis('off')
    

    if gt_mask is not None:
        axes[1].imshow(gt_color)
    
        gt_ids = np.unique(gt_mask)
        gt_names = [CLASS_NAMES[c] for c in gt_ids if c < len(CLASS_NAMES)]
    
        axes[1].set_title("Ground Truth", fontsize=12, color='green')

        axes[1].axis('off')
    else:
        axes[1].text(0.5, 0.5, "Ground Truth\nNot Available",
                     ha='center', va='center', fontsize=12)
        axes[1].axis('off')

    

    axes[2].imshow(pred_color)
    
    pred_ids = np.unique(pred_mask)
    pred_names = [CLASS_NAMES[c] for c in pred_ids if c < len(CLASS_NAMES)]
    
    axes[2].set_title("Prediction", fontsize=12, color='blue')

    axes[2].axis('off')

    legend_classes = set(pred_ids)
    
    if gt_mask is not None:
        legend_classes.update(np.unique(gt_mask))
    
    legend_handles = build_legend(legend_classes)
    
    axes[2].legend(
        handles=legend_handles,
        loc="lower right",
        fontsize=9,
        frameon=True
    )



    axes[3].imshow(overlay)
    axes[3].set_title("Prediction Overlay", fontsize=12)
    axes[3].axis('off')
    

    if gt_mask is not None:
        axes[4].imshow(error_mask)
        axes[4].set_title("Error Map\nRed: FP, Blue: FN", fontsize=12)
        axes[4].axis('off')
    else:
        axes[4].axis('off')
    

    axes[5].axis('off')
    if metrics is not None:
        metrics_text = f"Overall Metrics:\n"
        metrics_text += f"Accuracy: {metrics['accuracy']:.3f}\n"
        metrics_text += f"mIoU: {metrics['mean_iou']:.3f}\n\n"
        
        metrics_text += "Classes in GT: {}\n".format(metrics['present_classes'])
        metrics_text += "Classes in Pred: {}\n".format(metrics['predicted_classes'])
        
        if metrics['false_positive_classes']:
            metrics_text += f"FP Classes: {metrics['false_positive_classes']}\n"
        if metrics['false_negative_classes']:
            metrics_text += f"FN Classes: {metrics['false_negative_classes']}\n"
        
        metrics_text += "\nPer-Class IoU:\n"
        for cls in sorted(metrics['per_class'].keys()):
            iou = metrics['per_class'][cls]['iou']
            support = metrics['per_class'][cls]['support']
            pred_count = metrics['per_class'][cls]['predicted_count']
            metrics_text += f"Class {cls}: {iou:.3f} (GT:{support}, Pred:{pred_count})\n"
        
        axes[5].text(0.1, 0.95, metrics_text, fontsize=9, 
                    verticalalignment='top', family='monospace')
    else:
        axes[5].text(0.5, 0.5, "No metrics available\n(Ground truth missing)", 
                    ha='center', va='center', fontsize=12)
    
    plt.tight_layout()
    return fig

def save_predictions(image_path, pred_mask, gt_mask, color_mask, overlay, 
                     comparison_fig, output_dir, metrics):
    """Save all prediction outputs"""
    os.makedirs(output_dir, exist_ok=True)
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    

    mask_path = os.path.join(output_dir, f"{base_name}_pred_mask.png")
    Image.fromarray(color_mask).save(mask_path)
    

    overlay_path = os.path.join(output_dir, f"{base_name}_overlay.png")
    Image.fromarray(overlay).save(overlay_path)
    

    pred_path = os.path.join(output_dir, f"{base_name}_pred.png")
    Image.fromarray(pred_mask.astype(np.uint8)).save(pred_path)
 
    if gt_mask is not None:
        gt_path = os.path.join(output_dir, f"{base_name}_gt.png")
        Image.fromarray(gt_mask.astype(np.uint8)).save(gt_path)
        

        gt_color = decode_segmap(gt_mask)
        gt_color_path = os.path.join(output_dir, f"{base_name}_gt_color.png")
        Image.fromarray(gt_color).save(gt_color_path)
    

    comparison_path = os.path.join(output_dir, f"{base_name}_comparison.png")
    comparison_fig.savefig(comparison_path, dpi=150, bbox_inches='tight')
    plt.close(comparison_fig)
    

    if metrics is not None:
        metrics_path = os.path.join(output_dir, f"{base_name}_metrics.txt")
        with open(metrics_path, 'w') as f:
            f.write(f"Image: {image_path}\n")
            f.write(f"Overall Accuracy: {metrics['accuracy']:.4f}\n")
            f.write(f"mIoU (for present classes): {metrics['mean_iou']:.4f}\n")
            f.write(f"Classes in GT: {metrics['present_classes']}\n")
            f.write(f"Classes in Prediction: {metrics['predicted_classes']}\n")
            
            if metrics['false_positive_classes']:
                f.write(f"False Positive Classes (predicted but not in GT): {metrics['false_positive_classes']}\n")
            if metrics['false_negative_classes']:
                f.write(f"False Negative Classes (in GT but not predicted): {metrics['false_negative_classes']}\n")
            
            f.write("\nPer-Class Metrics:\n")
            f.write("Class | Precision | Recall   | IoU      | F1-Score | GT Pixels | Pred Pixels\n")
            f.write("-" * 80 + "\n")
            
            for cls in sorted(metrics['per_class'].keys()):
                m = metrics['per_class'][cls]
                f.write(f"{cls:5d} | {m['precision']:.4f}   | {m['recall']:.4f}   | "
                       f"{m['iou']:.4f}   | {m['f1']:.4f}   | {m['support']:9d} | {m['predicted_count']:11d}\n")
    
    saved_paths = {
        "mask_path": mask_path,
        "overlay_path": overlay_path,
        "pred_path": pred_path,
        "comparison_path": comparison_path,
    }
    
    if gt_mask is not None:
        saved_paths["gt_path"] = gt_path
        saved_paths["gt_color_path"] = gt_color_path
        saved_paths["metrics_path"] = metrics_path
    
    return saved_paths


def predict_single_image(model, image_path, gt_path=None, output_dir=None, return_all=False):

    img_tensor, orig_size, orig_img = preprocess_image(image_path)
    img_tensor = img_tensor.to(DEVICE)



    with torch.no_grad():
        with torch.amp.autocast('cuda', enabled=AMP_ENABLED):
            start_time = time.perf_counter()
            outputs = model(img_tensor)
            logits = outputs["logits"]
            

            probs = F.softmax(logits, dim=1)
            pred = torch.argmax(logits, dim=1)
            end_time = time.perf_counter()
            


            pred_np = pred[0].cpu().numpy()  # [H, W]
            probs_np = probs[0].cpu().numpy()  # [C, H, W]
        

    
    inference_time_ms = (end_time - start_time) * 1000
    print(f"[Inference Time] Prediction mask computed in {inference_time_ms:.2f} ms")

    pred_pil = Image.fromarray(pred_np.astype(np.uint8))
    pred_resized = pred_pil.resize(orig_size, Image.NEAREST)
    pred_final = np.array(pred_resized)

    gt_final = None
    if gt_path:
        gt_final = load_ground_truth(gt_path, orig_size)
    

    metrics = None
    if gt_final is not None:
        metrics = calculate_metrics(pred_final, gt_final)
    

    color_mask = decode_segmap(pred_final)
    

    comparison_fig = create_comparison_plot(
        orig_img, pred_final, gt_final, metrics, 
        os.path.basename(image_path)
    )
    

    orig_np = np.array(orig_img)
    overlay = 0.6 * orig_np.astype(np.float32) / 255.0 + 0.4 * color_mask.astype(np.float32) / 255.0
    overlay = np.clip(overlay * 255, 0, 255).astype(np.uint8)
    
   
    saved_paths = None
    if output_dir:
        saved_paths = save_predictions(
            image_path, pred_final, gt_final, color_mask, overlay, 
            comparison_fig, output_dir, metrics
        )
        print(f"âœ“ Saved predictions for {os.path.basename(image_path)}")
        if metrics:
            print(f"  Accuracy: {metrics['accuracy']:.3f}, mIoU: {metrics['mean_iou']:.3f}")
    
    if return_all:
        return {
            "prediction": pred_final,
            "ground_truth": gt_final,
            "probabilities": probs_np,
            "color_mask": color_mask,
            "overlay": overlay,
            "metrics": metrics,
            "saved_paths": saved_paths
        }
    else:
        return pred_final, gt_final, metrics

def predict_batch(model, image_dir, gt_dir=None, output_dir=None, batch_size=1):
    """Run inference on all images in a directory with optional ground truth"""
    import glob
    
    # Find all images
    image_extensions = ["*.png", "*.jpg", "*.jpeg", "*.tif", "*.tiff"]
    image_paths = []
    for ext in image_extensions:
        image_paths.extend(glob.glob(os.path.join(image_dir, ext)))
    
    print(f"Found {len(image_paths)} images in {image_dir}")
    
    # Process each image
    all_metrics = []
    for i, img_path in enumerate(image_paths):
        print(f"\nProcessing {i+1}/{len(image_paths)}: {os.path.basename(img_path)}")
        

        gt_path = None
        if gt_dir:
   
            base_name = os.path.splitext(os.path.basename(img_path))[0]
            possible_gt_names = [
                f"{base_name}.png",
                f"{base_name}.tif",
                f"{base_name}.jpg",
                f"{base_name}_mask.png",
                f"{base_name}_label.png",
                base_name.replace("image", "mask"),
                base_name.replace("img", "mask"),
            ]
            
            for gt_name in possible_gt_names:
                gt_test_path = os.path.join(gt_dir, gt_name)
                if os.path.exists(gt_test_path):
                    gt_path = gt_test_path
                    break
        

        pred, gt, metrics = predict_single_image(
            model, img_path, gt_path, output_dir, return_all=False
        )
        
        if metrics:
            all_metrics.append(metrics)
            print(f"  Accuracy: {metrics['accuracy']:.3f}, mIoU: {metrics['mean_iou']:.3f}")
    
    # Calculate average metrics if available
    if all_metrics:
        print("\n" + "="*60)
        print("AGGREGATE METRICS")
        print("="*60)
        
        avg_accuracy = np.mean([m['accuracy'] for m in all_metrics])
        avg_miou = np.mean([m['mean_iou'] for m in all_metrics])
        
        print(f"Average Accuracy: {avg_accuracy:.4f}")
        print(f"Average mIoU: {avg_miou:.4f}")
        

        print("\nPer-Class Average IoU (only for classes that appear):")
        all_classes = set()
        for m in all_metrics:
            all_classes.update(m['per_class'].keys())
        
        class_stats = {}
        for cls in sorted(all_classes):
            class_ious = []
            class_supports = []
            for m in all_metrics:
                if cls in m['per_class']:
                    class_ious.append(m['per_class'][cls]['iou'])
                    class_supports.append(m['per_class'][cls]['support'])
            
            if class_ious:
                avg_iou = np.mean(class_ious)
                total_support = np.sum(class_supports)
                num_images_with_class = len(class_ious)
                class_stats[cls] = {
                    'avg_iou': avg_iou,
                    'total_support': total_support,
                    'num_images': num_images_with_class
                }
                print(f"  Class {cls}: {avg_iou:.4f} (appears in {num_images_with_class} images, {total_support} total pixels)")
        
       
        if output_dir:
            agg_metrics_path = os.path.join(output_dir, "aggregate_metrics.txt")
            with open(agg_metrics_path, 'w') as f:
                f.write(f"Total Images: {len(all_metrics)}\n")
                f.write(f"Average Accuracy: {avg_accuracy:.4f}\n")
                f.write(f"Average mIoU: {avg_miou:.4f}\n\n")
                
                f.write("Per-Class Statistics:\n")
                f.write("Class | Avg IoU | # Images | Total Pixels\n")
                f.write("-" * 50 + "\n")
                
                for cls, stats in class_stats.items():
                    f.write(f"{cls:5d} | {stats['avg_iou']:.4f}  | {stats['num_images']:8d} | {stats['total_support']:12d}\n")
            
            print(f"\nâœ“ Aggregate metrics saved to: {agg_metrics_path}")
    
    return all_metrics


def run_example_inference():
    """Example showing how to use the inference functions with ground truth"""
    print("=" * 60)
    print("EARTHVQA INFERENCE SCRIPT WITH GROUND TRUTH COMPARISON")
    print("=" * 60)
    
    # === CONFIGURE THESE PATHS ===
    MODEL_PATH = "/kaggle/input/loveda-test-wrong-original/best_model_earthvqa.pth"  # Your trained model
    IMAGE_PATH = "/kaggle/input/loveda-test-wrong-original/Val/Val/images_png/2544.png"  # Test image
    GT_PATH = "/kaggle/input/loveda-test-wrong-original/Val/Val/masks_png/2544.png"  # Ground truth
    OUTPUT_DIR = "/kaggle/working/inference_results"  # Where to save results
    

    # IMAGE_DIR = "/kaggle/input/loveda-test-wrong-original/Test/Test/images_png"
    # GT_DIR = "/kaggle/input/loveda-test-wrong-original/Test/Test/masks_png"
    
    print(f"Model path: {MODEL_PATH}")
    print(f"Image path: {IMAGE_PATH}")
    print(f"Ground truth path: {GT_PATH}")
    print(f"Output directory: {OUTPUT_DIR}")
    print(f"Device: {DEVICE}")
    print("-" * 60)
    

    model = load_model(MODEL_PATH)
    
    print("\nRunning inference with ground truth comparison...")
    result = predict_single_image(model, IMAGE_PATH, GT_PATH, OUTPUT_DIR, return_all=True)
    

    print("\n" + "=" * 60)
    print("INFERENCE RESULTS")
    print("=" * 60)
    print(f"Prediction shape: {result['prediction'].shape}")
    print(f"Unique classes predicted: {np.unique(result['prediction'])}")
    
    if result['ground_truth'] is not None:
        print(f"Unique classes in ground truth: {np.unique(result['ground_truth'])}")
    
    if result['metrics']:
        print(f"\nEvaluation Metrics:")
        print(f"  Overall Accuracy: {result['metrics']['accuracy']:.4f}")
        print(f"  Mean IoU (for present classes): {result['metrics']['mean_iou']:.4f}")
        
        print(f"\nClasses in Ground Truth: {result['metrics']['present_classes']}")
        print(f"Classes in Prediction: {result['metrics']['predicted_classes']}")
        
        if result['metrics']['false_positive_classes']:
            print(f"False Positive Classes (predicted but not in GT): {result['metrics']['false_positive_classes']}")
        if result['metrics']['false_negative_classes']:
            print(f"False Negative Classes (in GT but not predicted): {result['metrics']['false_negative_classes']}")
        
        print("\nPer-Class IoU (only for classes present in ground truth):")
        for cls in sorted(result['metrics']['per_class'].keys()):
            iou = result['metrics']['per_class'][cls]['iou']
            support = result['metrics']['per_class'][cls]['support']
            pred_count = result['metrics']['per_class'][cls]['predicted_count']
            print(f"  Class {cls}: {iou:.4f} (GT: {support}px, Pred: {pred_count}px)")
    
    if result['saved_paths']:
        print("\nSaved files:")
        for key, path in result['saved_paths'].items():
            print(f"  {key}: {path}")
    

    print("\nDisplaying comparison results...")
    plt.figure(figsize=(15, 10))
    img = plt.imread(result['saved_paths']['comparison_path'])
    plt.imshow(img)
    plt.axis('off')
    plt.title("Inference Results Comparison", fontsize=14, pad=20)
    plt.tight_layout()
    plt.show()
    
    print(f"\nâœ“ Inference complete! Results saved to: {OUTPUT_DIR}")
    
    return model, result


if __name__ == "__main__":
    run_example_inference()
    
    # For batch processing with ground truth:
    # model = load_model("/kaggle/working/best_model_earthvqa.pth")
    # results = predict_batch(model, 
    #                         "/kaggle/input/loveda-test-wrong-original/Test/Test/images_png",
    #                         "/kaggle/input/loveda-test-wrong-original/Test/Test/masks_png",
    #                         "/kaggle/working/batch_results")