### Inference

- **Dataset:** Bijie Dataset
- **Backbone:** EfficientNet B4

**Prepare the dataset**

In [None]:
import os 
import torch
from torch.utils.data import random_split, ConcatDataset

from dataset import BijieRawDataset, TwoComposites, DualStreamTransform

EX_NO = 'ef_b4'
# Data directory
DATA_DIR = "/home/user1/ms/Datasets/Bijie-landslide-dataset"
# Base directory
BASE_DIR = "/home/user1/ms/DiGATe-UNet-LandSlide-Segmentation" 

landslide_ds = BijieRawDataset(f"{DATA_DIR}/landslide", phase="landslide")
nonlandslide_ds = BijieRawDataset(f"{DATA_DIR}/non-landslide", phase="non-landslide")

# Set a fixed seed for reproducibility
seed = 42
generator = torch.Generator().manual_seed(seed)

# split each one into train/val/test using the generator
def split(ds, ratios=(.7,.2,.1), generator=None):
    n = len(ds)
    sizes = [int(r * n) for r in ratios]
    sizes[2] = n - sum(sizes[:2])
    return random_split(ds, sizes, generator=generator)

# Apply the split with reproducible shuffling
tl, vl, sl = split(landslide_ds, generator=generator)
tn, vn, sn = split(nonlandslide_ds, generator=generator)

# concat landslide + non‐landslide for each split
train_ds = ConcatDataset([tl, tn])
val_ds   = ConcatDataset([vl, vn])
test_ds  = ConcatDataset([sl, sn])

train_dataset = TwoComposites(train_ds, bands='RGB&DEM', resize_to=256, transform=DualStreamTransform())
val_dataset = TwoComposites(val_ds, bands='RGB&DEM', resize_to=256, transform=None)
test_dataset = TwoComposites(test_ds, bands='RGB&DEM', resize_to=256, transform=None)

image1, image2, mask = train_dataset[0]

print(f"Number of training samples: {len(train_ds)}")
print(f"Number of validation samples: {len(val_ds)}")
print(f"Number of test samples: {len(test_ds)}")

print(type(image1), image1.shape, image1.min().item(), image1.max().item())
print(type(image2), image2.shape, image2.min().item(), image2.max().item())
print(type(mask), mask.shape, mask.min().item(), mask.max().item())

Number of training samples: 1941
Number of validation samples: 554
Number of test samples: 278
<class 'torch.Tensor'> torch.Size([3, 256, 256]) -0.14067722856998444 1.0951991081237793
<class 'torch.Tensor'> torch.Size([3, 256, 256]) -0.14775483310222626 1.1545225381851196
<class 'torch.Tensor'> torch.Size([256, 256]) 0.0 1.0


**Load the Model**

In [None]:
import torch
from models import DiGATe_Unet
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {DEVICE}')

model = DiGATe_Unet(
        n_classes=1,
        backbone="tf_efficientnet_b4",
        n_channels=3,
        pretrained=True,         
        pretrained_path=None,     
        use_input_adapter=False,
        freeze_backbone=True,
        share_backbone=False
    ).to(DEVICE)

checkpoint = torch.load(os.path.join(BASE_DIR, "weights", f"{EX_NO}.pth"), weights_only=False)
# model.load_state_dict(checkpoint['model_state_dict'])
model.load_state_dict(checkpoint)

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


Unexpected keys (bn2.bias, bn2.num_batches_tracked, bn2.running_mean, bn2.running_var, bn2.weight, classifier.bias, classifier.weight, conv_head.weight) found while loading pretrained weights. This may be expected if model is being adapted.
Unexpected keys (bn2.bias, bn2.num_batches_tracked, bn2.running_mean, bn2.running_var, bn2.weight, classifier.bias, classifier.weight, conv_head.weight) found while loading pretrained weights. This may be expected if model is being adapted.


<All keys matched successfully>

**Evaluation**

In [3]:
from utils.evaluate import evaluate_model

evaluate_model(model, val_dataset, DEVICE, "Validation")
evaluate_model(model, test_dataset, DEVICE, "Test")

Evaluating: 100%|██████████| 18/18 [00:07<00:00,  2.47it/s]



--- Evaluation Metrics on Validation Set---
Acc       : 0.9892
Recall    : 0.9295
Prec      : 0.9434
F1        : 0.9102
Iou       : 0.8826
--------------------------


Evaluating: 100%|██████████| 9/9 [00:03<00:00,  2.52it/s]


--- Evaluation Metrics on Test Set---
Acc       : 0.9897
Recall    : 0.9249
Prec      : 0.9453
F1        : 0.9111
Iou       : 0.8840
--------------------------





**Save Predictions**

In [4]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

def save_all_predictions_with_heatmap(model, dataset, device, ex_no, alpha=0.45):
    
    output_dir = f'visuals/{ex_no}/'
    os.makedirs(output_dir, exist_ok=True)
    model.to(device).eval()

    for idx in tqdm(range(len(dataset)), desc="Inference"):
        x1, x2, y_true = dataset[idx]
        x1_dev = x1.unsqueeze(0).to(device)
        x2_dev = x2.unsqueeze(0).to(device)

        with torch.no_grad():
            out = model(x1_dev, x2_dev)
            y_main = out[0] if isinstance(out, (tuple, list)) else out  # (B,1,H,W) logits
            prob = torch.sigmoid(y_main)                                # (B,1,H,W) in [0,1]
            y_pred = (prob > 0.5).float()

        img = x1[:3].permute(1, 2, 0).cpu().numpy()
        # Normalize to 0..1 if outside
        if img.max() > 1.0 or img.min() < 0.0:
            img = (img - img.min()) / (img.max() - img.min() + 1e-6)

        true_mask = y_true.squeeze().cpu().numpy()
        pred_mask = y_pred.squeeze().cpu().numpy()
        heat = prob.squeeze().cpu().numpy()  # smooth heatmap in [0,1]

        # Plot and save: Image | Ground Truth | Prediction | Image+Heatmap
        fig, axes = plt.subplots(1, 4, figsize=(16, 4))

        # Original RGB
        axes[0].imshow(img)
        axes[0].set_title("Image")
        axes[0].axis("off")

        # Ground Truth
        axes[1].imshow(true_mask, cmap="gray")
        axes[1].set_title("Ground Truth")
        axes[1].axis("off")

        # Prediction (binary)
        axes[2].imshow(pred_mask, cmap="gray", vmin=0, vmax=1)
        axes[2].set_title("Prediction")
        axes[2].axis("off")

        # Image + Heatmap overlay (blue=0, red=1)
        axes[3].imshow(img)
        hm = axes[3].imshow(heat, cmap="bwr", vmin=0.0, vmax=1.0, alpha=alpha)
        axes[3].set_title("Image + Heatmap")
        axes[3].axis("off")

        cbar = fig.colorbar(hm, ax=axes[3], fraction=0.046, pad=0.04)
        cbar.set_label("Predicted probability", rotation=270, labelpad=12)

        plt.tight_layout()
        save_path = os.path.join(output_dir, f"sample_{idx}.png")
        plt.savefig(save_path, dpi=200, bbox_inches="tight")
        plt.close(fig)

    print(f"✅ Saved all predictions (with heatmaps) to: {output_dir}")

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# save_all_predictions_with_heatmap(model, test_dataset, device, ex_no=EX_NO)