In [None]:

import os
import random
from typing import Dict, Tuple

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt

SEED = 42

# name the bad images
BAD_IMAGE_BASENAMES = {
    "1004_r270.png",
}


RESUME_TRAINING = True      
SAVE_EVERY_EPOCH = False     

# ablations

ABLATE_GC = True           # GlobalContextBlock (GC) on e3/e4
ABLATE_MHSA = True        # SelfAttentionBlock on e4
ABLATE_DETAIL = True       # Detail branch at 1/4
ABLATE_AUX = True          # Aux heads (aux_16, aux_8)
ABLATE_BOUNDARY = True     # Boundary prediction + loss
ABLATE_LOVASZ = True       # Lovasz-Softmax in main loss


ROOT_DIR = "/kaggle/input/loveda" 
TRAIN_DIR = os.path.join(ROOT_DIR, "Train/Train")
VAL_DIR = os.path.join(ROOT_DIR, "Val/Val")
TEST_DIR = os.path.join(ROOT_DIR, "Test/Test")

TRAIN_IMG_DIR = os.path.join(TRAIN_DIR, "images_png")
TRAIN_MASK_DIR = os.path.join(TRAIN_DIR, "masks_png")
VAL_IMG_DIR = os.path.join(VAL_DIR, "images_png")
VAL_MASK_DIR = os.path.join(VAL_DIR, "masks_png")
TEST_IMG_DIR = os.path.join(TEST_DIR, "images_png")

# output
WORK_DIR = "/kaggle/working"
os.makedirs(WORK_DIR, exist_ok=True)
BEST_MODEL_PATH = os.path.join(WORK_DIR, "best_model_earthvqa.pth")
PLOT_PATH = os.path.join(WORK_DIR, "train_curves.png")
TEST_OUT_DIR = os.path.join(WORK_DIR, "test_inference")
os.makedirs(TEST_OUT_DIR, exist_ok=True)

NUM_CLASSES = 8             # labels: 0 to 8, 0 = background (ignored)
IGNORE_INDEX = 0            # background ignored in loss

BATCH_SIZE = 4             
NUM_EPOCHS = 25
EARLY_STOP_PATIENCE = 3

LR = 1e-4
WEIGHT_DECAY = 1e-5

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
AMP_ENABLED = (DEVICE == "cuda")

IMG_SIZE = 512  

LAMBDA_MAIN = 1.0
LAMBDA_AUX = 0.3
LAMBDA_BOUNDARY = 0.5

# indices: 0  1  2  3  4  5  6  7  8
COLOR_MAP = np.array([
    [0, 0, 0],        # 0
    [128, 0, 0],       # 1
    [0, 128, 0],       # 2
    [128, 128, 0],      # 3
    [0, 0, 128],      # 4
    [128, 0, 128],      # 5
    [0, 128, 128],     # 6
    [128, 128, 128],  # 7
    [64, 0, 0],        # 8   merged with class 1 background
], dtype=np.uint8)



def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(SEED)



class EarthVQADataset(Dataset):
    def __init__(self, img_dir: str, mask_dir: str, img_size: int = 1024):
        self.img_dir = img_dir
        self.mask_dir = mask_dir

        all_img_paths = sorted([
            os.path.join(img_dir, f) for f in os.listdir(img_dir)
            if f.lower().endswith((".png", ".jpg", ".jpeg", ".tif"))
        ])

        self.img_paths = []
        self.mask_paths = []

        for p in all_img_paths:
            fname = os.path.basename(p)
            if fname in BAD_IMAGE_BASENAMES:
                print(f"[INFO] Skipping corrupted image and mask: {fname}")
                continue

            self.img_paths.append(p)
            self.mask_paths.append(os.path.join(mask_dir, fname))

        self.img_size = img_size

        self.mean = (0.485, 0.456, 0.406)
        self.std = (0.229, 0.224, 0.225)

        self.img_tf = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.mean, std=self.std),
        ])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx: int):
        img_path = self.img_paths[idx]
        mask_path = self.mask_paths[idx]

        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")  # single-channel labels

        img_t = self.img_tf(img)

        mask_resized = transforms.functional.resize(
            mask, (self.img_size, self.img_size), interpolation=Image.NEAREST
        )
        mask_arr = np.array(mask_resized, dtype=np.int64)

# remapping class 8 to 1
        mask_arr[mask_arr == 8] = 1   # playground to background

        
        mask_t = torch.from_numpy(mask_arr)


        return img_t, mask_t


class EarthVQATestDataset(Dataset):
    """
    Test dataset with only images (no masks).
    """
    def __init__(self, img_dir: str, img_size: int = 1024):
        self.img_dir = img_dir
        self.img_paths = sorted([
            os.path.join(img_dir, f) for f in os.listdir(img_dir)
            if f.lower().endswith((".png", ".jpg", ".jpeg", ".tif"))
        ])
        self.img_size = img_size

        self.mean = (0.485, 0.456, 0.406)
        self.std = (0.229, 0.224, 0.225)

        self.img_tf = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.mean, std=self.std),
        ])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx: int):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert("RGB")
        img_t = self.img_tf(img)
        return img_t, os.path.basename(img_path)


# Model archi

class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1, groups=1, use_dropout=False, p_drop=0.1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, s, p, bias=False, groups=groups)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.drop = nn.Dropout2d(p_drop) if use_dropout else nn.Identity()

        nn.init.kaiming_normal_(self.conv.weight, mode="fan_out", nonlinearity="relu")

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.drop(x)
        return x


class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, use_dropout=False, p_drop=0.1):
        super().__init__()
        self.depth = ConvBNReLU(in_ch, in_ch, k=3, s=1, p=1,
                                groups=in_ch, use_dropout=use_dropout, p_drop=p_drop)
        self.point = ConvBNReLU(in_ch, out_ch, k=1, s=1, p=0,
                                use_dropout=use_dropout, p_drop=p_drop)

    def forward(self, x):
        x = self.depth(x)
        x = self.point(x)
        return x


class SelfAttentionBlock(nn.Module):
    
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)
        self.norm = nn.LayerNorm(dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        # x: [B, C, H, W]
        b, c, h, w = x.shape
        n = h * w

        #
        x_flat = x.permute(0, 2, 3, 1).reshape(b, n, c)
        x_norm = self.norm(x_flat)

        q, k, v = self.qkv(x_norm).chunk(3, dim=-1)

        q = q.view(b, n, self.num_heads, c // self.num_heads).transpose(1, 2)
        k = k.view(b, n, self.num_heads, c // self.num_heads).transpose(1, 2)
        v = v.view(b, n, self.num_heads, c // self.num_heads).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = attn @ v

        out = out.transpose(1, 2).reshape(b, n, c)
        out = self.drop(self.proj(out))

        out = out + x_flat
        out = out.reshape(b, h, w, c).permute(0, 3, 1, 2)
        return out


class GlobalContextBlock(nn.Module):
    
    def __init__(self, in_channels, reduction=4):
        super().__init__()

       
        self.attention = nn.Conv2d(in_channels, 1, kernel_size=1)

        self.transform = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False),
        )

    def forward(self, x):
        B, C, H, W = x.shape

      
        attn = self.attention(x)
        attn = attn.view(B, 1, -1)
        attn = torch.softmax(attn, dim=-1)
        attn = attn.view(B, 1, H, W)

        context = (x * attn).sum(dim=[2, 3], keepdim=True)
        out = self.transform(context)
        return x + out


class ConvNeXtTinyEncoder(nn.Module):
    """
    ConvNeXt-Tiny encoder.
    Returns features at:
        e1: 1/4,   C=96
        e2: 1/8,   C=192
        e3: 1/16,  C=384
        e4: 1/32,  C=768
    """
    def __init__(self, pretrained: bool = True):
        super().__init__()
        try:
            if pretrained:
                backbone = torchvision.models.convnext_tiny(weights="DEFAULT")
            else:
                backbone = torchvision.models.convnext_tiny(weights=None)
        except TypeError:
            backbone = torchvision.models.convnext_tiny(pretrained=pretrained)

        self.stem = backbone.features[0]   # /4
        self.stage1 = backbone.features[1] # 1/4,  96C
        self.stage2 = backbone.features[2] # 1/8,  192C
        self.stage3 = backbone.features[3] # 1/16, 384C
        self.stage4 = backbone.features[4] # 1/32, 768C

    def forward(self, x):
        x = self.stem(x)         # [B,96,H/4,W/4]
        e1 = self.stage1(x)      # [B,96,H/4,W/4]
        e2 = self.stage2(e1)     # [B,192,H/8,W/8]
        e3 = self.stage3(e2)     # [B,384,H/16,W/16]
        e4 = self.stage4(e3)     # [B,768,H/32,W/32]
        return e1, e2, e3, e4


class Decoder(nn.Module):

    def __init__(self, encoder_channels=(96, 192, 192, 384),
                 decoder_channels=(256, 96)):
        super().__init__()

        c1, c2, c3, c4 = encoder_channels
        d1, d2 = decoder_channels

        # Ablation toggles
        self.use_gc = ABLATE_GC
        self.use_mhsa = ABLATE_MHSA
        self.use_aux = ABLATE_AUX

        self.gc_e4 = GlobalContextBlock(c4)
        self.gc_e3 = GlobalContextBlock(c3)
        self.att_block = SelfAttentionBlock(c4, num_heads=8, dropout=0.1)

        self.up3 = nn.ConvTranspose2d(c4, c3, kernel_size=2, stride=2)

        self.dec3 = DepthwiseSeparableConv(576, d1, use_dropout=True)

        self.up2 = nn.ConvTranspose2d(d1, 128, kernel_size=2, stride=2)
        self.dec2 = DepthwiseSeparableConv(128 + c1, d2, use_dropout=True)

        self.out_ch = d2

        self.aux_head_16 = nn.Conv2d(d1, NUM_CLASSES, kernel_size=1)
        self.aux_head_8  = nn.Conv2d(d2, NUM_CLASSES, kernel_size=1)

    def forward(self, e1, e2, e3, e4):

        if self.use_gc:
            e4 = self.gc_e4(e4)
            e3 = self.gc_e3(e3)

        if self.use_mhsa:
            e4 = self.att_block(e4)

        x = self.up3(e4)
        x = torch.cat([x, e3, e2], dim=1)
        x = self.dec3(x)
        aux_16 = self.aux_head_16(x) if self.use_aux else None

        x = self.up2(x)
        x = torch.cat([x, e1], dim=1)
        x = self.dec2(x)
        aux_8 = self.aux_head_8(x) if self.use_aux else None

        return x, aux_16, aux_8



class DetailBranch(nn.Module):

    def __init__(self, in_ch=3, out_ch=96):
        super().__init__()
        self.down = nn.Sequential(
            ConvBNReLU(in_ch, 32, k=3, s=2, p=1),   # 1/2
            ConvBNReLU(32, 64, k=3, s=2, p=1),      # 1/4
        )
        self.block = DepthwiseSeparableConv(64, out_ch, use_dropout=True)

    def forward(self, x):
        x = self.down(x)
        x = self.block(x)
        return x


class SegmentationModel(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()

        self.encoder = ConvNeXtTinyEncoder(pretrained=True)

        self.decoder = Decoder(
            encoder_channels=(96, 192, 192, 384),
            decoder_channels=(256, 96)
        )

        self.use_detail = ABLATE_DETAIL
        self.use_boundary = ABLATE_BOUNDARY

        self.detail_branch = DetailBranch(in_ch=3, out_ch=96)
        self.fuse = DepthwiseSeparableConv(self.decoder.out_ch + 96, 96, use_dropout=True)

        self.seg_head = nn.Conv2d(96, num_classes, kernel_size=1)
        self.boundary_head = nn.Conv2d(96, 1, kernel_size=1)

    def forward(self, x):
        e1, e2, e3, e4 = self.encoder(x)
        dec_1_4, aux_16, aux_8 = self.decoder(e1, e2, e3, e4)

        if self.use_detail:
            detail = self.detail_branch(x)
            fused = torch.cat([dec_1_4, detail], dim=1)
        else:
            fused = dec_1_4

        fused = self.fuse(fused)

        seg_logits_1_4 = self.seg_head(fused)

        if self.use_boundary:
            boundary_logits_1_4 = self.boundary_head(fused)
        else:
            boundary_logits_1_4 = None

        H, W = x.shape[2], x.shape[3]
        seg_logits = F.interpolate(seg_logits_1_4, size=(H, W), mode="bilinear")

        if boundary_logits_1_4 is not None:
            boundary_logits = F.interpolate(boundary_logits_1_4, size=(H, W), mode="bilinear")
        else:
            boundary_logits = None

        if aux_16 is not None:
            aux_16 = F.interpolate(aux_16, size=(H, W), mode="bilinear")
        if aux_8 is not None:
            aux_8 = F.interpolate(aux_8, size=(H, W), mode="bilinear")

        return {
            "logits": seg_logits,
            "aux_16": aux_16,
            "aux_8": aux_8,
            "boundary_logits": boundary_logits
        }


def lovasz_grad(gt_sorted):
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.cumsum(0)
    union = gts + (1 - gt_sorted).cumsum(0)
    jaccard = 1.0 - intersection / union
    if gt_sorted.numel() > 1:
        jaccard[1:] = jaccard[1:] - jaccard[:-1]
    return jaccard


def lovasz_softmax_flat(probs, labels):
    C = probs.size(1)
    losses = []
    for c in range(C):
        fg = (labels == c).float()
        if fg.sum() == 0:
            continue

        class_pred = probs[:, c]
        errors = (fg - class_pred).abs()

        errors_sorted, perm = torch.sort(errors, descending=True)
        fg_sorted = fg[perm]

        grad = lovasz_grad(fg_sorted)
        losses.append(torch.dot(errors_sorted, grad))

    if len(losses) == 0:
        return torch.tensor(0., device=probs.device)
    return sum(losses) / len(losses)


def lovasz_softmax(probs, labels):
    probs_flat = probs.permute(0, 2, 3, 1).contiguous().view(-1, probs.size(1))
    labels_flat = labels.view(-1)
    return lovasz_softmax_flat(probs_flat, labels_flat)


class CombinedLoss(nn.Module):
 
    def __init__(self,
                 num_classes=NUM_CLASSES,
                 ignore_index=IGNORE_INDEX,
                 lambda_main=1.0,
                 lambda_aux=0.3,
                 lambda_boundary=0.5,
                 lambda_lovasz=1.0):

        super().__init__()
        self.ce = nn.CrossEntropyLoss(ignore_index=ignore_index)
        self.ignore_index = ignore_index
        self.num_classes = num_classes

        self.lambda_main = lambda_main
        self.lambda_aux = lambda_aux
        self.lambda_boundary = lambda_boundary
        self.lambda_lovasz = lambda_lovasz

        self.bce_boundary = nn.BCEWithLogitsLoss()

        self.use_aux = ABLATE_AUX
        self.use_boundary = ABLATE_BOUNDARY
        self.use_lovasz = ABLATE_LOVASZ

    def dice_loss(self, logits, targets, eps=1e-6):
        num_classes = logits.shape[1]
        probs = F.softmax(logits, dim=1)

        one_hot = F.one_hot(
            torch.clamp(targets, 0, num_classes - 1),
            num_classes=num_classes
        ).permute(0, 3, 1, 2).float()

        # ignore background index
        ignore_mask = (targets == self.ignore_index).unsqueeze(1)  # [B,1,H,W]
        one_hot = torch.where(ignore_mask, torch.zeros_like(one_hot), one_hot)
        probs = torch.where(ignore_mask, torch.zeros_like(probs), probs)

        dims = (0, 2, 3)
        intersection = torch.sum(probs * one_hot, dims)
        cardinality = torch.sum(probs + one_hot, dims)
        dice = (2. * intersection + eps) / (cardinality + eps)

        valid = torch.arange(num_classes, device=logits.device) != self.ignore_index
        dice = dice[valid]

        return 1. - dice.mean()


    def boundary_from_mask(self, mask: torch.Tensor, dilation: int = 1) -> torch.Tensor:
        """
        mask: [B,H,W] integer labels
        returns: [B,1,H,W] boundary map {0,1}
        """
        b, h, w = mask.shape
        mask = mask.unsqueeze(1).float()  # [B,1,H,W]
        pad = dilation
        max_pool = F.max_pool2d(mask, kernel_size=3, stride=1, padding=pad)
        min_pool = -F.max_pool2d(-mask, kernel_size=3, stride=1, padding=pad)
        boundary = (max_pool != min_pool).float()
        return boundary


    def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor):

        logits = outputs["logits"]               # [B,C,H,W]
        aux_16 = outputs["aux_16"]              # or None
        aux_8 = outputs["aux_8"]                # or None
        boundary_logits = outputs["boundary_logits"]  # or None

        loss_ce_main = self.ce(logits, targets)
        loss_dice_main = self.dice_loss(logits, targets)
        loss_main = loss_ce_main + loss_dice_main

        if self.use_lovasz:
            probs = F.softmax(logits, dim=1)
            loss_lovasz = lovasz_softmax(probs, targets)
            loss_main = loss_main + self.lambda_lovasz * loss_lovasz
        else:
            loss_lovasz = torch.tensor(0.0, device=logits.device)

        if self.use_aux and (aux_16 is not None) and (aux_8 is not None):
            loss_aux_16 = self.ce(aux_16, targets)
            loss_aux_8 = self.ce(aux_8, targets)
            loss_aux = 0.5 * (loss_aux_16 + loss_aux_8)
        else:
            loss_aux = torch.tensor(0.0, device=logits.device)

        if self.use_boundary and (boundary_logits is not None):
            with torch.no_grad():
                boundary_gt = self.boundary_from_mask(targets)
            loss_boundary = self.bce_boundary(boundary_logits, boundary_gt)
        else:
            loss_boundary = torch.tensor(0.0, device=logits.device)

        total_loss = (
            self.lambda_main * loss_main +
            self.lambda_aux * loss_aux +
            self.lambda_boundary * loss_boundary
        )

        loss_dict = {
            "loss_total": float(total_loss.item()),
            "loss_main": float(loss_main.item()),
            "loss_ce": float(loss_ce_main.item()),
            "loss_dice": float(loss_dice_main.item()),
            "loss_lovasz": float(loss_lovasz.item()),
            "loss_aux": float(loss_aux.item()),
            "loss_boundary": float(loss_boundary.item()),
        }

        return total_loss, loss_dict



def compute_confusion_matrix(pred: torch.Tensor,
                             target: torch.Tensor,
                             num_classes: int,
                             ignore_index: int) -> np.ndarray:
    pred = pred.view(-1).cpu().numpy()
    target = target.view(-1).cpu().numpy()
    mask = target != ignore_index
    pred = pred[mask]
    target = target[mask]
    cm = np.bincount(
        num_classes * target.astype(int) + pred.astype(int),
        minlength=num_classes ** 2
    ).reshape(num_classes, num_classes)
    return cm


def metrics_from_confusion_matrix(cm: np.ndarray, ignore_index: int):
    tp = np.diag(cm)
    sum_rows = cm.sum(axis=1)
    sum_cols = cm.sum(axis=0)
    union = sum_rows + sum_cols - tp

    iou = tp / np.maximum(union, 1e-6)
    acc = tp / np.maximum(sum_rows, 1e-6)

    num_classes = cm.shape[0]
    valid = np.arange(num_classes) != ignore_index

    mIoU = np.nanmean(iou[valid])
    mAcc = np.nanmean(acc[valid])

    pixel_acc = tp.sum() / np.maximum(cm.sum(), 1e-6)

    freq = sum_rows / np.maximum(cm.sum(), 1e-6)
    fwIoU = (freq[valid] * iou[valid]).sum()

    metrics = {
        "mIoU": float(mIoU),
        "pixel_acc": float(pixel_acc),
        "mAcc": float(mAcc),
        "FWIoU": float(fwIoU),
    }
    return metrics, iou, acc


def boundary_f1(pred: torch.Tensor, target: torch.Tensor,
                ignore_index: int, dilation: int = 1) -> float:
    def extract_boundary(mask: torch.Tensor) -> torch.Tensor:
        b, h, w = mask.shape
        mask = mask.unsqueeze(1).float()
        pad = dilation
        max_pool = F.max_pool2d(mask, kernel_size=3, stride=1, padding=pad)
        min_pool = -F.max_pool2d(-mask, kernel_size=3, stride=1, padding=pad)
        boundary = (max_pool != min_pool).float()
        return boundary

    with torch.no_grad():
        target_mask = target.clone()
        target_mask[target_mask == ignore_index] = 0
        pred_mask = pred.clone()
        pred_mask[target == ignore_index] = 0

        b_gt = extract_boundary(target_mask)
        b_pred = extract_boundary(pred_mask)

        b_gt = b_gt.view(-1)
        b_pred = b_pred.view(-1)

        tp = ((b_gt == 1) & (b_pred == 1)).sum().item()
        fp = ((b_gt == 0) & (b_pred == 1)).sum().item()
        fn = ((b_gt == 1) & (b_pred == 0)).sum().item()

        precision = tp / max(tp + fp, 1e-6)
        recall = tp / max(tp + fn, 1e-6)
        f1 = 2 * precision * recall / max(precision + recall, 1e-6)
        return float(f1)


def train_one_epoch(model, dataloader, optimizer, loss_fn, epoch, device, scaler):
    model.train()
    running = {"loss_total": 0.0, "loss_main": 0.0, "loss_aux": 0.0, "loss_boundary": 0.0}

    pbar = tqdm(dataloader, desc=f"Train Epoch {epoch}", leave=False)
    for imgs, masks in pbar:
        imgs = imgs.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=AMP_ENABLED):
            outputs = model(imgs)
            loss, loss_dict = loss_fn(outputs, masks)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        for k in running.keys():
            if k in loss_dict:
                running[k] += loss_dict[k]


    n = len(dataloader)
    return {k: v / n for k, v in running.items()}


def eval_one_epoch(model, dataloader, loss_fn, device):
    model.eval()
    running = {"loss_total": 0.0, "loss_main": 0.0, "loss_aux": 0.0, "loss_boundary": 0.0}
    cm = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.int64)
    bf1_values = []

    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation", leave=False)
        for imgs, masks in pbar:
            imgs = imgs.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)

            with torch.cuda.amp.autocast(enabled=AMP_ENABLED):
                outputs = model(imgs)
                loss, loss_dict = loss_fn(outputs, masks)

            for k in running.keys():
                if k in loss_dict:
                    running[k] += loss_dict[k]


            logits = outputs["logits"]
            preds = torch.argmax(logits, dim=1)

            cm_batch = compute_confusion_matrix(preds, masks, NUM_CLASSES, IGNORE_INDEX)
            cm += cm_batch
            bf1_values.append(boundary_f1(preds, masks, IGNORE_INDEX))

    n = len(dataloader)
    loss_stats = {k: v / n for k, v in running.items()}
    metric_stats, iou_per_class, acc_per_class = metrics_from_confusion_matrix(cm, IGNORE_INDEX)
    metric_stats["BF1"] = float(np.mean(bf1_values)) if bf1_values else 0.0
    return loss_stats, metric_stats, iou_per_class, acc_per_class


def enable_mc_dropout(model: nn.Module):
    for m in model.modules():
        if isinstance(m, (nn.Dropout, nn.Dropout2d)):
            m.train()  # keep dropout active in eval


def mc_dropout_predict(model: nn.Module,
                       imgs: torch.Tensor,
                       mc_passes: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
    model.eval()
    enable_mc_dropout(model)
    probs_list = []
    with torch.no_grad():
        for _ in range(mc_passes):
            with torch.cuda.amp.autocast(enabled=AMP_ENABLED):
                out = model(imgs)
                logits = out["logits"]
                probs = F.softmax(logits, dim=1)
            probs_list.append(probs)
    probs_stack = torch.stack(probs_list, dim=0)  # [T,B,C,H,W]
    probs_mean = probs_stack.mean(dim=0)
    entropy = -(probs_mean * probs_mean.clamp(min=1e-8).log()).sum(dim=1, keepdim=True)
    return probs_mean, entropy



def decode_segmap(mask: np.ndarray) -> np.ndarray:
    """
    mask: [H,W] int labels.
    returns: [H,W,3] RGB
    """
    h, w = mask.shape
    rgb = COLOR_MAP[mask.flatten()].reshape(h, w, 3)
    return rgb


def save_test_visuals(image_tensor: torch.Tensor,
                      pred_mask: torch.Tensor,
                      entropy_map: torch.Tensor,
                      save_prefix: str):
  
    img_np = image_tensor.cpu().numpy()
    img_np = np.transpose(img_np, (1, 2, 0))
    img_np = (img_np * np.array([0.229, 0.224, 0.225]) +
              np.array([0.485, 0.456, 0.406]))
    img_np = np.clip(img_np, 0, 1)

    mask_np = pred_mask.cpu().numpy()
    entropy_np = entropy_map.cpu().numpy()

    color_mask = decode_segmap(mask_np)
    color_mask_norm = color_mask.astype(np.float32) / 255.0

    overlay = 0.6 * img_np + 0.4 * color_mask_norm
    overlay = np.clip(overlay, 0, 1)

    ent_min, ent_max = entropy_np.min(), entropy_np.max()
    if ent_max > ent_min:
        ent_vis = (entropy_np - ent_min) / (ent_max - ent_min)
    else:
        ent_vis = np.zeros_like(entropy_np)
    ent_vis = np.stack([ent_vis, ent_vis, ent_vis], axis=-1)

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1); plt.imshow(img_np); plt.title("Image"); plt.axis("off")
    plt.subplot(1, 3, 2); plt.imshow(overlay); plt.title("Prediction Overlay"); plt.axis("off")
    plt.subplot(1, 3, 3); plt.imshow(ent_vis, cmap="magma"); plt.title("Uncertainty"); plt.axis("off")
    plt.tight_layout()
    plt.savefig(save_prefix + "_viz.png", dpi=150)
    plt.close()

    mask_rgb = decode_segmap(mask_np)
    Image.fromarray(mask_rgb).save(save_prefix + "_mask.png")


def per_class_dice(cm: np.ndarray, ignore_index: int):
  
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp

    dice = 2 * tp / np.maximum(2 * tp + fp + fn, 1e-6)

    valid = np.arange(len(dice)) != ignore_index
    return dice[valid]


def cohens_kappa(cm: np.ndarray):

    total = cm.sum()
    po = np.trace(cm) / total  # observed accuracy

    # expected accuracy
    row_marginals = cm.sum(axis=1)
    col_marginals = cm.sum(axis=0)
    pe = np.sum(row_marginals * col_marginals) / (total * total)

    kappa = (po - pe) / (1 - pe + 1e-6)
    return float(kappa)


def boundary_iou(pred: torch.Tensor, target: torch.Tensor, ignore_index: int):

    def extract_boundary(mask: torch.Tensor):
      
        if mask.dim() == 2:
            mask = mask.unsqueeze(0).unsqueeze(0)  # [1,1,H,W]
        elif mask.dim() == 3:
            mask = mask.unsqueeze(1)               # [B,1,H,W]
        else:
            raise ValueError("mask must be 2D or 3D")

        mask = mask.float()

        max_pool = F.max_pool2d(mask, kernel_size=3, stride=1, padding=1)
        min_pool = -F.max_pool2d(-mask, kernel_size=3, stride=1, padding=1)

        boundary = (max_pool != min_pool).float()  # [B,1,H,W]
        return boundary.squeeze(1)                  # [B,H,W]

    with torch.no_grad():
      
        t = target.clone()
        p = pred.clone()

        t[t == ignore_index] = 0
        p[target == ignore_index] = 0

        bt = extract_boundary(t)   # [1,H,W]
        bp = extract_boundary(p)   # [1,H,W]

        bt = bt.squeeze(0)
        bp = bp.squeeze(0)

        intersection = ((bt == 1) & (bp == 1)).sum().item()
        union = ((bt == 1) | (bp == 1)).sum().item()

        if union == 0:
            return 0.0
        
        return float(intersection / union)



def expected_calibration_error(probs: torch.Tensor, targets: torch.Tensor, n_bins=15):
   
    with torch.no_grad():
        conf, pred = probs.max(dim=1)  # [B,H,W]

        pred = pred.view(-1).cpu().numpy()
        conf = conf.view(-1).cpu().numpy()
        targets = targets.view(-1).cpu().numpy()

        mask = targets != IGNORE_INDEX
        pred = pred[mask]
        conf = conf[mask]
        targets = targets[mask]

        bins = np.linspace(0.0, 1.0, n_bins + 1)
        ece = 0.0

        for i in range(n_bins):
            l, r = bins[i], bins[i+1]
            idx = (conf >= l) & (conf < r)
            if np.sum(idx) == 0:
                continue
            acc_bin = np.mean(pred[idx] == targets[idx])
            conf_bin = np.mean(conf[idx])
            ece += np.abs(acc_bin - conf_bin) * (np.sum(idx) / len(conf))

        return float(ece)


def plot_confusion_matrix(cm, class_names, save_path, normalize=True):
    
    if normalize:
        cm = cm.astype(np.float32)
        cm = cm / np.maximum(cm.sum(axis=1, keepdims=True), 1e-6)

    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm,
        annot=True,
        fmt=".2f" if normalize else "d",
        cmap="Blues",
        xticklabels=class_names,
        yticklabels=class_names
    )
    plt.xlabel("Predicted")
    plt.ylabel("Ground Truth")
    plt.title("Confusion Matrix" + (" (Normalized)" if normalize else ""))
    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()


def main():
    print("Loading datasets...")

    train_ds = EarthVQADataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR, img_size=IMG_SIZE)
    val_ds   = EarthVQADataset(VAL_IMG_DIR, VAL_MASK_DIR, img_size=IMG_SIZE)
    test_ds  = EarthVQATestDataset(TEST_IMG_DIR, img_size=IMG_SIZE)

    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
        drop_last=True,
        persistent_workers=True
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True
    )

    test_loader = DataLoader(
        test_ds,
        batch_size=8,
        shuffle=False,
        num_workers=8,
        pin_memory=True
    )

    print("Building model (ConvNeXt-Tiny backbone)...")
    model = SegmentationModel(num_classes=NUM_CLASSES).to(DEVICE)

    from thop import profile

    def compute_model_complexity(model, img_size, device):
        model.eval()
        dummy = torch.randn(1, 3, img_size, img_size).to(device)
    
        flops, params = profile(
            model,
            inputs=(dummy,),
            verbose=False
        )
    
        flops_g = flops / 1e9
        params_m = params / 1e6
    
        return flops_g, params_m
    gflops, params_m = compute_model_complexity(model, IMG_SIZE, DEVICE)

    print("\n================ MODEL COMPLEXITY ================")
    print(f"Parameters (M): {params_m:.2f}")
    print(f"GFLOPs @ {IMG_SIZE}x{IMG_SIZE}: {gflops:.2f}")
    print("==================================================")


    with torch.no_grad():
        x = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(DEVICE)
        e1, e2, e3, e4 = model.encoder(x)
        print("ENCODER SHAPES:")
        print("e1:", e1.shape)
        print("e2:", e2.shape)
        print("e3:", e3.shape)
        print("e4:", e4.shape)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LR,
        weight_decay=WEIGHT_DECAY
    )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=NUM_EPOCHS
    )

    scaler = torch.cuda.amp.GradScaler(enabled=AMP_ENABLED)

    loss_fn = CombinedLoss(
        num_classes=NUM_CLASSES,
        ignore_index=IGNORE_INDEX,
        lambda_main=LAMBDA_MAIN,
        lambda_aux=LAMBDA_AUX,
        lambda_boundary=LAMBDA_BOUNDARY
    )

    start_epoch = 1
    best_mIoU = 0.0
    no_improve_epochs = 0

    if RESUME_TRAINING and os.path.exists(BEST_MODEL_PATH):
        print(f"\n[INFO] Resuming training from: {BEST_MODEL_PATH}")

        ckpt = torch.load(BEST_MODEL_PATH, map_location=DEVICE)

        model.load_state_dict(ckpt["model_state"])
        optimizer.load_state_dict(ckpt["optimizer_state"])
        scheduler.load_state_dict(ckpt["scheduler_state"])

        if AMP_ENABLED and ckpt.get("scaler_state") is not None:
            scaler.load_state_dict(ckpt["scaler_state"])

        start_epoch = ckpt["epoch"] + 1
        best_mIoU = ckpt["best_mIoU"]
        no_improve_epochs = ckpt.get("no_improve_epochs", 0)

        print(f"[INFO] Last epoch finished : {ckpt['epoch']}")
        print(f"[INFO] Resuming from epoch : {start_epoch}")
        print(f"[INFO] Best mIoU so far    : {best_mIoU:.4f}")

    train_losses = []
    val_losses = []
    val_mious = []

    for epoch in range(start_epoch, NUM_EPOCHS + 1):
        print(f"\n=== Epoch {epoch}/{NUM_EPOCHS} ===")

        train_stats = train_one_epoch(
            model,
            train_loader,
            optimizer,
            loss_fn,
            epoch,
            DEVICE,
            scaler
        )

        val_loss_stats, val_metrics, _, _ = eval_one_epoch(
            model,
            val_loader,
            loss_fn,
            DEVICE
        )

        scheduler.step()

        train_losses.append(train_stats["loss_total"])
        val_losses.append(val_loss_stats["loss_total"])
        val_mious.append(val_metrics["mIoU"])

        print(f"Train Loss : {train_stats['loss_total']:.4f}")
        print(
            f"Val Loss   : {val_loss_stats['loss_total']:.4f} | "
            f"mIoU: {val_metrics['mIoU']:.4f} | "
            f"PA: {val_metrics['pixel_acc']:.4f} | "
            f"BF1: {val_metrics['BF1']:.4f}"
        )

        mIoU = val_metrics["mIoU"]

        if mIoU > best_mIoU:
            best_mIoU = mIoU
            no_improve_epochs = 0

            torch.save({
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "scaler_state": scaler.state_dict() if AMP_ENABLED else None,
                "best_mIoU": best_mIoU,
                "no_improve_epochs": no_improve_epochs,
            }, BEST_MODEL_PATH)

            print(f"[INFO] New best mIoU: {best_mIoU:.4f} â†’ checkpoint saved")

        else:
            no_improve_epochs += 1
            print(f"[INFO] No improvement for {no_improve_epochs} epoch(s)")

        if no_improve_epochs >= EARLY_STOP_PATIENCE:
            print("[INFO] Early stopping triggered.")
            break

    print(f"\nTraining finished. Best mIoU: {best_mIoU:.4f}")

    plt.figure(figsize=(8, 5))
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.plot(val_mious, label="Val mIoU")
    plt.xlabel("Epoch")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(PLOT_PATH, dpi=150)
    plt.close()

    print(f"Training curves saved to {PLOT_PATH}")


    # fINAL TEST INFERENCE

    print("\nLoading best model for final TEST inference...")


  
    print("\nRunning FINAL METRICS on VALIDATION SET...")
    
    cm_final = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.int64)
    boundary_iou_list = []
    ece_list = []
    
    model.eval()
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc="Final Metrics")
        for imgs, masks in pbar:
            imgs = imgs.to(DEVICE)
            masks = masks.to(DEVICE)
    
            with torch.cuda.amp.autocast(enabled=AMP_ENABLED):
                out = model(imgs)
                logits = out["logits"]
                probs = F.softmax(logits, dim=1)
                preds = torch.argmax(logits, dim=1)
    
            # confusion matrix accumulation
            cm_final += compute_confusion_matrix(preds, masks, NUM_CLASSES, IGNORE_INDEX)
    
            # per-image Boundary IoU
            for b in range(imgs.size(0)):
                boundary_iou_list.append(
                    boundary_iou(preds[b].cpu(), masks[b].cpu(), IGNORE_INDEX)
                )
    

            ece_list.append(expected_calibration_error(probs, masks))

    

    tp = np.diag(cm_final)
    sum_rows = cm_final.sum(axis=1)
    sum_cols = cm_final.sum(axis=0)
    union = sum_rows + sum_cols - tp
    iou_per_class = tp / np.maximum(union, 1e-6)
    

    dice_valid = per_class_dice(cm_final, IGNORE_INDEX)  # length = 8 (classes 1..8)
    dice_for_classes = np.full(NUM_CLASSES, np.nan)      # [0..8], start as NaN
    for cls in range(1, NUM_CLASSES):                    # fill for classes 1..8
        dice_for_classes[cls] = dice_valid[cls - 1]
    
    kappa_final = cohens_kappa(cm_final)
    boundary_iou_final = float(np.mean(boundary_iou_list)) if boundary_iou_list else 0.0
    ece_final = float(np.mean(ece_list)) if ece_list else 0.0
    
    print("\n================ FINAL METRICS ================")
    print("Per-class IoU:", iou_per_class)
    print("Per-class Dice:", dice_for_classes)
    print("Cohen's Kappa:", kappa_final)
    print("Boundary IoU:", boundary_iou_final)
    print("ECE:", ece_final)
    print("==============================================")

    CLASS_NAMES = [
    "no-data (ignored)",  # 0
    "Background",         # 1
    "Building",           # 2
    "Road",               # 3
    "Water",              # 4
    "Barren",             # 5
    "Forest",             # 6
    "Agriculture"         # 7
    ]

    
    cm_path = os.path.join(WORK_DIR, "confusion_matrix.png")
    plot_confusion_matrix(
        cm_final,
        class_names=CLASS_NAMES,
        save_path=cm_path,
        normalize=True
    )
    
    print(f"Confusion matrix saved to: {cm_path}")


    import pandas as pd
    
    df = pd.DataFrame({
        "class": list(range(NUM_CLASSES)),
        "IoU": iou_per_class,
        "Dice": dice_for_classes
    })
    
    df2 = pd.DataFrame({
        "Metric": ["Kappa", "BoundaryIoU", "ECE"],
        "Value": [kappa_final, boundary_iou_final, ece_final]
    })
    
    final_csv = os.path.join(WORK_DIR, "final_metrics.csv")
    with open(final_csv, "w") as f:
        df.to_csv(f, index=False)
        f.write("\n")
        df2.to_csv(f, index=False)
    
    print(f"\nFinal metrics saved to: {final_csv}")

    ckpt = torch.load(BEST_MODEL_PATH, map_location=DEVICE)
    model.load_state_dict(ckpt["model_state"])
    model.to(DEVICE)
    model.eval()

    print(f"Running test inference on {len(test_loader.dataset)} images...")
    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Test Inference")
        for i, (imgs, names) in enumerate(pbar):
            imgs = imgs.to(DEVICE, non_blocking=True)

     
            with torch.cuda.amp.autocast(enabled=AMP_ENABLED):
                out = model(imgs)
                logits = out["logits"]
                preds = torch.argmax(logits, dim=1)

       
            probs_mean, entropy = mc_dropout_predict(model, imgs, mc_passes=5)
            entropy_map = entropy[0, 0]

            img_name = os.path.splitext(names[0])[0]
            save_prefix = os.path.join(TEST_OUT_DIR, img_name)

            save_test_visuals(
                image_tensor=imgs[0].cpu(),
                pred_mask=preds[0].cpu(),
                entropy_map=entropy_map.cpu(),
                save_prefix=save_prefix
            )

    print(f"Test predictions + overlays + uncertainty saved in: {TEST_OUT_DIR}")
    
    
        



if __name__ == "__main__":
    main()
