<a href="https://colab.research.google.com/github/nabincool20/colab_glaucoma/blob/main/Models_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
import pandas as pd
from PIL import Image
import gradio as gr
from torchvision import transforms, models
from torchvision.models import resnet18
from transformers import ViTForImageClassification, ViTImageProcessor
import torch.nn as nn
import torch.nn.functional as F
import timm

# ====== Paths to your saved models ======
MODEL_PATHS = {
    "resnet_fundus_acrima": "/content/drive/MyDrive/trained_models_all/trained_resnet_models/acrima_resnet18.pth",
    "resnet_fundus_origa": "/content/drive/MyDrive/trained_models_all/trained_resnet_models/origa_resnet18.pth",
    "resnet_fundus_rimone": "/content/drive/MyDrive/trained_models_all/trained_resnet_models/rimone_resnet18.pth",
    "vit_fundus_acrima": "/content/drive/MyDrive/trained_models_all/trained_vit_models/acrima_label_vit.pth",
    "vit_fundus_origa": "/content/drive/MyDrive/trained_models_all/trained_vit_models/origa_label_vit.pth",
    "vit_fundus_rimone": "/content/drive/MyDrive/trained_models_all/trained_vit_models/rimone_label_vit.pth",
    "resnet_oct": "/content/drive/MyDrive/trained_models_all/resnet_oct.pth",
    "vit_oct": "/content/drive/MyDrive/trained_models_all/vit_oct.pth",
}

# ====== Load ResNet18 model function ======
def load_resnet_model(path, model_key):
    model = resnet18(pretrained=False)

    if "origa" in model_key:
        model.fc = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(model.fc.in_features, 2)
        )
    elif "rimone" in model_key:
        model.fc = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(model.fc.in_features, 2)
        )
    elif "acrima" in model_key:
        model.fc = nn.Linear(model.fc.in_features, 2)
    else:
        model.fc = nn.Linear(model.fc.in_features, 2)  # default fallback

    model.load_state_dict(torch.load(path, map_location=torch.device("cpu")))
    model.eval()
    return model

# ====== Load ViT model function ======


def load_vit_model(path):
    model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=2)
    model.load_state_dict(torch.load(path, map_location=torch.device("cpu")))
    model.eval()
    return model

# ====== Image transforms ======
resnet_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * 3, [0.5] * 3)
])

vit_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

# ====== Prediction function for each model type ======
def predict_resnet(image, model):
    image_tensor = resnet_transform(image).unsqueeze(0)
    with torch.no_grad():
        outputs = model(image_tensor)
        probs = F.softmax(outputs, dim=1)
        pred = torch.argmax(probs, dim=1).item()
    return "Glaucoma" if pred == 0 else "Non_Glaucoma"

def predict_vit(image, model):
    image_tensor = vit_transform(image).unsqueeze(0)
    with torch.no_grad():
        outputs = model(image_tensor)
        probs = F.softmax(outputs, dim=1)
        pred = torch.argmax(probs, dim=1).item()
    return "Glaucoma" if pred == 0 else "Non_Glaucoma"

# ====== Unified prediction wrapper ======
MODELS = {
    key: load_resnet_model(path, key) if 'resnet' in key else load_vit_model(path)
    for key, path in MODEL_PATHS.items()
}

PREDICTORS = {
    key: predict_resnet if 'resnet' in key else predict_vit
    for key in MODELS
}

# ====== Function to process a folder of images with all models ======
def process_all_models(folder_path):
    results = []
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.lower().endswith((".jpg", ".jpeg", ".png")):
                file_path = os.path.join(root, file)
                try:
                    image = Image.open(file_path).convert("RGB")
                    row = {"filename": os.path.relpath(file_path, folder_path)}
                    for model_name, model in MODELS.items():
                        predictor = PREDICTORS[model_name]
                        row[model_name] = predictor(image, model)
                    results.append(row)
                except Exception as e:
                    results.append({"filename": file_path, "error": str(e)})

    df = pd.DataFrame(results)
    output_excel = os.path.join(folder_path, "combined_model_predictions.xlsx")
    df.to_excel(output_excel, index=False)
    return output_excel

# ====== Gradio Interface ======
with gr.Blocks() as demo:
    gr.Markdown("## 🧠 Unified Glaucoma Detection Tool")

    with gr.Tab("1️⃣ Upload Single/Multiple Images"):
        with gr.Row():
            file_input = gr.File(label="Upload Image(s)", file_types=[".jpg", ".png", ".jpeg"], file_count="multiple")
            model_choice_1 = gr.Dropdown(choices=list(MODELS.keys()), label="Select Model")
        output_text_1 = gr.Textbox(label="Predictions")

        def predict_uploaded_images(files, model_key):
            if not files:
                return "No images uploaded."
            model = MODELS[model_key]
            predictor = PREDICTORS[model_key]
            results = []
            for file in files:
                try:
                    image = Image.open(file.name).convert("RGB")
                    pred = predictor(image, model)
                    results.append(f"{os.path.basename(file.name)} → {pred}")
                except Exception as e:
                    results.append(f"{os.path.basename(file.name)} → Error: {str(e)}")
            return "\n".join(results)

        predict_btn_1 = gr.Button("Run Prediction")
        predict_btn_1.click(predict_uploaded_images, inputs=[file_input, model_choice_1], outputs=output_text_1)

    with gr.Tab("2️⃣ Upload Folder"):
        folder_input = gr.File(label="Upload Folder (ZIP)", file_types=[".zip"])
        model_choice_2 = gr.Dropdown(choices=list(MODELS.keys()), label="Select Model")
        folder_output = gr.File(label="Excel Output")

        def process_uploaded_zip(zip_file, model_key):
            import zipfile
            import tempfile

            extract_dir = tempfile.mkdtemp()
            with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
                zip_ref.extractall(extract_dir)

            model = MODELS[model_key]
            predictor = PREDICTORS[model_key]

            results = []
            for root, _, files in os.walk(extract_dir):
                for file in files:
                    if file.lower().endswith((".jpg", ".jpeg", ".png")):
                        file_path = os.path.join(root, file)
                        try:
                            image = Image.open(file_path).convert("RGB")
                            pred = predictor(image, model)
                            results.append({"filename": os.path.relpath(file_path, extract_dir), "prediction": pred})
                        except Exception as e:
                            results.append({"filename": file, "error": str(e)})

            df = pd.DataFrame(results)
            output_excel = os.path.join(extract_dir, "predictions_from_folder.zip.xlsx")
            df.to_excel(output_excel, index=False)
            return output_excel

        predict_btn_2 = gr.Button("Run Prediction on Folder")
        predict_btn_2.click(process_uploaded_zip, inputs=[folder_input, model_choice_2], outputs=folder_output)

    with gr.Tab("3️⃣ Enter Folder Path (Google Drive/Colab)"):
        path_input = gr.Textbox(label="Enter Folder Path", placeholder="/content/drive/MyDrive/OCT_Split/test")
        model_choice_3 = gr.Dropdown(choices=list(MODELS.keys()), label="Select Model")
        path_output = gr.File(label="Excel Output")

        def process_by_path(folder_path, model_key):
            model = MODELS[model_key]
            predictor = PREDICTORS[model_key]
            results = []
            for root, _, files in os.walk(folder_path):
                for file in files:
                    if file.lower().endswith((".jpg", ".jpeg", ".png")):
                        file_path = os.path.join(root, file)
                        try:
                            image = Image.open(file_path).convert("RGB")
                            pred = predictor(image, model)
                            results.append({"filename": os.path.relpath(file_path, folder_path), "prediction": pred})
                        except Exception as e:
                            results.append({"filename": file, "error": str(e)})
            df = pd.DataFrame(results)
            output_excel = os.path.join(folder_path, "model_predictions_from_path.xlsx")
            df.to_excel(output_excel, index=False)
            return output_excel

        predict_btn_3 = gr.Button("Run Prediction on Path")
        predict_btn_3.click(process_by_path, inputs=[path_input, model_choice_3], outputs=path_output)

demo.launch()





It looks like you are running Gradio on a hosted a Jupyter notebook. For the Gradio app to work, sharing must be enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://337b2fcd760223f716.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


