# Evaluatio of Trained Segmentation Model (UNet + ResNet50)

**Classes** (example mapping):
- 0: background
- 1: foliage
- 2: wood
- 3: ivy

## Notes
- Provide your own local test arrays or image folders.
- Expected NumPy shapes:
  - `X_test`: `(N, 256, 256, 3)` RGB
  - `y_test`: either `(N, 256, 256)` integer masks **or** `(N, 256, 256, 4)` one hot masks


## 1. Setup

In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import tensorflow as tf

# Make segmentation_models use tf.keras
os.environ.setdefault("SM_FRAMEWORK", "tf.keras")
import segmentation_models as sm

print("TensorFlow:", tf.__version__)
print("segmentation_models:", getattr(sm, "__version__", "unknown"))


## 2. Configuration

In [None]:
# ---- Paths (edit for your local machine) ----
DATA_DIR = Path("data_private")     # not committed to GitHub
OUT_DIR  = Path("outputs_eval")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Model weights path (download separately hosted on Zenodo)
MODEL_PATH = Path("model") / "unet_resnet50_ash_tree_segmentation.hdf5"

# Test set arrays (filenames are placeholders: rename to match your local files)
X_TEST_NPY = DATA_DIR / "X_test.npy"
Y_TEST_NPY = DATA_DIR / "y_test.npy"

# Model / dataset settings
BACKBONE = "resnet50"
IMAGE_SIZE = 256
N_CLASSES = 4

# Visualisation
N_EXAMPLES_TO_PLOT = 6
RANDOM_SEED = 42


## 3. Load model

In [None]:
if not MODEL_PATH.exists():
    raise FileNotFoundError(
        f"Model not found: {MODEL_PATH.resolve()}\n"
        "Tip: download weights (from Zenodo) and place them under ./model/"
    )

# If your model uses custom objects, add them here:
CUSTOM_OBJECTS = {}

model = tf.keras.models.load_model(MODEL_PATH, compile=False, custom_objects=CUSTOM_OBJECTS)
model.summary()


## 4. Load test data

In [None]:
def load_npy(path: Path) -> np.ndarray:
    if not path.exists():
        raise FileNotFoundError(
            f"Missing file: {path.resolve()}\n"
        )
    return np.load(path)

X_test = load_npy(X_TEST_NPY)
y_test = load_npy(Y_TEST_NPY)

print("X_test:", X_test.shape, X_test.dtype)
print("y_test:", y_test.shape, y_test.dtype)

# Basic checks
if X_test.ndim != 4 or X_test.shape[-1] != 3:
    raise ValueError("X_test must have shape (N, H, W, 3).")

if X_test.shape[1] != IMAGE_SIZE or X_test.shape[2] != IMAGE_SIZE:
    raise ValueError(f"Expected IMAGE_SIZE={IMAGE_SIZE}, got {X_test.shape[1:3]}.")

# Convert y to integer mask if needed
if y_test.ndim == 4 and y_test.shape[-1] == N_CLASSES:
    y_true = np.argmax(y_test, axis=-1).astype(np.int32)
elif y_test.ndim == 3:
    y_true = y_test.astype(np.int32)
else:
    raise ValueError(
        "y_test must be either (N,H,W) integer masks or (N,H,W,C) one-hot masks."
    )

print("y_true:", y_true.shape, y_true.dtype)


## 5. Predict

In [None]:
preprocess_input = sm.get_preprocessing(BACKBONE)
X_pp = preprocess_input(X_test)

# Predict probabilities: (N,H,W,C)
probs = model.predict(X_pp, batch_size=8, verbose=1)

# Convert to predicted class labels: (N,H,W)
y_pred = np.argmax(probs, axis=-1).astype(np.int32)
print("y_pred:", y_pred.shape, y_pred.dtype)


## 6. Metrics (overall + per class)

We compute pixel wise precision, recall, F1 and IoU per class, and overall accuracy.

In [None]:
def per_class_metrics(y_true: np.ndarray, y_pred: np.ndarray, n_classes: int):
    metrics = []
    # Flatten for pixel-wise evaluation
    yt = y_true.reshape(-1)
    yp = y_pred.reshape(-1)

    overall_acc = float((yt == yp).mean())

    for c in range(n_classes):
        tp = np.sum((yt == c) & (yp == c))
        fp = np.sum((yt != c) & (yp == c))
        fn = np.sum((yt == c) & (yp != c))

        precision = tp / (tp + fp + 1e-12)
        recall    = tp / (tp + fn + 1e-12)
        f1        = 2 * precision * recall / (precision + recall + 1e-12)
        iou       = tp / (tp + fp + fn + 1e-12)

        metrics.append({
            "class": c,
            "tp": int(tp), "fp": int(fp), "fn": int(fn),
            "precision": float(precision),
            "recall": float(recall),
            "f1": float(f1),
            "iou": float(iou),
        })

    return overall_acc, pd.DataFrame(metrics)

overall_acc, df = per_class_metrics(y_true, y_pred, N_CLASSES)

display(df)
print("Overall pixel accuracy:", overall_acc)


## 7. Save results

In [None]:
results = {
    "backbone": BACKBONE,
    "image_size": IMAGE_SIZE,
    "n_classes": N_CLASSES,
    "overall_pixel_accuracy": overall_acc,
}

df.to_csv(OUT_DIR / "per_class_metrics.csv", index=False)
pd.DataFrame([results]).to_csv(OUT_DIR / "summary_metrics.csv", index=False)

print("Saved:")
print("-", (OUT_DIR / "per_class_metrics.csv").resolve())
print("-", (OUT_DIR / "summary_metrics.csv").resolve())


## 8. Visualise a few examples

In [None]:
import matplotlib.pyplot as plt

rng = np.random.default_rng(RANDOM_SEED)
idx = rng.choice(len(X_test), size=min(N_EXAMPLES_TO_PLOT, len(X_test)), replace=False)

def show_example(i: int):
    fig, ax = plt.subplots(1, 3, figsize=(12, 4))
    ax[0].imshow(X_test[i].astype(np.uint8))
    ax[0].set_title("Input (RGB)")
    ax[0].axis("off")

    ax[1].imshow(y_true[i], vmin=0, vmax=N_CLASSES-1)
    ax[1].set_title("Ground truth")
    ax[1].axis("off")

    ax[2].imshow(y_pred[i], vmin=0, vmax=N_CLASSES-1)
    ax[2].set_title("Prediction")
    ax[2].axis("off")

    plt.tight_layout()
    return fig

for i in idx:
    fig = show_example(int(i))
    fig.savefig(OUT_DIR / f"example_{int(i):04d}.png", dpi=200)
    plt.close(fig)

print(f"Saved {len(idx)} example figures to:", OUT_DIR.resolve())
