<a href="https://colab.research.google.com/github/cfoli/Multi-label-Medical-Image-Classification/blob/main/gradio_app_chestvision_PRO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --upgrade gradio

In [None]:
!pip install lightning torchmetrics

### Import dependencies

In [None]:
import torch
import torch.nn.functional as F
import torchvision
from torchvision import transforms, models, datasets
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torch.utils.data import random_split
import pytorch_lightning as torch_light
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import torchmetrics
from torchmetrics import Metric
import os
import shutil
import subprocess
import pandas as pd
from PIL import Image
import gradio
from functools import partial

### Set parameters

In [None]:

configs = {
    "IMAGE_SIZE":   (224, 224),    # Resize images to (W, H)
    "NUM_CHANNELS": 3,             # RGB images
    "NUM_CLASSES":  15,            # Number of output labels

    # ImageNet dataset normalization values (for pretrained backbones)
    "MEAN": (0.485, 0.456, 0.406),
    "STD":  (0.229, 0.224, 0.225),

    "DEFAULT_BACKBONE": "ViT-base-16",

    "THRESHOLD": 0.5
    }

MODEL_REGISTRY = {
    "CheXFormer-small": "m42-health/CXformer-small",
    "ViT-base-16": "google/vit-base-patch16-224",
}

MODEL_CACHE = {}


### Define helper functions

In [None]:
# helper function for loading pre-trained model
# ===================================================================================================
class get_pretrained_model(nn.Module):
    def __init__(
        self,
        model_name: str,
        num_classes: int,
        num_layers_to_unfreeze: int = 0):
        super().__init__()

        print(f"Loading pretrained [{model_name}] model")

        self.backbone = AutoModel.from_pretrained(
            MODEL_REGISTRY[model_name],
            trust_remote_code=True)

        hidden_size = self.backbone.config.hidden_size

        # Freeze entire backbone first
        for param in self.backbone.parameters():
            param.requires_grad = False

        # Selectively unfreeze last N layers
        if num_layers_to_unfreeze > 0:
            self._unfreeze_last_n_layers(num_layers_to_unfreeze)

        # Single classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Dropout(0.4),
            nn.Linear(hidden_size, num_classes) )

    def forward(self, x):
        outputs = self.backbone(x)

        # Use CLS token
        img_embeddings = outputs.last_hidden_state[:, 0]

        logits = self.classifier(img_embeddings)
        return logits

    def _unfreeze_last_n_layers(self, n: int):
        if hasattr(self.backbone, "encoder"):
            encoder_layers = self.backbone.encoder.layer
        elif hasattr(self.backbone, "vision_model"):
            encoder_layers = self.backbone.vision_model.encoder.layer
        else:
            raise ValueError("Cannot find encoder layers in backbone.")

        total_layers = len(encoder_layers)
        n = min(n, total_layers)

        print(f"Unfreezing last {n} of {total_layers} transformer layers.")

        for layer in encoder_layers[-n:]:
            for param in layer.parameters():
                param.requires_grad = True


# helper function for preprocessing input images
# ===================================================================================================
preprocess_fxn = transforms.Compose(
      [transforms.Resize(size=configs["IMAGE_SIZE"][::-1]),
      transforms.ToTensor(),
      transforms.Normalize(configs["MEAN"], configs["STD"], inplace=True)])

# Map numeric outputs to string labels
labels_dict = {
    0: "Atelectasis",
    1: "Cardiomegaly",
    2: "Consolidation",
    3: "Edema",
    4: "Effusion",
    5: "Emphysema",
    6: "Fibrosis",
    7: "Hernia",
    8: "Infiltration",
    9: "Mass",
    10: "No finding",
    11: "Nodule",
    12: "Pleural_Thickening",
    13: "Pneumonia",
    14: "Pneumothorax"}


### Create torch lightning model (i.e., classifier) module

In [None]:
class modelModule(torch_light.LightningModule):
    def __init__(self, num_classes, backbone_model_name, num_layers_to_unfreeze):
        super().__init__()
        self.num_classes = num_classes
        self.backbone_model_name = backbone_model_name
        self.num_layers_to_unfreeze = num_layers_to_unfreeze

        # Load a pretrained backbone and replace its final layer
        self.model = get_pretrained_model(
            num_classes = self.num_classes,
            model_name  = self.backbone_model_name,
            num_layers_to_unfreeze = self.num_layers_to_unfreeze)

        # Binary classification loss operating on raw logits
        self.loss_function      = torch.nn.BCEWithLogitsLoss()

        self.accuracy_function  = torchmetrics.classification.MultilabelAccuracy(num_labels=self.num_classes, average="weighted", threshold=0.5)
        self.f1_score_function  = torchmetrics.classification.MultilabelF1Score(num_labels=self.num_classes, average="weighted", threshold=0.5)
        self.auroc_function     = torchmetrics.classification.MultilabelAUROC(num_labels=self.num_classes, average="weighted", thresholds=10)
        self.map_score_function = torchmetrics.classification.MultilabelAveragePrecision(num_labels=self.num_classes, average="weighted", thresholds=10)
        # average options: macro (simple average), micro (sum), weighted (weight by class size, then avg)
        # threshold: Threshold for transforming probability to binary (0,1) predictions. For some metrics (e.g., AUROC), represents the number of thresholds (evenly spaced b/n 0â€“1) the metric should be computed at (resulting array of values are the averaged to obtain the final score)

    def forward(self, x):
        # Forward pass through the backbone model
        return self.model(x)

    def _common_step(self, batch, batch_idx):
        """
        Shared logic for train / val / test steps.
        Computes loss and evaluation metrics.
        """
        x, y = batch

        # Compute model predictions ()
        y_logits = self.forward(x)
        y_prob    = torch.sigmoid(y_logits)

        # Compute metrics (expects logits + labels)
        loss     = self.loss_function(y_logits, y.float())

        # Compute mean loss over all classes
        # loss     = torchmetrics.aggregation.MeanMetric(self.loss_function(y_hat, y.float()), weight=X.shape[0])
        accuracy = self.accuracy_function(y_prob, y)
        f1_score = self.f1_score_function(y_prob, y)
        auroc    = self.auroc_function(y_prob, y)
        mAP      = self.map_score_function(y_prob, y) # mean average precision

        return loss, y_logits, y, accuracy, f1_score, auroc, mAP

    def training_step(self, batch, batch_idx):
        # Run shared step
        loss, y_logits, y, accuracy, f1_score, auroc, mAP = self._common_step(batch, batch_idx)

        # Log epoch-level training metrics
        self.log_dict(
            {"train_loss": loss, "train_accuracy": accuracy, "train_f1_score": f1_score, "train_auroc": auroc, "train_mAP": mAP},
            on_step=False, on_epoch=True, prog_bar=True)

        # Lightning expects the loss key for backprop
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        # Run shared step
        loss, y_logits, y, accuracy, f1_score, auroc, mAP = self._common_step(batch, batch_idx)

        # Log validation metrics
        self.log_dict(
            {"val_loss": loss, "val_accuracy": accuracy,"val_f1_score": f1_score, "val_auroc": auroc, "val_mAP": mAP},
            on_step=False, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):
        # Run shared step
        loss, y_logits, y, accuracy, f1_score, auroc, mAP = self._common_step(batch, batch_idx)

        # Log test metrics
        self.log_dict(
            {"test_loss": loss, "test_accuracy": accuracy,"test_f1_score": f1_score, "test_auroc": auroc, "test_mAP": mAP},
            on_step=False, on_epoch=True, prog_bar=True)

    def predict_step(self, batch, batch_idx):
        """
        Prediction logic used by trainer.predict().
        Returns model outputs without computing loss.
        """
        x = batch if not isinstance(batch, (tuple, list)) else batch[0]
        logits = self.forward(x)

        # Convert logits to probabilities for inference
        probs = torch.sigmoid(logits)

        return probs

    def configure_optimizers(self):
        # Optimizer over all trainable parameters
        optimizer = optim.Adam(self.parameters(), lr=3e-5)
        return optimizer


### Create function for running inference (i.e., assistive medical diagnosis)

In [None]:
@torch.inference_mode()
def run_diagnosis(
    backbone_name,
    input_image,
    preprocess_fn=None,
    Idx2labels=None,
    threshold=configs["THRESHOLD"]):

    # Preprocess
    x = preprocess_fn(input_image).unsqueeze(0)

    # Resolve backbone
    backbone_info = MODEL_REGISTRY[backbone_name]
    ckpt_path = os.path.join(CKPT_ROOT, backbone_info["ckpt"])

    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    # Load model (cache for speed)
    if backbone_name not in MODEL_CACHE:
      MODEL_CACHE[backbone_name] = modelModule.load_from_checkpoint(
        ckpt_path, backbone_model_name=backbone_info["torchvision_name"], num_layers_to_unfreeze = 2)
    model = MODEL_CACHE[backbone_name]

    model.eval()

    # Forward
    logits = model(x)
    probs = torch.sigmoid(logits)[0].cpu().numpy()

    output_probs = {
        Idx2labels[i]: float(p) for i, p in enumerate(probs)
    }

    predicted_classes = [
        Idx2labels[i] for i, p in enumerate(probs) if p >= threshold
    ]

    return "\n".join(predicted_classes), output_probs


### Gradio app

In [None]:

# example_list_dir = os.path.join(os.getcwd(), "Curated test samples")
# example_list_img_names = os.listdir(example_list_dir)
example_list_img_names = os.listdir(os.getcwd())
CKPT_ROOT = os.getcwd()

example_list = [
    [configs["DEFAULT_BACKBONE"], os.path.join(os.getcwd(), example_img)]
    for example_img in example_list_img_names
    if example_img.lower().endswith(".png")]

# example_list = [['/content/new_labels.csv',"ResNet50"]]

gradio_app = gradio.Interface(
    fn     = partial(run_diagnosis, preprocess_fn = preprocess_fxn, Idx2labels = labels_dict, threshold = configs["THRESHOLD"]),

    # inputs = [gradio.Dropdown(["ConvNeXt(small)", "ConvNeXt(tiny)", "EfficientNet(v2_small)", "EfficientNet(b3)", "RegNet(x3_2GF)","ResNet50"], value="EfficientNet(b3)", label="Select Backbone Model"),
    #           gradio.Image(type="pil", label="Load chest-X-ray image here")],
    inputs = [gradio.Dropdown(["CheXFormer-small", "ViT-base-16"], value="ViT-base-16", label="Select Backbone Model"),
              gradio.Image(type="pil", label="Load chest-X-ray image here")],

    outputs = [gradio.Textbox(label="Predicted Medical Conditions"),
             gradio.Label(label="Predicted Probabilities", show_label=False)],

    examples       = example_list,
    cache_examples = True,
    title          = "ChestVision",
    description    = "Vision-Transformer solutions for assistive medical diagnosis with Vision-Language-based prediction justification",
    article        = "Author: C. Foli (02.2026) | Website: coming soon...")

gradio_app.launch()


Caching examples at: '/content/.gradio/cached_examples/497'
It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://a01ce5b19a6adc4294.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


