In [28]:
import supervision as sv
import numpy as np
import cv2
import gradio as gr
import json

In [29]:
from ultralytics import YOLO
import matplotlib.pyplot as plt
from PIL import Image

In [30]:
# YOLO
print("Loading YOLO model...")
try:
    # Ensure your custom .pt file is in the same directory
    yolo_model = YOLO('YOLOv12_trained_weights.pt') 
    print("Model loaded successfully.")
    print("Loaded class names:", yolo_model.names)
except Exception as e:
    print(f"Error loading model: {e}")
    yolo_model = None

# --- 2. Analysis Function ---
def analyze_segmentation_mask(mask, microns_per_pixel, number_of_detections):
    """
    Analyzes a binary segmentation mask to count objects and calculate area density.
    """
    if mask.ndim > 2 and mask.shape[2] == 3:
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)

    _, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    binary_mask = binary_mask.astype(np.uint8)

    object_count = number_of_detections

    if object_count == 0:
        total_area_sq_microns = (mask.shape[0] * mask.shape[1]) * (microns_per_pixel ** 2)
        return {
            "Image Dimensions": f"{mask.shape[1]}x{mask.shape[0]}",
            "Stomata Count": "0",
            "Density Percentage": "0.00%",
            "Total Image Area (sq microns)": f"{total_area_sq_microns:.2f}"
        }

    total_image_pixels = mask.shape[0] * mask.shape[1]
    object_pixels_count = cv2.countNonZero(binary_mask)
    
    area_per_pixel_sq_microns = microns_per_pixel ** 2
    average_object_area_sq_microns = (object_pixels_count * area_per_pixel_sq_microns) / object_count
    total_image_area_sq_microns = total_image_pixels * area_per_pixel_sq_microns
    
    density = (object_pixels_count / total_image_pixels) * 100

    return {
        "Image Dimensions": f"{mask.shape[1]}x{mask.shape[0]}",
        "Stomata Count": str(object_count),
        "Average Stomata Area (sq microns)": f"{average_object_area_sq_microns:.2f}",
        "Total Image Area (sq microns)": f"{total_image_area_sq_microns:.2f}",
        "Density Percentage": f"{density:.2f}%",
        "Stomatal Density": f"{object_count} per {total_image_area_sq_microns:.2f} sq microns",
        "Microns per Pixel": str(microns_per_pixel)
    }

Loading YOLO model...
Model loaded successfully.
Loaded class names: {0: 'stomata'}


In [31]:
# RF-DETR
try:
    from rfdetr import RFDETRMedium
    # Load the weights for Component Body
    model_a = RFDETRMedium(pretrain_weights="./RF_DETR_trained_weights.pth")    
    
except ImportError:
    print("rfdetr library not found. Please install it to run the models.")
    model_a = None

except Exception as e:
    print(f"Error loading model weights: {e}")
    print("A mock model will be used for demonstration purposes.")
    model_a = None


# Use a nested dictionary for class names for each model
CLASS_NAMES_MAP = {
    "stomata": {
        0: 'stomata'
        # Add more classes for Hole if needed
    }    
}

# Initialize annotators
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
dot_annotator = sv.DotAnnotator(radius=5, color=sv.Color.RED)


# --- New Function for Mask Analysis ---
def analyze_segmentation_mask(mask, microns_per_pixel, number_of_detections):
    """
    Analyzes a binary segmentation mask to count objects and calculate area density.
    All returned values in the dictionary are strings for Gradio compatibility.
    """
    if mask is None:
        return {
            "Error": "No image or mask provided for analysis."
        }
    
    # Ensure mask is grayscale if it has color channels
    if mask.ndim > 2 and mask.shape[2] == 3:
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    elif mask.ndim == 3 and mask.shape[2] == 1: # If it's a single channel but still 3D
        mask = mask.squeeze()

    # If the mask is all zeros (e.g., no detections), handle it gracefully
    if np.all(mask == 0) and number_of_detections == 0:
        total_image_area_sq_microns = (mask.shape[0] * mask.shape[1]) * (microns_per_pixel ** 2)
        return {
            "Image Dimensions": f"{mask.shape[1]}x{mask.shape[0]}",
            "Stomata Count": str(0),
            "Average Stomata Area (sq microns)": "N/A",
            "Total Image Area (sq microns)": f"{total_image_area_sq_microns:.2f}",
            "Density Percentage": "0.00%",
            "Stomatal Density": f"0 per {total_image_area_sq_microns:.2f} sq microns",
            "Microns per Pixel": str(microns_per_pixel)
        }

    _, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    binary_mask = binary_mask.astype(np.uint8)

    object_count = number_of_detections

    if object_count == 0: # This case is already handled above for all-zero masks, but good to double check
        total_image_area_sq_microns = (mask.shape[0] * mask.shape[1]) * (microns_per_pixel ** 2)
        return {
            "Image Dimensions": f"{mask.shape[1]}x{mask.shape[0]}",
            "Stomata Count": str(0),
            "Average Stomata Area (sq microns)": "N/A",
            "Total Image Area (sq microns)": f"{total_image_area_sq_microns:.2f}",
            "Density Percentage": "0.00%",
            "Stomatal Density": f"0 per {total_image_area_sq_microns:.2f} sq microns",
            "Microns per Pixel": str(microns_per_pixel)
        }

    total_image_pixels = mask.shape[0] * mask.shape[1]
    object_pixels_count = cv2.countNonZero(binary_mask)
    
    area_per_pixel_sq_microns = microns_per_pixel ** 2
    
    # Avoid division by zero if object_count is zero but object_pixels_count is not
    if object_count > 0:
        average_object_area_sq_microns = (object_pixels_count * area_per_pixel_sq_microns) / object_count
    else:
        average_object_area_sq_microns = 0.0 # Or "N/A" if preferred for display
        
    total_image_area_sq_microns = total_image_pixels * area_per_pixel_sq_microns
    
    density = (object_pixels_count / total_image_pixels) * 100

    results = {
        "Image Dimensions": f"{mask.shape[1]}x{mask.shape[0]}",
        "Stomata Count": str(object_count),
        "Average Stomata Area (sq microns)": f"{average_object_area_sq_microns:.2f}",
        "Total Image Area (sq microns)": f"{total_image_area_sq_microns:.2f}",
        "Density Percentage": f"{density:.2f}%",
        "Stomatal Density": f"{object_count} per {total_image_area_sq_microns:.2f} sq microns",
        "Microns per Pixel": str(microns_per_pixel)
    }
    return results

Using a different number of positional encodings than DINOv2, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model.
Using patch size 16 instead of 14, which means we're not loading DINOv2 backbone weights. This is not a problem if finetuning a pretrained RF-DETR model.
Loading pretrain weights


num_classes mismatch: pretrain weights has 0 classes, but your model has 90 classes
reinitializing detection head with 0 classes


Error loading model weights: Error(s) in loading state_dict for LWDETR:
	size mismatch for backbone.0.encoder.encoder.embeddings.position_embeddings: copying a param with shape torch.Size([1, 1025, 384]) from checkpoint, the shape in current model is torch.Size([1, 1297, 384]).
A mock model will be used for demonstration purposes.


In [32]:
# --- 3. Run Inference & Display ---
IMAGE_PATH = 'Rice_40x_447.jpg'  
MICRONS_PER_PIXEL = 0.5            

# --- 3. Processing & Export Logic ---
def process_and_zip(input_image, microns_per_pixel):
    if yolo_model is None:
        return None, {"Error": "Model not loaded"}, None

    # 1. Run Inference
    results = yolo_model.predict(input_image, conf=0.25)[0]
    detections = sv.Detections.from_ultralytics(results)

    # 2. Process Masks & Stats
    if detections.mask is not None:
        combined_mask = np.any(detections.mask, axis=0).astype(np.uint8) * 255
        stats = analyze_segmentation_mask(combined_mask, microns_per_pixel, len(detections))
    else:
        stats = {"Stomata Count": "0", "Message": "No stomata detected"}

    # 3. Annotate Image
    mask_annotator = sv.MaskAnnotator()
    label_annotator = sv.LabelAnnotator()
    annotated_img = mask_annotator.annotate(scene=input_image.copy(), detections=detections)
    annotated_img = label_annotator.annotate(scene=annotated_img, detections=detections)

    # 4. Create ZIP Archive for download
    temp_dir = tempfile.mkdtemp()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Save Image
    img_path = os.path.join(temp_dir, f"annotated_{timestamp}.jpg")
    cv2.imwrite(img_path, cv2.cvtColor(annotated_img, cv2.COLOR_RGB2BGR))
    
    # Save CSV
    csv_path = os.path.join(temp_dir, f"stats_{timestamp}.csv")
    with open(csv_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(stats.keys())
        writer.writerow(stats.values())

    # Create Zip
    zip_filename = f"stomata_results_{timestamp}"
    zip_path = shutil.make_archive(os.path.join(temp_dir, zip_filename), 'zip', temp_dir)

    return annotated_img, stats, zip_path



In [33]:
# --- 4. Gradio Interface ---
with gr.Blocks(title="Stomata Analyzer") as demo:
    gr.Markdown("# ðŸŒ¿ Stomata Segmentation & Density Analysis")
    gr.Markdown("Upload a leaf micrograph to detect stomata, calculate density, and download results.")
    
    with gr.Row():
        with gr.Column():
            input_img = gr.Image(type="numpy", label="Input Image")
            scale_input = gr.Number(value=0.5, label="Microns per Pixel")
            btn = gr.Button("Analyze Stomata", variant="primary")
            
        with gr.Column():
            output_img = gr.Image(label="Detections (Masks & Labels)")
            output_json = gr.JSON(label="Analysis Results")
            output_file = gr.File(label="Download ZIP (Image + CSV)")

    btn.click(
        fn=process_and_zip,
        inputs=[input_img, scale_input],
        outputs=[output_img, output_json, output_file]
    )

# Launch (in Jupyter, inline=True is usually default)
demo.launch(inline=True, share=False)

* Running on local URL:  http://127.0.0.1:7861
* To create a public link, set `share=True` in `launch()`.


