In [None]:


import os
import random
from glob import glob
from pathlib import Path
from tqdm import tqdm

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

import albumentations as A

from sklearn.metrics import confusion_matrix

# -------------------------
# Config / Hyperparameters (Kaggle-friendly)
# -------------------------
SEED = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Point DATA_ROOT to the attached dataset under /kaggle/input
# Example for your earlier path:
DATA_ROOT = "/kaggle/input/tuc-cd-new/TUC-CD_new/TUE_new"   # <<-- change if needed

A_DIR = os.path.join(DATA_ROOT, "A_new")
B_DIR = os.path.join(DATA_ROOT, "B_new")
MASK_DIR = os.path.join(DATA_ROOT, "label_new")

# Save best model to working directory so it appears in Outputs
BEST_MODEL_PATH = "/kaggle/working/best_unet4_model_new.pth"
BEST_CHECKPOINT = "/kaggle/working/best_unet4_checkpoint_new.pth"  # optional

IMG_EXT = ["*.png", "*.jpg", "*.jpeg", "*.tif"]
IMG_SIZE = (256, 256)   # safer default on Kaggle
BATCH_SIZE = 8
EPOCHS = 200
EARLY_STOPPING_PATIENCE = 10
LR = 1e-4
WEIGHT_DECAY = 1e-5

# DataLoader niceties for Kaggle (Linux)
NUM_WORKERS = 2
PIN_MEMORY = True

# -------------------------
# Reproducibility
# -------------------------
def seed_everything(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything()

# -------------------------
# Helper: find files (assumes aligned filenames across folders)
# -------------------------
def list_images(dir_path):
    files = []
    for ext in IMG_EXT:
        files.extend(glob(os.path.join(dir_path, ext)))
    files = sorted(files)
    return files

pre_list = list_images(A_DIR)
post_list = list_images(B_DIR)
mask_list = list_images(MASK_DIR)

if not (len(pre_list) == len(post_list) == len(mask_list)):
    raise ValueError(f"Unequal counts: pre {len(pre_list)}, post {len(post_list)}, mask {len(mask_list)}")

pairs = list(zip(pre_list, post_list, mask_list))
print(f"Found {len(pairs)} triplets")

# -------------------------
# Albumentations Transforms (use 3-channel mean/std per image)
# -------------------------
imagenet_mean = (0.485, 0.456, 0.406)
imagenet_std  = (0.229, 0.224, 0.225)

train_transform = A.Compose(
    [
        A.Resize(IMG_SIZE[0], IMG_SIZE[1]),
        A.RandomCrop(IMG_SIZE[0], IMG_SIZE[1]),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.2),
        A.Rotate(limit=20, p=0.5),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=0, p=0.3),
        A.RandomBrightnessContrast(p=0.5),
        A.GaussianBlur(blur_limit=(3,5), p=0.2),
        # Normalize will be applied to both 'image' (pre) and 'post_img' separately
        A.Normalize(mean=imagenet_mean, std=imagenet_std),
    ],
    additional_targets={"post_img": "image", "mask": "mask"}
)

val_transform = A.Compose(
    [
        A.Resize(IMG_SIZE[0], IMG_SIZE[1]),
        A.Normalize(mean=imagenet_mean, std=imagenet_std),
    ],
    additional_targets={"post_img": "image", "mask": "mask"}
)

# -------------------------
# Dataset
# -------------------------
class ChangeDetectionDataset(Dataset):
    def __init__(self, triplets, transform=None, photometric_independent=False):
        self.triplets = triplets
        self.transform = transform
        self.photometric_independent = photometric_independent
        self.color_aug = A.Compose([
            A.RandomBrightnessContrast(p=0.5),
            A.HueSaturationValue(p=0.3),
        ])

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        pre_path, post_path, mask_path = self.triplets[idx]
        pre = np.array(Image.open(pre_path).convert("RGB"))
        post = np.array(Image.open(post_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"))

        mask = (mask > 127).astype("uint8")

        if self.transform:
            augmented = self.transform(image=pre, post_img=post, mask=mask)
            pre = augmented["image"]
            post = augmented["post_img"]
            mask = augmented["mask"]
        else:
            pre = np.array(Image.fromarray(pre).resize(IMG_SIZE[::-1]))
            post = np.array(Image.fromarray(post).resize(IMG_SIZE[::-1]))
            mask = np.array(Image.fromarray(mask).resize(IMG_SIZE[::-1]))

        if self.photometric_independent:
            pre = self.color_aug(image=pre)["image"]
            post = self.color_aug(image=post)["image"]

        # If albumentations Normalize applied, arrays are float32 and roughly normalized.
        # Ensure float32 and stack pre+post into H,W,6
        if pre.dtype == np.uint8:
            pre = pre.astype(np.float32) / 255.0
        if post.dtype == np.uint8:
            post = post.astype(np.float32) / 255.0

        stacked = np.concatenate([pre, post], axis=2)  # H,W,6
        img = torch.from_numpy(stacked).permute(2,0,1).float()
        mask = torch.from_numpy(mask).unsqueeze(0).float()  # 1,H,W

        return img, mask

# -------------------------
# Create datasets and splits (60/20/20 random)
# -------------------------
random.shuffle(pairs)
n_total = len(pairs)
n_train = int(0.6 * n_total)
n_val = int(0.2 * n_total)
n_test = n_total - n_train - n_val

train_pairs = pairs[:n_train]
val_pairs = pairs[n_train:n_train + n_val]
test_pairs = pairs[n_train + n_val:]

print(f"Split: train {len(train_pairs)}, val {len(val_pairs)}, test {len(test_pairs)}")

train_ds = ChangeDetectionDataset(train_pairs, transform=train_transform, photometric_independent=True)
val_ds = ChangeDetectionDataset(val_pairs, transform=val_transform, photometric_independent=False)
test_ds = ChangeDetectionDataset(test_pairs, transform=val_transform, photometric_independent=False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

# -------------------------
# U-Net (custom, 4 encoders + bottleneck + 4 decoders)
# -------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet4(nn.Module):
    def __init__(self, in_channels=6, out_channels=1, features=[64,128,256,512]):
        super().__init__()
        self.enc1 = DoubleConv(in_channels, features[0])
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(features[0], features[1])
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DoubleConv(features[1], features[2])
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = DoubleConv(features[2], features[3])
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(features[3], features[3]*2)

        self.up4 = nn.ConvTranspose2d(features[3]*2, features[3], kernel_size=2, stride=2)
        self.dec4 = DoubleConv(features[3]*2, features[3])
        self.up3 = nn.ConvTranspose2d(features[3], features[2], kernel_size=2, stride=2)
        self.dec3 = DoubleConv(features[2]*2, features[2])
        self.up2 = nn.ConvTranspose2d(features[2], features[1], kernel_size=2, stride=2)
        self.dec2 = DoubleConv(features[1]*2, features[1])
        self.up1 = nn.ConvTranspose2d(features[1], features[0], kernel_size=2, stride=2)
        self.dec1 = DoubleConv(features[0]*2, features[0])

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x); p1 = self.pool1(e1)
        e2 = self.enc2(p1); p2 = self.pool2(e2)
        e3 = self.enc3(p2); p3 = self.pool3(e3)
        e4 = self.enc4(p3); p4 = self.pool4(e4)
        b = self.bottleneck(p4)
        d4 = self.up4(b); d4 = torch.cat([d4, e4], dim=1); d4 = self.dec4(d4)
        d3 = self.up3(d4); d3 = torch.cat([d3, e3], dim=1); d3 = self.dec3(d3)
        d2 = self.up2(d3); d2 = torch.cat([d2, e2], dim=1); d2 = self.dec2(d2)
        d1 = self.up1(d2); d1 = torch.cat([d1, e1], dim=1); d1 = self.dec1(d1)
        out = self.final_conv(d1)
        return out

# -------------------------
# Losses & Metrics
# -------------------------
def dice_coef(pred, target, smooth=1e-6):
    pred = pred.contiguous()
    target = target.contiguous()
    intersection = (pred * target).sum(dim=(2,3))
    denom = pred.sum(dim=(2,3)) + target.sum(dim=(2,3))
    dice = (2. * intersection + smooth) / (denom + smooth)
    return dice.mean()

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        dice = dice_coef(pred, target, smooth=self.smooth)
        return 1.0 - dice

bce_loss = nn.BCEWithLogitsLoss()

def combined_loss(logits, mask):
    bce = bce_loss(logits, mask)
    probs = torch.sigmoid(logits)
    dice = DiceLoss()(probs, mask)
    return bce + dice

@torch.no_grad()
def compute_metrics_batch(logits, masks, thresh=0.5):
    probs = torch.sigmoid(logits)
    preds = (probs >= thresh).float()
    preds_flat = preds.view(-1).cpu().numpy()
    masks_flat = masks.view(-1).cpu().numpy()
    tn, fp, fn, tp = confusion_matrix(masks_flat, preds_flat, labels=[0,1]).ravel().astype(np.int64)
    eps = 1e-8
    iou = tp / (tp + fp + fn + eps)
    dice = (2 * tp) / (2 * tp + fp + fn + eps)
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    f1 = 2 * precision * recall / (precision + recall + eps)
    acc = (tp + tn) / (tp + tn + fp + fn + eps)
    return {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp),
            "iou": float(iou), "dice": float(dice), "precision": float(precision),
            "recall": float(recall), "f1": float(f1), "acc": float(acc)}

# -------------------------
# Training / Validation loops
# -------------------------
def train_one_epoch(model, loader, optimizer):
    model.train()
    running_loss = 0.0
    for imgs, masks in tqdm(loader, desc="Train batch"):
        imgs = imgs.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = combined_loss(logits, masks)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    epoch_loss = running_loss / len(loader.dataset)
    return epoch_loss

@torch.no_grad()
def validate(model, loader):
    model.eval()
    running_loss = 0.0
    agg = {"tn":0,"fp":0,"fn":0,"tp":0}
    for imgs, masks in tqdm(loader, desc="Val batch"):
        imgs = imgs.to(device)
        masks = masks.to(device)
        logits = model(imgs)
        loss = combined_loss(logits, masks)
        running_loss += loss.item() * imgs.size(0)
        metas = compute_metrics_batch(logits, masks)
        for k in ["tn","fp","fn","tp"]:
            agg[k] += metas[k]
    tp, fp, fn, tn = agg["tp"], agg["fp"], agg["fn"], agg["tn"]
    eps = 1e-8
    iou = tp / (tp + fp + fn + eps)
    dice = (2 * tp) / (2 * tp + fp + fn + eps)
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    f1 = 2 * precision * recall / (precision + recall + eps)
    acc = (tp + tn) / (tp + tn + fp + fn + eps)
    epoch_loss = running_loss / len(loader.dataset)
    metrics = {"loss": epoch_loss, "iou": iou, "dice": dice, "precision": precision,
               "recall": recall, "f1": f1, "acc": acc, "confusion": np.array([[tn, fp], [fn, tp]])}
    return metrics

# -------------------------
# Main training
# -------------------------
def main():
    model = UNet4(in_channels=6, out_channels=1).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    best_val_loss = float("inf")
    epochs_no_improve = 0
    history = {"train_loss": [], "val_loss": [], "val_iou": [], "val_dice": []}

    for epoch in range(1, EPOCHS+1):
        print(f"\nEpoch {epoch}/{EPOCHS}")
        train_loss = train_one_epoch(model, train_loader, optimizer)
        val_metrics = validate(model, val_loader)

        print(f"Train Loss: {train_loss:.6f}")
        print(f"Val Loss: {val_metrics['loss']:.6f} | IoU: {val_metrics['iou']:.4f} | Dice: {val_metrics['dice']:.4f} | F1: {val_metrics['f1']:.4f}")

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_metrics["loss"])
        history["val_iou"].append(val_metrics["iou"])
        history["val_dice"].append(val_metrics["dice"])

        # early stopping logic + save best
        if val_metrics["loss"] < best_val_loss - 1e-6:
            best_val_loss = val_metrics["loss"]
            epochs_no_improve = 0
            torch.save(model.state_dict(), BEST_MODEL_PATH)
            # optional: save checkpoint including optimizer
            torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_val_loss': best_val_loss}, BEST_CHECKPOINT)
            print(f"Saved best model to {BEST_MODEL_PATH}.")
        else:
            epochs_no_improve += 1
            print(f"No improvement for {epochs_no_improve} epoch(s).")

        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"Early stopping triggered after {epoch} epochs (patience={EARLY_STOPPING_PATIENCE})")
            break

    # Load best model for final evaluation on test set
    print("Loading best model for test evaluation:", BEST_MODEL_PATH)
    model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device))
    model.to(device)
    test_metrics = validate(model, test_loader)

    print("\n--- TEST METRICS ---")
    print(f"Test Loss: {test_metrics['loss']:.6f}")
    print(f"IoU: {test_metrics['iou']:.4f}")
    print(f"Dice: {test_metrics['dice']:.4f}")
    print(f"Precision: {test_metrics['precision']:.4f}")
    print(f"Recall: {test_metrics['recall']:.4f}")
    print(f"F1: {test_metrics['f1']:.4f}")
    print(f"Accuracy: {test_metrics['acc']:.4f}")
    print("Confusion matrix (pixel-level):")
    print(test_metrics["confusion"])

if __name__ == "__main__":
    main()

