<a href="https://colab.research.google.com/github/elliemci/vision-transformer-models/blob/main/model_deployment/app.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Deploying a brain tumor segmentation on MRI image App

In [None]:
!pip install --upgrade gradio

In [None]:
!pip install transformers

In [None]:
!pip install gradio

In [4]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/ColabNotebooks/ExplainableAI/model_deployment

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/ColabNotebooks/ExplainableAI/model_deployment


In [11]:
import torch
import gradio as gr

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor
from PIL import Image

# can upload from Huggingface Space "elliemci/maskformer_tumor_segmentation"
model = MaskFormerForInstanceSegmentation.from_pretrained("elliemci/maskformer_tumor_segmentation")
image_processor = MaskFormerImageProcessor.from_pretrained("elliemci/maskformer_tumor_segmentation")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define a custom dataset class to handle images
class ImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')

        if self.transform:
            image = self.transform(image)
        return image


def segment(image_files):
    """Takes a list of UploadedFile objects and returns a list of segmented images."""

    dataset = ImageDataset(image_files, transform=transforms.ToTensor())
    dataloader = DataLoader(dataset, batch_size=len(image_files), shuffle=False)  # Batch size is the number of images

    # process a batch
    with torch.no_grad():
        for batch in dataloader:  # Only one iteration since batch_size = len(image_files)
            pixel_values = batch.to(device, dtype=torch.float32)
            outputs = model(pixel_values=pixel_values)

            # Post-processing
            original_images = outputs.get("org_images", batch)
            target_sizes = [(image.shape[-2], image.shape[-1]) for image in original_images]
            predicted_masks = image_processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)

            return predicted_masks


def update_gallery(images):
    print(f"Images received: {images}")

    gallery_data = []
    if images:
        segmented_images = segment(images)  # Process images

        for i, image_path in enumerate(images):
            try:
                image = Image.open(image_path).convert("RGB")  # Load original image

                segmented_mask = segmented_images[i].to(dtype=torch.float32, device="cpu")

                segmented_image_pil = transforms.ToPILImage()(segmented_mask)  # Convert to PIL Image

                gallery_data.extend([(image, "Original Image"), (segmented_image_pil, "Segmented Image")])
            except Exception as e:
                print(f"Error processing image {i}: {e}")
                gallery_data.extend([(image, "Original Image"), (image, f"Error: {e}")])

    return gallery_data


with gr.Blocks() as demo:
  gr.Markdown("<h1 style='text-align: center;'>MRI Brain Tumor Segmentation App</h1>")

  with gr.Column():
    with gr.Column():
      image_files = gr.Files(label="Upload MRI files",
                                     file_count="multiple",
                                     type="filepath")
      with gr.Row():
        gallery = gr.Gallery(label="Brain Images and Tumor Segmentation")

        image_files.change(
             fn=update_gallery,
             inputs=[image_files],
             outputs=[gallery])

      with gr.Column():
        example_image = gr.Image(type="filepath", label="MRI Image", visible=False)
        examples = gr.Examples(examples=["examples/Te-me_0194.jpg", "examples/Te-me_0111.jpg",
                                         "examples/Te-me_0295.jpg", "examples/Te-me_0228.jpg",
                                         "examples/Te-me_0218.jpg", "examples/Te-me_0164.jpg"],
                               inputs=[example_image])

      with gr.Column(scale=0):

        example_button = gr.Button("Process Example Image", variant="secondary")
        example_button.click(
            fn=lambda img: update_gallery([img]) if img else [],
            inputs=[example_image],
            outputs=[gallery]
        )

demo.launch(debug=True)

  return func(*args, **kwargs)


Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://2543c7837e632ceab6.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Images received: ['/tmp/gradio/7930ef0afa88b9b32b074b4a1f2a0ca3d42a90afbb8d14d6da1cdbdf9c3ecf7b/Te-me_0218.jpeg']
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://2543c7837e632ceab6.gradio.live


