<a href="https://colab.research.google.com/github/elliemci/vision-transformer-models/blob/main/model_deployment/mri_classification_app.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Deploying a MRI classification app with Huggingface and Gradio

In [None]:
!pip install --upgrade gradio

In [None]:
!pip install transformers

In [3]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/ColabNotebooks/ExplainableAI/model_deployment

Mounted at /content/drive
/content/drive/MyDrive/ColabNotebooks/ExplainableAI/model_deployment


## 1. Create a Gradio App that exposes the MRI brain tumor classification model

In [4]:
import gradio as gr

def update_gallery(images):
    """Update the gallery with the uploaded images"""
    return images

# create gradio instance demo
with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center;'>MRI Brain Tumor Classification App</h1>")
    with gr.Column():
        image_input = gr.Files(label="Upload MRI Images",
                              file_count="multiple",
                              type="filepath")

        gallery = gr.Gallery(label="MRI Brain Images")

    # set up an event listener to update the gallery images when image_input changes
    image_input.change(
        fn=update_gallery,
        inputs=[image_input],
        outputs=[gallery])

# launch the app enabled to be debugged
demo.launch(debug=True)

Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://de2911cb3b502f8d15.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://de2911cb3b502f8d15.gradio.live




## 2. Gradio UI for MRI classification

### Load Vision Transformer model trained for MRI classification and image processor

In [None]:
from transformers import ViTImageProcessor, ViTForImageClassification

model = ViTForImageClassification.from_pretrained("elliemci/vit_tumor_classification_model")
image_processor = ViTImageProcessor.from_pretrained("elliemci/vit_tumor_classification_model")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

### Immage class prediction function

In [11]:
import torch
from torch.utils.data import DataLoader

def predict(images):

  batch_size = len([images]) if len(images) <= 8 else 8

  # create the data loader for all input images
  inputs_loader = DataLoader(images, batch_size=batch_size, shuffle=False)

  batch_size = len([images]) if len(images) <= 8 else 8

  # create the data loader for all input images
  inputs_loader = DataLoader(images, batch_size=batch_size, shuffle=False)

  predictions = []
  # set the model to evaluation mode
  model.eval()

  # disable gradient calculation during inference
  with torch.no_grad():
      for batch in inputs_loader:
        # preprocess the batch of images using the feature extractor
        inputs = image_processor(images=batch['image'], return_tensors="pt").to(device)
        inputs = batch['image'].to(device)
        # with dictionary unpacking operator on preprocess data, containt pizel_values and labels
        outputs = model(**inputs)
        # outputs = model(pixel_values=inputs, labels=labels)
        _, preds = torch.max(outputs.logits, 1)

        # add individual prediction iterating over the list
        predictions.extend(preds.cpu().numpy())

  return predictions # {'tumor':str(predictions), 'no-tumor':str(1-predictions)}

### Gradio UI for uploading image, making and displaying the prediction

In [None]:
import os
import torch

from PIL import Image

def predict_image(image_path):
  """ Takes the path to an image and returns the prediction of the model """

  try:
    # open the image and convert to RGB format
    image = Image.open(image_path).convert('RGB')
    # pre-process the image with the pre-trained image processor
    input = image_processor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
      outputs = model(**input)
      # extract the row outputs
      logits = outputs.logits
      predicted_class = logits.argmax(-1).item()
      prediction = "Tumor" if predicted_class == 1 else "No Tumor"

      return prediction

  except Exception as e:
    return f"Error processing image: {e}"


def update_gallery(images):
  """ Process the uploaded images and returns a list of image paths along with
      their predictions. """

  predictions = []
  image_paths = []

  if images:
    for image in images:
      # extract the file path
      image_path = image.name
      image_paths.append(image_path)
      # predict and store
      prediction = predict_image(image_path)
      predictions.append(prediction)

  # gallery expects the input in the format (image_path, prediction) for every image,
  # combine image_paths and predictions into a list of tuples (image_path, prediction)
  image_paths_with_predictions = list(zip(image_paths, predictions))

  return image_paths_with_predictions

### Gradio Interface

In [17]:
# create and lauch Gradio web interface
with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center;'>MRI Brain Tumor Classification App</h1>")

    with gr.Column():
      # create components
      image_input = gr.Files(label="Upload MRI Images",
                            file_count="multiple",
                            type="filepath")
      gallery = gr.Gallery(label="MRI Brain Images with Predictions")

    # set up an event listener
    image_input.change(
        fn=update_gallery,
        inputs=[image_input],
        outputs=[gallery]
    )
    # launch the app
    demo.launch(debug=True)

Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://dc20b644505a65352f.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://dc20b644505a65352f.gradio.live


## 3. Gradio UI for MEI segmentation

### Image segmentation prediction function

### Gradio UI for uploading image and making a prediction

### Gradio components to display segmentation result

## 4. Gradio Interface

### App contents

#### Text section

#### Gallery section

#### Image analysis section

### App