### Experiment Details

- **Dataset:** Bijie Dataset
- **Regularized:** Yes
- **Model:** DiGATe_Unet
- **Backbone:** EfficientNet (tf_efficientnet_b4)
- **Data:** 6 Bands, RGB, DEM

In [None]:
EX_NO = 'E01' # Experiment number. weights, results will save with this extention, works like a unique identification number.
# Data directory
DATA_DIR = "ADD YOUR DATA DIRECTORY HERE" # e.g, /home/user1/ms/Datasets/bijie
# Base directory
BASE_DIR = "ADD YOUR BASE DIRECTORY HERE" # Current directory of the project

**Load Data**

In [None]:
import torch
import os 
from torch.utils.data import random_split, ConcatDataset
from dataset import BijieRawDataset

def check_path(path):
    if os.path.exists(path):
        print(f"{path} ✅")
    else:
        print(f"{path}❌")

check_path(DATA_DIR)
landslide_ds = BijieRawDataset(f"{DATA_DIR}/landslide", phase="landslide")
nonlandslide_ds = BijieRawDataset(f"{DATA_DIR}/non-landslide", phase="non-landslide")

# Set a fixed seed for reproducibility
seed = 42
generator = torch.Generator().manual_seed(seed)

# split each one into train/val/test using the generator
def split(ds, ratios=(.7,.2,.1), generator=None):
    n = len(ds)
    sizes = [int(r * n) for r in ratios]
    sizes[2] = n - sum(sizes[:2])
    return random_split(ds, sizes, generator=generator)

# Apply the split with reproducible shuffling
tl, vl, sl = split(landslide_ds, generator=generator)
tn, vn, sn = split(nonlandslide_ds, generator=generator)

# concat landslide + non‐landslide for each split
train_ds = ConcatDataset([tl, tn])
val_ds   = ConcatDataset([vl, vn])
test_ds  = ConcatDataset([sl, sn])

print(f"Number of training samples: {len(train_ds)}")
print(f"Number of validation samples: {len(val_ds)}")
print(f"Number of test samples: {len(test_ds)}")

**Processing**

In [None]:
from dataset import TwoComposites, DualStreamTransform

train_dataset = TwoComposites(train_ds, bands='RGB&DEM', resize_to=256, transform=DualStreamTransform())
val_dataset = TwoComposites(val_ds, bands='RGB&DEM', resize_to=256, transform=None)
test_dataset = TwoComposites(test_ds, bands='RGB&DEM', resize_to=256, transform=None)

image1, image2, mask = train_dataset[0]

print(type(image1), image1.shape, image1.min().item(), image1.max().item())
print(type(image2), image2.shape, image2.min().item(), image2.max().item())
print(type(mask), mask.shape, mask.min().item(), mask.max().item())

**Model**

In [None]:
import torch
from models import DiGATe_Unet

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {DEVICE}')

model = DiGATe_Unet(
        n_classes=1,
        backbone="tf_efficientnet_b4",
        n_channels=3,
        pretrained=True,          # set True if downloads are allowed
        pretrained_path=None,      # or local .pth file
        use_input_adapter=False,
        freeze_backbone=True,
        share_backbone=False
    ).to(DEVICE)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
print(f"Trainable parameters: {trainable_params:.2f}M")

**Hyperparameters**

In [None]:
# Training hyperparameters  
BATCH_SIZE = 32           
NUM_EPOCHS = 100          
LEARNING_RATE = 3e-4     
WEIGHT_DECAY = 1e-4        
PIN_MEMORY = True           
PATIENCE_LIMIT = 20   
NUM_CLASSES = 1

SAVE_PATH = os.path.join(BASE_DIR, "weights", f"{EX_NO}.pth")

**Training**

In [None]:
from train.train_reg import Config, train

config_dict = {  
    'num_epochs': NUM_EPOCHS,   
    'learning_rate': LEARNING_RATE,  
    'weight_decay': WEIGHT_DECAY,  
    'batch_size': BATCH_SIZE,  
    'model_save_path': SAVE_PATH,  
    'device': DEVICE
}

config = Config(**config_dict)
history = train(train_dataset, val_dataset, model, config)

In [None]:
from utils import plot_training_metrics
plot_training_metrics(history, ex=EX_NO)

**Load the saved Model**

In [None]:
checkpoint = torch.load(os.path.join(BASE_DIR, "weights", f"{EX_NO}.pth"), weights_only=False)
# model.load_state_dict(checkpoint['model_state_dict'])
model.load_state_dict(checkpoint)

# Put Model in Evaluation mode
# model.eval()

**Evaluate the model**

In [None]:
from utils import evaluate_model

evaluate_model(model, train_dataset, DEVICE, "Train")
evaluate_model(model, val_dataset, DEVICE, "Validation")
evaluate_model(model, test_dataset, DEVICE, "Test")

**Extended Evaluation**

In [None]:
import os, json, numpy as np, torch, torch.nn.functional as F
from typing import Dict, List, Tuple
import models.smp_metrics as sm
from skimage import measure
from torch.utils.data import DataLoader


def get_dataloaders(test_dataset):
    return DataLoader(test_dataset, shuffle=False, batch_size=32)

def get_debug(d):
    batch = next(iter(d))
    x1, x2, y = batch
    print(f"Shape of x1: {x1.shape}, Type: {type(x1)}, Dtype: {x1.dtype}")
    print(f"Shape of x2: {x2.shape}, Type: {type(x2)}, Dtype: {x2.dtype}")
    print(f"Shape of y: {y.shape}, Type of y: {type(y)}, Dtype: {y.dtype}, y classes: {y.unique()}")
    print()

def prep_batch(x1, x2, y, device):
    if not torch.is_floating_point(x1):
        x1 = x1.float()
    if not torch.is_floating_point(x2):
        x2 = x2.float()
    if torch.is_floating_point(y):
        y = y.round().long()

    x1 = x1.to(device, non_blocking=True)
    x2 = x2.to(device, non_blocking=True)
    y = y.to(device, non_blocking=True)
    if y.dim() == 3:
        y = y.unsqueeze(1)
    return x1, x2, y

def _binarize(prob_map: np.ndarray, thr: float) -> np.ndarray:
    return (prob_map >= thr).astype(np.uint8)

def _remove_small(mask: np.ndarray, min_area: int) -> np.ndarray:
    if min_area <= 1: 
        return mask
    lab = measure.label(mask, connectivity=1)
    out = np.zeros_like(mask, dtype=np.uint8)
    for r in measure.regionprops(lab):
        if r.area >= min_area:
            out[lab == r.label] = 1
    return out

def mask_to_instances(mask_bin: np.ndarray, min_area: int = 20) -> List[np.ndarray]:
    mask_bin = _remove_small(mask_bin, min_area)
    lab = measure.label(mask_bin, connectivity=1)
    insts = []
    for lab_id in range(1, lab.max()+1):
        inst = (lab == lab_id).astype(np.uint8)
        if inst.sum() > 0:
            insts.append(inst)
    return insts

def mask_iou(a: np.ndarray, b: np.ndarray) -> float:
    inter = (a & b).sum()
    if inter == 0: 
        return 0.0
    union = a.sum() + b.sum() - inter
    return float(inter) / float(union + 1e-6)

def greedy_match_ious(preds: List[np.ndarray], gts: List[np.ndarray], iou_thr=0.5):
    matches = []
    U_p = set(range(len(preds)))
    U_g = set(range(len(gts)))
    # precompute IoU matrix
    iou_mat = np.zeros((len(preds), len(gts)), dtype=np.float32)
    for i, pm in enumerate(preds):
        for j, gm in enumerate(gts):
            iou_mat[i, j] = mask_iou(pm, gm)
    while True:
        best = (iou_thr, -1, -1)  # (val, i, j)
        for i in U_p:
            row = iou_mat[i]
            for j in U_g:
                val = row[j]
                if val > best[0]:
                    best = (val, i, j)
        if best[1] == -1:
            break
        _, i, j = best
        matches.append((i, j))
        U_p.remove(i); U_g.remove(j)
    return matches

def instance_scores(prob_map: np.ndarray, insts: List[np.ndarray]) -> List[float]:
    return [float(prob_map[m.astype(bool)].mean()) if m.sum() else 0.0 for m in insts]

def average_precision_at_iou(pred_instances, pred_scores, gt_instances, iou_thr=0.5) -> float:
    if len(pred_instances) == 0:
        return 0.0
    order = np.argsort(-np.array(pred_scores))
    pred_instances = [pred_instances[i] for i in order]
    pred_scores    = [pred_scores[i]    for i in order]

    matched_gts = set()
    tps, fps = [], []
    for pm in pred_instances:
        best_iou, best_j = 0.0, -1
        for j, gm in enumerate(gt_instances):
            if j in matched_gts: 
                continue
            iou = mask_iou(pm, gm)
            if iou > best_iou:
                best_iou, best_j = iou, j
        if best_iou >= iou_thr and best_j != -1:
            tps.append(1); fps.append(0); matched_gts.add(best_j)
        else:
            tps.append(0); fps.append(1)

    tps, fps = np.array(tps), np.array(fps)
    cum_tp, cum_fp = np.cumsum(tps), np.cumsum(fps)
    recalls = cum_tp / (len(gt_instances) + 1e-6)
    precisions = cum_tp / (cum_tp + cum_fp + 1e-6)
    # 11-point VOC
    ap = 0.0
    for r in np.linspace(0,1,11):
        p_r = precisions[recalls >= r].max() if np.any(recalls >= r) else 0.0
        ap += p_r / 11.0
    return float(ap)

# ---- image-level curves (AUROC/AUPRC) ----
def _pr_curve(scores: np.ndarray, labels: np.ndarray):
    order = np.argsort(-scores)
    scores = scores[order]; labels = labels[order]
    tp = 0; fp = 0; P = labels.sum(); N = len(labels) - P + 1e-6
    precisions, recalls, thresholds = [], [], []
    last = None
    for s, y in zip(scores, labels):
        if last is None or s != last:
            precisions.append(tp / (tp + fp + 1e-6))
            recalls.append(tp / (P + 1e-6))
            thresholds.append(s)
            last = s
        if y == 1: tp += 1
        else: fp += 1
    precisions.append(tp / (tp + fp + 1e-6))
    recalls.append(tp / (P + 1e-6))
    thresholds.append(0.0)
    return np.array(precisions), np.array(recalls), np.array(thresholds)

def _auprc(prec, rec):
    order = np.argsort(rec)
    rec = rec[order]; prec = prec[order]
    return float(np.trapz(prec, rec))

def _roc_curve(scores: np.ndarray, labels: np.ndarray):
    order = np.argsort(-scores)
    scores = scores[order]; labels = labels[order]
    tp=0; fp=0; P=labels.sum(); N=len(labels)-P+1e-6
    TPR=[0.0]; FPR=[0.0]; last=None
    for s,y in zip(scores,labels):
        if y==1: tp+=1
        else: fp+=1
        if last is None or s!=last:
            TPR.append(tp/(P+1e-6)); FPR.append(fp/(N))
            last=s
    TPR.append(1.0); FPR.append(1.0)
    return np.array(FPR), np.array(TPR)

def _auroc(fpr,tpr):
    order = np.argsort(fpr)
    return float(np.trapz(tpr[order], fpr[order]))

def eval_seg(model, dataset, device, threshold=0.5) -> Dict[str, float]:
    """Pixel-level segmentation metrics."""
    loader = get_dataloaders(dataset)
    get_debug(loader)

    tot_acc=tot_rec=tot_f1=tot_iou=tot_prec=0.0
    batches=0

    model.eval()
    with torch.no_grad():
        for x1,x2,y in loader:
            x1,x2,y = prep_batch(x1,x2,y,device)
            out = model(x1,x2)
            y_main = out[0] if isinstance(out,(tuple,list)) else out
            tp, fp, fn, tn = sm.get_statistics(y_main, y, mode='binary', threshold=threshold)
            acc = sm.acc(tp, fp, fn, tn)
            rec = sm.recall(tp, fp, fn, tn)
            f1  = sm.f1(tp, fp, fn, tn)  # Dice for binary = F1
            iou = sm.iou(tp, fp, fn, tn)
            prec= sm.prec(tp, fp, fn, tn)
            tot_acc += acc; tot_rec += rec; tot_f1 += f1; tot_iou += iou; tot_prec += prec; batches+=1

    return {
        "task":"segmentation(pixel)",
        "accuracy": float(tot_acc/batches),
        "recall":   float(tot_rec/batches),
        "precision":float(tot_prec/batches),
        "f1(dice)": float(tot_f1/batches),
        "iou":      float(tot_iou/batches),
        "batches":  int(batches)
    }

def eval_det(model, dataset, device, iou_thresh=0.5, prob_thr=0.5, min_area=20) -> Dict[str, float]:
    """Instance-level detection from segmentation logits."""
    loader = get_dataloaders(dataset)
    get_debug(loader)

    total_matches=total_pred=total_gt=0
    ap_list=[]
    model.eval()
    with torch.no_grad():
        for x1,x2,y in loader:
            x1,x2,y = prep_batch(x1,x2,y,device)
            out = model(x1,x2)
            y_main = out[0] if isinstance(out,(tuple,list)) else out
            if y_main.shape[-2:] != y.shape[-2:]:
                y_main = F.interpolate(y_main, size=y.shape[-2:], mode='bilinear', align_corners=False)
            prob = torch.sigmoid(y_main).detach().cpu().numpy()
            gt   = y.detach().cpu().numpy().astype(np.uint8)
            B = prob.shape[0]
            for b in range(B):
                p = prob[b,0]; g = gt[b,0]
                pred_bin = _binarize(p, prob_thr)
                pred_inst = mask_to_instances(pred_bin, min_area=min_area)
                gt_inst   = mask_to_instances(g,        min_area=min_area)
                scores    = instance_scores(p, pred_inst)
                matches   = greedy_match_ious(pred_inst, gt_inst, iou_thr=iou_thresh)

                total_matches += len(matches)
                total_pred    += len(pred_inst)
                total_gt      += len(gt_inst)
                ap_list.append(average_precision_at_iou(pred_inst, scores, gt_inst, iou_thr=iou_thresh))

    tp = total_matches
    fp = total_pred - tp
    fn = total_gt - tp
    precision = tp / (tp + fp + 1e-6)
    recall    = tp / (tp + fn + 1e-6)
    f1        = 2*precision*recall/(precision+recall+1e-6)
    ap_mean   = float(np.mean(ap_list) if ap_list else 0.0)
    return {
        "task":"detection(instance)",
        "precision": float(precision),
        "recall":    float(recall),
        "f1":        float(f1),
        f"AP@{iou_thresh:.2f}": ap_mean,
        "TP": int(tp), "FP": int(fp), "FN": int(fn),
        "num_images": int(total_pred + fn)  # approximate count traversed
    }

def eval_image(model, dataset, device, prob_thr_for_instances=0.5, min_area=20) -> Dict[str, float]:
    """Image-level presence metrics (AUROC/AUPRC) robust to imbalance."""
    loader = get_dataloaders(dataset)
    model.eval()
    img_scores=[]; img_labels=[]
    with torch.no_grad():
        for x1,x2,y in loader:
            x1,x2,y = prep_batch(x1,x2,y,device)
            out = model(x1,x2)
            y_main = out[0] if isinstance(out,(tuple,list)) else out
            if y_main.shape[-2:] != y.shape[-2:]:
                y_main = F.interpolate(y_main, size=y.shape[-2:], mode='bilinear', align_corners=False)
            prob = torch.sigmoid(y_main).detach().cpu().numpy()
            gt   = y.detach().cpu().numpy().astype(np.uint8)
            B = prob.shape[0]
            for b in range(B):
                p = prob[b,0]; g = gt[b,0]
                y_img = 1 if g.sum() > 0 else 0
                pred_bin = (p >= prob_thr_for_instances).astype(np.uint8)
                insts = mask_to_instances(pred_bin, min_area=min_area)
                score = float(np.max(instance_scores(p, insts))) if len(insts)>0 else 0.0
                img_scores.append(score); img_labels.append(y_img)

    scores = np.asarray(img_scores, dtype=np.float32)
    labels = np.asarray(img_labels, dtype=np.int32)
    # PR & ROC
    prec, rec, thr = _pr_curve(scores, labels)
    auprc = _auprc(prec, rec)
    fpr, tpr = _roc_curve(scores, labels)
    auroc = _auroc(fpr, tpr)
    # F1-optimal threshold (optional calibration)
    f1s = 2*prec*rec/(prec+rec+1e-6)
    best_idx = int(np.nanargmax(f1s))
    return {
        "task":"image-level(presence)",
        "AUROC": float(auroc),
        "AUPRC": float(auprc),
        "best_F1": float(f1s[best_idx]),
        "best_threshold": float(thr[best_idx]),
        "positives": int(labels.sum()),
        "negatives": int((1-labels).sum()),
        "num_images": int(len(labels))
    }

def evaluate(model, dataset, device, task="seg", **kwargs) -> Dict[str, float]:
    """
    Dispatch to the correct evaluator.
    task: "seg" | "det" | "image"
    kwargs forwarded to the specific evaluator.
    """
    if task == "seg":
        return eval_seg(model, dataset, device, **kwargs)
    elif task == "det":
        return eval_det(model, dataset, device, **kwargs)
    elif task == "image":
        return eval_image(model, dataset, device, **kwargs)
    else:
        raise ValueError(f"Unknown task: {task}")

import numpy as np
import torch
import torch.nn.functional as F
from typing import Dict, List
import models.smp_metrics as sm

@torch.no_grad()
def collect_pixel_metrics(model, dataset, device, threshold=0.5):
    """
    Returns:
      per_image: dict of lists (iou, dice, precision, recall, accuracy)
      totals: dict with aggregated TP/FP/FN/TN over all pixels
    """
    loader = get_dataloaders(dataset)
    get_debug(loader)

    per_iou, per_dice, per_prec, per_rec, per_acc = [], [], [], [], []

    tot_tp = tot_fp = tot_fn = tot_tn = 0
    model.eval()
    for x1, x2, y in loader:
        x1, x2, y = prep_batch(x1, x2, y, device)
        out = model(x1, x2)
        y_main = out[0] if isinstance(out, (tuple, list)) else out
        if y_main.shape[-2:] != y.shape[-2:]:
            y_main = F.interpolate(y_main, size=y.shape[-2:], mode='bilinear', align_corners=False)

        # batch pixel metrics
        tp, fp, fn, tn = sm.get_statistics(y_main, y, mode='binary', threshold=threshold)

        # reduce to scalars
        acc = sm.acc(tp, fp, fn, tn)
        rec = sm.recall(tp, fp, fn, tn)
        f1  = sm.f1(tp, fp, fn, tn)        # Dice == F1 in binary segmentation
        iou = sm.iou(tp, fp, fn, tn)
        prec= sm.prec(tp, fp, fn, tn)

        # get per-image by splitting batch accumulators image-wise
        # sm.* functions usually support tensor inputs; to get per-image,
        # we recompute with thresholded preds per sample quickly:
        probs = torch.sigmoid(y_main)
        preds = (probs >= threshold).long()
        B = y.shape[0]
        for b in range(B):
            tp_b, fp_b, fn_b, tn_b = sm.get_stats_simple(preds[b:b+1], y[b:b+1]) if hasattr(sm, "get_stats_simple") else sm.get_statistics(preds[b:b+1].float(), y[b:b+1], mode='binary', threshold=0.5)
            acc_b = sm.acc(tp_b, fp_b, fn_b, tn_b)
            rec_b = sm.recall(tp_b, fp_b, fn_b, tn_b)
            f1_b  = sm.f1(tp_b, fp_b, fn_b, tn_b)
            iou_b = sm.iou(tp_b, fp_b, fn_b, tn_b)
            prec_b= sm.prec(tp_b, fp_b, fn_b, tn_b)
            per_acc.append(float(acc_b)); per_rec.append(float(rec_b))
            per_dice.append(float(f1_b)); per_iou.append(float(iou_b)); per_prec.append(float(prec_b))

        # accumulate totals
        tot_tp += int(tp.sum().item()); tot_fp += int(fp.sum().item())
        tot_fn += int(fn.sum().item()); tot_tn += int(tn.sum().item())

    per_image = dict(iou=per_iou, dice=per_dice, precision=per_prec, recall=per_rec, accuracy=per_acc)
    totals = dict(TP=tot_tp, FP=tot_fp, FN=tot_fn, TN=tot_tn)
    return per_image, totals


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# SEGMENTATION (pixel-level)
seg_metrics = evaluate(model, test_dataset, device, task="seg", threshold=0.5)
print(seg_metrics)

# IMAGE-LEVEL PRESENCE
img_metrics = evaluate(model, test_dataset, device, task="image",
                       prob_thr_for_instances=0.5, min_area=20)
print(img_metrics)