In [11]:
!pip install gradio



In [12]:
import torchvision

num_classes = 3

model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=None)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(256, num_classes, 28)

In [13]:
import gradio as gr
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
import torchvision


# Define the image transformation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load("mask_rcnn_model.pth", map_location=device))
model.to(device)
model.eval()

# Define the prediction function
def predict(image):
    image = Image.fromarray(image).convert("RGB")
    original_image = np.array(image)

    # Transform the image
    transformed_image = transform(image).unsqueeze(0).to(device)

    # Perform inference
    with torch.no_grad():
        outputs = model(transformed_image)[0]

    # Overlay masks on the original image
    for i, mask in enumerate(outputs['masks']):
        score = outputs['scores'][i].item()
        if score > 0.5:  # Confidence threshold
            mask = mask[0].cpu().numpy() > 0.5
            color = np.random.rand(3)
            for c in range(3):  # Apply mask to all RGB channels
                original_image[:, :, c] = np.where(mask, original_image[:, :, c] * 0.5 + color[c] * 255 * 0.5, original_image[:, :, c])

    return original_image

# Define the Gradio interface
gr.Interface(
    fn=predict,
    inputs=gr.Image(type="numpy", label="Upload Image"),
    outputs=gr.Image(type="numpy", label="Predicted Image"),
    title="Mask R-CNN Object Detection",
    description="Upload an image to detect and segment objects using Mask R-CNN model."
).launch()


  model.load_state_dict(torch.load("mask_rcnn_model.pth", map_location=device))


Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://b07a7e032bdfc4c094.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


