<a href="https://colab.research.google.com/github/eWizardII/dropletcode/blob/dev/Droplet_Volumetric_Calculation_Shareable_Version.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mask R-CNN Image Segmentation

## Overview

This Jupyter notebook showcases how to use the Mask R-CNN model for image segmentation. Mask R-CNN is a state-of-the-art model that segments objects in images. This demonstration takes an image as input, applies the Mask R-CNN model to detect objects, and then segments the object with the highest confidence score from the background.

## Table of Contents

1. **Git Setup and Directory Initialization**
2. **Importing Necessary Libraries**
3. **Configuring Mask R-CNN for Inference**
4. **Loading Pre-trained Weights**
5. **Image Visualization Helpers**
6. **Image Pre-processing**
7. **Model Inference and Results Visualization**
8. **Image Segmentation Helper**
9. **Displaying Segmented Results**

In [None]:
import logging
from IPython.display import Image

# Initialize logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')


def clone_repo(repo_url: str) -> None:
    """
    Clone a given GitHub repository.

    Args:
        repo_url (str): URL of the GitHub repository to clone.
    """
    try:
        # Note: In a production setting, avoid using shell commands directly for security reasons.
        # This is just an example, and there are better methods to clone repos programmatically.
        get_ipython().system(f'git clone {repo_url}')
        logging.info(f"Successfully cloned {repo_url}")
    except Exception as e:
        logging.error(f"Error cloning the repo {repo_url}. Reason: {str(e)}")


def display_image_from_file(filename: str) -> None:
    """
    Display an image given its filename.

    Args:
        filename (str): Path to the image file.
    """
    try:
        display(Image(filename))
    except Exception as e:
        logging.error(f"Error displaying the image from {filename}. Reason: {str(e)}")


def capture_and_display_photo() -> None:
    """
    Capture a photo and display it.
    """
    try:
        filename = take_photo()  # Assuming `take_photo` is defined elsewhere in your code.
        logging.info(f'Saved to {filename}')
        display_image_from_file(filename)
    except Exception as e:
        logging.error(f"Error capturing the photo. Reason: {str(e)}")


# Execution
repo_url = "https://github.com/matterport/Mask_RCNN"
clone_repo(repo_url)
capture_and_display_photo()

In [None]:
%%shell
# clone Mask_RCNN repo and install packages
cd Mask_RCNN
python setup.py install

In [None]:
# Basic imports
import os
import sys
import random
import numpy as np
import colorsys
import argparse
import imutils
import cv2
from matplotlib import pyplot
from matplotlib.patches import Rectangle

# For in-notebook visualization
%matplotlib inline

# Import Mask RCNN related packages
!pip install ./Mask_RCNN
from mrcnn.config import Config
from mrcnn import model as modellib
from mrcnn import visualize
import mrcnn

# Setting up the root directory for Mask_RCNN and appending it to the sys path
ROOT_DIR = os.path.abspath("../")
sys.path.append(ROOT_DIR)

In [None]:
import os
import logging
from mrcnn.config import Config
from mrcnn import model as modellib, utils

# Initialize logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

# Constants
ROOT_DIR = os.path.abspath("../")
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")

class MyMaskRCNNConfig(Config):
    """Configuration for Mask R-CNN model."""
    # Give the configuration a recognizable name
    NAME = "MaskRCNN_inference"

    # Set the number of GPUs to use along with the number of images per GPU
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

    # Number of classes (including the background class)
    NUM_CLASSES = 1 + 80  # background + 80 classes

def setup_mask_rcnn_model(config: Config) -> "MaskRCNN":
    """Initialize and return Mask R-CNN model."""
    logging.info("Setting up Mask R-CNN model...")
    model = modellib.MaskRCNN(mode="inference", config=config, model_dir='./')
    return model

def load_model_weights(model: "MaskRCNN", model_path: str) -> None:
    """Load trained weights into the model."""
    # Check if weights exist, if not, download
    if not os.path.exists(model_path):
        logging.info(f"{model_path} not found. Downloading...")
        utils.download_trained_weights(model_path)

    # Load weights
    try:
        model.load_weights(model_path, by_name=True)
        logging.info(f"Loaded weights from {model_path}")
    except Exception as e:
        logging.error(f"Error loading weights: {str(e)}")

if __name__ == "__main__":
    config = MyMaskRCNNConfig()
    model = setup_mask_rcnn_model(config)
    load_model_weights(model, COCO_MODEL_PATH)

    class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
 'bus', 'train', 'truck', 'boat', 'traffic light',
 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
 'kite', 'baseball bat', 'baseball glove', 'skateboard',
 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
 'teddy bear', 'hair drier', 'toothbrush']

In [None]:
from matplotlib import pyplot
from matplotlib.patches import Rectangle
from typing import List, Tuple

def draw_image_with_boxes(filename: str, boxes_list: List[Tuple[int, int, int, int]]) -> None:
    """
    Draws an image with bounding boxes around detected objects.

    Args:
        filename (str): Path to the image file.
        boxes_list (List[Tuple[int, int, int, int]]): List of bounding box coordinates in the form (y1, x1, y2, x2).
    """
    try:
        # Load the image
        data = pyplot.imread(filename)
        # Plot the image
        pyplot.imshow(data)
        # Get the context for drawing boxes
        ax = pyplot.gca()

        # Draw each bounding box on the image
        _draw_boxes_on_ax(ax, boxes_list)

        # Show the plot
        pyplot.show()

    except Exception as e:
        logging.error(f"Error drawing image with boxes. Reason: {str(e)}")

def _draw_boxes_on_ax(ax: "pyplot.Axes", boxes_list: List[Tuple[int, int, int, int]]) -> None:
    """Helper function to draw bounding boxes on a given Axes instance."""
    for box in boxes_list:
        # Get coordinates
        y1, x1, y2, x2 = box
        # Calculate width and height of the box
        width, height = x2 - x1, y2 - y1
        # Create the shape
        rect = Rectangle((x1, y1), width, height, fill=False, color='red', lw=5)
        # Draw the box
        ax.add_patch(rect)

In [None]:
from typing import Union
from keras.preprocessing.image import load_img, img_to_array
from matplotlib import pyplot as plt

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')


def load_and_display_image(image_path: str) -> Union[None, np.ndarray]:
    """
    Loads an image from the specified path, displays it using matplotlib, and returns it as a numpy array.

    Args:
        image_path (str): Path to the image file.

    Returns:
        np.ndarray: Image in numpy array format.
    """
    try:
        # Load the image using Keras
        img = load_img(image_path)

        # Display the image
        plt.imshow(img)
        plt.show()

        # Convert image to numpy array format for further processing
        return img_to_array(img)

    except Exception as e:
        logging.error(f"Error loading or displaying image from {image_path}. Reason: {str(e)}")
        return None


if __name__ == "__main__":
    image_path = '/content/images/image001.png'
    img_array = load_and_display_image(image_path)
    if img_array is not None:
        logging.info(f"Loaded image of shape {img_array.shape}")

In [None]:
from typing import List, Tuple
from mrcnn import model as modellib

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def predict_and_visualize(model: modellib.MaskRCNN, img: np.ndarray, image_path: str) -> None:
    """
    Predict bounding boxes using the model and visualize them on the original image.

    Args:
        model (modellib.MaskRCNN): Pre-trained Mask R-CNN model.
        img (np.ndarray): Input image in numpy array format.
        image_path (str): Path to the original image file for visualization.
    """
    try:
        # Make a prediction using the model
        results = model.detect([img], verbose=0)

        # If results are obtained, visualize them on the image
        if results and 'rois' in results[0]:
            draw_image_with_boxes(image_path, results[0]['rois'])
        else:
            logging.warning(f"No results obtained for the image at {image_path} or 'rois' not in the results.")

    except Exception as e:
        logging.error(f"Error predicting or visualizing results for image at {image_path}. Reason: {str(e)}")


if __name__ == "__main__":
    image_path = '/content/images/image001.png'
    # Ensure that `img` and `model` are defined and loaded before calling the function
    predict_and_visualize(model, img, image_path)

In [None]:
from mrcnn.visualize import display_instances
from typing import Dict, List

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def display_image_predictions(image: np.ndarray, results: Dict[str, np.ndarray], class_names: List[str]) -> None:
    """
    Display the image with bounding boxes, masks, class labels, and scores.

    Args:
        image (np.ndarray): Input image in numpy array format.
        results (Dict[str, np.ndarray]): Predicted results dictionary from Mask R-CNN.
        class_names (List[str]): List of class names for labeling.
    """
    try:
        # Check if necessary keys are in the results
        if all(key in results for key in ['rois', 'masks', 'class_ids', 'scores']):
            display_instances(
                image,
                results['rois'],
                results['masks'],
                results['class_ids'],
                class_names,
                results['scores']
            )
        else:
            logging.warning("The provided results dictionary doesn't have all the necessary keys for visualization.")

    except Exception as e:
        logging.error(f"Error displaying image predictions. Reason: {str(e)}")

if __name__ == "__main__":
    # Ensure that `img`, `results`, and `class_names` are defined and available before calling the function
    display_image_predictions(img, results[0], class_names)

In [None]:
import numpy as np
import logging
import matplotlib.pyplot as plt
from typing import Dict, Union

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def calculate_mask_sum(mask_array: np.ndarray) -> float:
    """
    Calculate the sum of the mask values after reshaping.

    Args:
        mask_array (np.ndarray): The mask array to be reshaped and summed.

    Returns:
        float: The sum of the mask values.
    """
    return np.reshape(mask_array, (-1, mask_array.shape[-1])).astype(np.float32).sum()

def segment_image(image: np.ndarray, results: Dict[str, Union[np.ndarray, float]]) -> np.ndarray:
    """
    Segments the given image based on the highest scoring result.

    Args:
        image (np.ndarray): Image to be segmented.
        results (Dict[str, Union[np.ndarray, float]]): Prediction results from the model.

    Returns:
        np.ndarray: The segmented image.
    """
    try:
        idx = results['scores'].argmax()
        mask = results['masks'][:,:,idx]
        mask = np.stack((mask,)*3, axis=-1).astype('uint8')
        bg = 255 - mask * 255
        mask_img = image * mask
        return mask_img + bg
    except Exception as e:
        logging.error(f"Error segmenting the image. Reason: {str(e)}")
        return image  # Return the original image in case of error.

def display_image(image: np.ndarray, segmentation: np.ndarray) -> None:
    """
    Display the original and segmented images side by side.

    Args:
        image (np.ndarray): Original image.
        segmentation (np.ndarray): Segmented version of the original image.
    """
    plt.subplots(1, figsize=(16, 16))
    plt.axis('off')
    plt.imshow(np.concatenate([image, segmentation], axis=1))
    plt.show()

if __name__ == "__main__":
    # Assuming r is defined and img is loaded
    mask_values = calculate_mask_sum(r['masks'])
    logging.info(f"Mask Values Sum: {mask_values}")

    for idx in range(r['masks'].shape[-1]):
        mask_sum = np.reshape(r['masks'], (-1, r['masks'].shape[-1]))[:, idx].astype(np.float32).sum()
        logging.info(f"Mask Sum for Index {idx}: {mask_sum}")

    segmented = segment_image(img, r)
    display_image(img, segmented)