<a href="https://colab.research.google.com/github/basakesin/InsectAI-WG3-STSM/blob/main/test_model_with_gradio.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Test Your Model (Keras & PyTorch)

This notebook lets you **upload a trained image classification model** and **interactively test predictions** via a simple **Gradio web interface** — no extra coding required.  
It supports both **TensorFlow/Keras** and **PyTorch** models.


## What You Can Load
- **Keras / TensorFlow**
  - `.keras`, `.h5`, or a `.zip` that contains a **SavedModel**  
  - Optional: `class_names.txt` (one label per line) to map outputs to readable names  
  - Built-in support for common backbones: **EfficientNetB0**, **MobileNetV2**, **ResNet50**, **InceptionV3**

- **PyTorch**
  - `.pt` / `.pth` **full model** (`torch.save(model, ...)`)  
  - `.pt` / `.pth` **state_dict** (`torch.save(model.state_dict(), ...)`)  
  - Optional: `class_names.txt` for readable labels  
  - If you upload a **state_dict**, you must select the correct **Backbone** in the UI (**ResNet50**, **MobileNetV2**, **EfficientNetB0**, **InceptionV3**)



## Key Features
- **Upload & run** your trained model (Keras or PyTorch) in one place  
- **Optional class names** for human-readable outputs  
- **Automatic preprocessing**  
  - Keras → detects internal `preprocess_input` if present; otherwise applies the correct external preprocess (or `/255`)  
  - PyTorch → standard **ImageNet normalization (mean/std)**  
- **Drag-and-drop image testing**  
- **Interactive results**: top-5 predictions with confidence scores  
- **Binary models**: outputs shown as **positive/negative probabilities**



## Quick Start
1. In the Colab menu, go to **Runtime → Run all**  
   - This installs dependencies, prepares the notebook, and launches the Gradio app automatically
2. In the UI:
   - Select **Framework** → Keras or PyTorch  
   - **Upload your model file**  
     - Keras: `.keras`, `.h5`, or `.zip`  
     - PyTorch: `.pt` / `.pth`  
   - (Optional) Upload `class_names.txt`  
   - If using a PyTorch **state_dict**, choose the right **Backbone**
3. After setup, a **Gradio link** will appear  
   - Click it to open the interactive web interface  
   - Upload a test image → see predictions instantly



## Notes & Tips
- **Input size**: InceptionV3 → `299×299`; others → `224×224` (handled automatically)  
- **Binary outputs**: If the model returns a single logit, the app applies **sigmoid** and shows `positive` / `negative` probabilities  
- **SavedModel ZIPs**: The archive must contain `saved_model.pb` at the top level  
- **Common issues**:  
  - Wrong backbone for PyTorch `state_dict` → select the correct one  
  - `class_names.txt` length doesn’t match model outputs → fallback to `class_0, class_1, ...`  
  - Environment glitches → use **Runtime → Restart and run all** to reset



💡 **That’s it!** Upload your model, select the right options, and start testing predictions interactively.


In [None]:
# @title Load Required Packages
%pip -q install "tensorflow==2.19.0" "tf-keras==2.19.0" "keras==3.5.0" "numpy==2.0.2" "Pillow>=10.3.0"
import tensorflow as tf, keras, numpy as np
import os, io, zipfile, tempfile, json
from PIL import Image
import gradio as gr
print("TF", tf.__version__, "| Keras", keras.__version__, "| NumPy", np.__version__)
# Beklenen: TF 2.19.0 | Keras 3.5.x | NumPy 2.0.2


In [None]:
# @title Run Your Interface (Keras + PyTorch)
import os, io, zipfile, tempfile, json
from typing import Optional, Tuple

import numpy as np
from PIL import Image
import gradio as gr

# ====== KERAS / TF ======
import tensorflow as tf, keras
# Allow unsafe deserialization (for trusted models with Lambda/preprocess layers)
try:
    keras.config.enable_unsafe_deserialization()
except Exception:
    pass

# DepthwiseConv2D "groups" compatibility patch
from tensorflow.keras.layers import DepthwiseConv2D as _TFDepthwiseConv2D
class DepthwiseConv2DCompat(_TFDepthwiseConv2D):
    def __init__(self, *args, groups=None, **kwargs):
        super().__init__(*args, **kwargs)
CUSTOM_OBJECTS_BASE = {"DepthwiseConv2D": DepthwiseConv2DCompat}

# Candidate Keras preprocess functions
from tensorflow.keras.applications.efficientnet import preprocess_input as eff_pre
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mob_pre
from tensorflow.keras.applications.resnet50 import preprocess_input as res_pre
from tensorflow.keras.applications.inception_v3 import preprocess_input as inc_pre

KERAS_PREPROC = {
    "EfficientNetB0": eff_pre,
    "MobileNetV2":   mob_pre,
    "ResNet50":      res_pre,
    "InceptionV3":   inc_pre,
}

def _infer_input_size(model):
    ishape = model.input_shape
    if isinstance(ishape, list): ishape = ishape[0]
    _, h, w, c = ishape
    if h is None or w is None:
        h, w = 224, 224
    return (w, h), c

def _prep_image_keras(pil_img, size, channels, preprocess=None, scale_255=True):
    img = pil_img.convert("RGB").resize(size)
    x = np.array(img, dtype=np.float32)
    if channels == 1:
        x = np.mean(x, axis=-1, keepdims=True)
    if preprocess is not None:
        x = preprocess(x)  # do not divide by 255 here
    else:
        if scale_255:
            x = x / 255.0
    return np.expand_dims(x, 0)

def _try_load_no_custom(path_or_dir):
    co = dict(CUSTOM_OBJECTS_BASE)
    return tf.keras.models.load_model(
        path_or_dir, compile=False, safe_mode=False, custom_objects=co
    )

def _read_class_names(txt_path: str):
    if not txt_path:
        return None
    with open(txt_path, "r", encoding="utf-8") as f:
        names = [ln.strip() for ln in f if ln.strip()]
    return names or None

def _resolve_zip_savedmodel(path):
    workdir = tempfile.mkdtemp()
    with zipfile.ZipFile(path, "r") as zf:
        zf.extractall(workdir)
    cand = None
    for root, _, files in os.walk(workdir):
        if "saved_model.pb" in files:
            cand = root; break
    if not cand:
        raise RuntimeError("No 'saved_model.pb' found inside ZIP.")
    return cand

def keras_load_model_any(model_path: str, backbone_hint: Optional[str]):
    low = model_path.lower()
    target = _resolve_zip_savedmodel(model_path) if low.endswith(".zip") else model_path

    # 1) Try without preprocess (means no Lambda inside the model)
    try:
        m = _try_load_no_custom(target)
        return m, None, None, False
    except Exception:
        pass

    # 2) Try with candidate preprocess functions (prioritize selected backbone)
    order = []
    if backbone_hint in KERAS_PREPROC:
        order.append((backbone_hint, KERAS_PREPROC[backbone_hint]))
    for name, fn in KERAS_PREPROC.items():
        if name != backbone_hint:
            order.append((name, fn))

    last_err = None
    for name, fn in order:
        try:
            co = dict(CUSTOM_OBJECTS_BASE); co["preprocess_input"] = fn
            m = tf.keras.models.load_model(target, compile=False, safe_mode=False, custom_objects=co)
            return m, name, fn, True
        except Exception as e:
            last_err = e
    raise RuntimeError(f"Keras model could not be deserialized. Last error: {last_err}")

# ====== PYTORCH ======
import torch
import torch.nn as nn
import torchvision
from torchvision import models as tvm
from torchvision import transforms as T

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

def _canon(name: str) -> str:
    """Normalize names: 'MobileNetV2' -> 'mobilenetv2', 'EfficientNet_B0' -> 'efficientnetb0'"""
    return "".join(ch for ch in (name or "").lower() if ch.isalnum())

def _torch_default_input_size(backbone: str) -> Tuple[int,int]:
    key = _canon(backbone)
    return (299, 299) if key == "inceptionv3" else (224, 224)

def _torch_transform(size: Tuple[int,int]):
    w, h = size
    return T.Compose([
        T.Resize((h, w)),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])

def _torch_build_backbone(name: str, num_classes: Optional[int] = None):
    key = _canon(name)
    if key == "resnet50":
        m = tvm.resnet50(weights=None)
        in_feats = m.fc.in_features
        m.fc = nn.Linear(in_feats, num_classes or m.fc.out_features)
    elif key == "mobilenetv2":
        m = tvm.mobilenet_v2(weights=None)
        in_feats = m.classifier[-1].in_features
        m.classifier[-1] = nn.Linear(in_feats, num_classes or m.classifier[-1].out_features)
    elif key == "efficientnetb0":
        m = tvm.efficientnet_b0(weights=None)
        in_feats = m.classifier[-1].in_features
        m.classifier[-1] = nn.Linear(in_feats, num_classes or m.classifier[-1].out_features)
    elif key == "inceptionv3":
        m = tvm.inception_v3(weights=None, aux_logits=False)
        in_feats = m.fc.in_features
        m.fc = nn.Linear(in_feats, num_classes or m.fc.out_features)
    else:
        raise ValueError(f"Unsupported PyTorch backbone: {name}")
    return m

def torch_load_model_any(model_path: str, backbone_hint: Optional[str], class_names: Optional[list]):
    """
    Tries in order:
      1) torch.jit.load (scripted/trace)
      2) torch.load full model object
      3) torch.load state_dict -> rebuild backbone (with class count if available)
    """
    # 1) JIT
    try:
        m = torch.jit.load(model_path, map_location="cpu")
        m.eval()
        return m, (backbone_hint or "Unknown"), _torch_default_input_size(backbone_hint or "resnet50")
    except Exception:
        pass

    last_err = None
    try:
        obj = torch.load(model_path, map_location="cpu")
        # Full model?
        if hasattr(obj, "state_dict") and callable(obj.state_dict):
            try: obj.eval()
            except Exception: pass
            return obj, (backbone_hint or "Unknown"), _torch_default_input_size(backbone_hint or "resnet50")
        # state_dict?
        if isinstance(obj, dict):
            if backbone_hint is None:
                raise RuntimeError("State dict detected. Please select a Backbone.")
            num_classes = (len(class_names) if class_names else None)
            m = _torch_build_backbone(backbone_hint, num_classes=num_classes)
            missing, unexpected = m.load_state_dict(obj, strict=False)
            if missing or unexpected:
                print("State dict load — missing:", missing, "unexpected:", unexpected)
            m.eval()
            return m, backbone_hint, _torch_default_input_size(backbone_hint)
    except Exception as e:
        last_err = e

    raise RuntimeError(f"PyTorch model could not be loaded. Last error: {last_err}")

def _torch_predict(model: torch.nn.Module, img: Image.Image, size: Tuple[int,int], class_names: Optional[list]):
    tfm = _torch_transform(size)
    x = tfm(img.convert("RGB")).unsqueeze(0)
    with torch.no_grad():
        out = model(x)
        if isinstance(out, (list, tuple)):
            out = out[0]
        out = out.squeeze(0)
        if out.ndim == 0 or out.numel() == 1:   # binary logit
            p = torch.sigmoid(out.flatten()[0]).item()
            return {"positive": float(p), "negative": float(1.0 - p)}
        # multiclass
        probs = torch.softmax(out, dim=-1).cpu().numpy()
        if class_names and len(class_names) == probs.shape[-1]:
            return {cls: float(p) for cls, p in zip(class_names, probs)}
        return {f"class_{i}": float(p) for i, p in enumerate(probs)}

# ====== GRADIO UI ======
with gr.Blocks(title="Image Classification (Keras / PyTorch)") as demo:
    gr.Markdown(
        "## Image Classification (Keras / PyTorch)\n"
        "Upload your model and **class_names.txt**. Select the **Framework** and, if necessary, the **Backbone**, then provide an image to get predictions.\n\n"
        "**Supported formats**  \n"
        "- **Keras**: `.keras`, `.h5`, or `.zip` containing a SavedModel  \n"
        "- **PyTorch**: `.pt` / `.pth` (full model **or** `state_dict`)  \n"
        "Note: For `state_dict`, you must select the correct backbone."
    )

    with gr.Row():
        framework_dd = gr.Dropdown(
            choices=["Keras", "PyTorch"],
            value="Keras",
            label="Framework"
        )
        model_file = gr.File(
            label="Model File (.keras / .h5 / .zip / .pt / .pth)",
            file_types=[".keras", ".h5", ".zip", ".pt", ".pth"],
            type="filepath"
        )
        class_file = gr.File(
            label="Class Names (class_names.txt) — optional",
            file_types=[".txt"],
            type="filepath"
        )
        backbone_dropdown = gr.Dropdown(
            choices=["EfficientNetB0", "MobileNetV2", "ResNet50", "InceptionV3"],
            label="Backbone (for PyTorch state_dict or Keras deserialization)",
            value="MobileNetV2"
        )

    load_btn = gr.Button("Load Model", variant="primary")
    status = gr.Markdown()

    # States
    framework_state = gr.State()
    model_state = gr.State()
    input_size_state = gr.State()      # (w,h)
    channels_state = gr.State()        # Keras only; None for PyTorch
    class_names_state = gr.State()
    preprocess_fn_state = gr.State()   # Keras external preprocess fn
    backbone_used_state = gr.State()
    has_internal_pre_state = gr.State()# Keras only; None for PyTorch

    with gr.Row():
        image_in = gr.Image(type="pil", label="Input Image")
        predict_btn = gr.Button("Predict", variant="primary")

    label_out = gr.Label(num_top_classes=5, label="Top-5 Predictions")

    # ------- LOAD -------
    def on_load(framework, model_path, class_path, backbone_hint):
        if not model_path:
            return ("Please choose a model file.",
                    framework, None, None, None, None, None, None, None)

        classes = _read_class_names(class_path)

        if framework == "Keras":
            try:
                m, used_backbone, used_pre, has_internal = keras_load_model_any(model_path, backbone_hint)
                (w, h), c = _infer_input_size(m)
                msg = f"✅ [Keras] Model loaded. Expected input: {(h,w,c)}"
                if has_internal:
                    msg += " | Preprocessing: inside model (no external /255)."
                else:
                    msg += f" | External preprocessing: {used_backbone or '/255'}"
                if classes:
                    msg += f" | {len(classes)} classes loaded."
                else:
                    msg += " | No class_names.txt found; class_0, class_1, ... will be used."
                preprocess_for_runtime = None if has_internal else used_pre
                return (msg, framework, m, (w, h), c, classes, preprocess_for_runtime, (used_backbone or "/255"), has_internal)
            except Exception as e:
                return (f"❌ Keras model failed to load: {e}", framework, None, None, None, None, None, None, None)

        # PyTorch
        try:
            m, used_bb, size = torch_load_model_any(model_path, backbone_hint, classes)
            w, h = size
            msg = f"✅ [PyTorch] Model loaded. Expected input: {(h,w,3)} | Backbone: {used_bb}"
            if classes:
                msg += f" | {len(classes)} classes loaded."
            else:
                msg += " | No class_names.txt found; class_0, class_1, ... will be used."
            return (msg, framework, m, (w, h), None, classes, None, used_bb, None)
        except Exception as e:
            return (f"❌ PyTorch model failed to load: {e}", framework, None, None, None, None, None, None, None)

    load_btn.click(
        on_load,
        inputs=[framework_dd, model_file, class_file, backbone_dropdown],
        outputs=[status, framework_state, model_state, input_size_state, channels_state,
                 class_names_state, preprocess_fn_state, backbone_used_state, has_internal_pre_state]
    )

    # ------- PREDICT -------
    def on_predict(img, framework, model, input_size, channels, class_names, preprocess_fn, backbone_used, has_internal_pre):
        if model is None:
            return {"error": 1.0}
        if img is None:
            return {"no_image": 1.0}

        if framework == "Keras":
            try:
                x = _prep_image_keras(
                    img, input_size, channels,
                    preprocess=(None if has_internal_pre else preprocess_fn),
                    scale_255=(False if has_internal_pre else (preprocess_fn is None))
                )
                y = model.predict(x, verbose=0)
                y = y[0] if isinstance(y, (list, tuple)) else y
                y = np.array(y).reshape(-1)

                # Binary case
                if y.shape == () or y.shape == (1,):
                    p = float(y if y.shape == () else y[0])
                    if p < 0 or p > 1:  # looks like a logit, apply sigmoid
                        p = 1 / (1 + np.exp(-p))
                    return {"positive": p, "negative": 1.0 - p}

                # Multiclass case
                if class_names and len(class_names) == y.shape[-1]:
                    return {cls: float(p) for cls, p in zip(class_names, y)}
                return {f"class_{i}": float(p) for i, p in enumerate(y)}
            except Exception as e:
                return {f"error: {e}": 1.0}

        # PyTorch
        try:
            return _torch_predict(model, img, input_size, class_names)
        except Exception as e:
            return {f"error: {e}": 1.0}

    predict_btn.click(
        on_predict,
        inputs=[image_in, framework_state, model_state, input_size_state, channels_state,
                class_names_state, preprocess_fn_state, backbone_used_state, has_internal_pre_state],
        outputs=[label_out]
    )

demo.queue().launch()
