# Instance segmentation

### Libraries and Variables

In [None]:
# general
import os
import random
import mmcv
import numpy as np
import torch
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# detection
from mmdet.apis import init_detector, inference_detector
from mmdet.utils import register_all_modules
from mmdet.apis import init_detector, inference_detector

# segmentation
from segment_anything import sam_model_registry, SamPredictor
import matplotlib.pyplot as plt

exercise_dir = os.path.dirname(os.path.abspath("__file__"))
danuma_dir = os.path.dirname(os.path.dirname(exercise_dir))
raw_data_dir = os.path.join(danuma_dir, 'data/raw_data')
output_data_dir = os.path.join(danuma_dir, 'data/output_data')

### Overview

Implementing neural networks on your own is an important skill for becoming an applied machine learning researcher. It makes you to think about how neural networks work, which strengthens your understanding of the subject. However, it is **equally important** to be able to quickly modify and combine existing code bases and packages for your purposes. This is the goal of this notebook. You will (1) use existing functionality for pig detection to obtain bounding boxes on unlabeled images from a pig barn, and (2) use these bounding boxes as input to SAM (acronym for "Segment Anything Model") to obtain instance segmentations.

It should be noted that it is not important to fully understand the methods you use. In this exercise, the focus lies on understanding how to use and adapt the functionality provided by other repositories to solve tasks!

### Helper functions

To be used later for plotting bounding boxes, points and instance segmentations.

In [4]:
def show_box(box, ax, random_color=False):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]

    if random_color:
        edge_color = (random.random(), random.random(), random.random())
    else:
        edge_color = 'green'
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=edge_color, facecolor=(0,0,0,0), lw=2))

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
      
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

### 1. Pig detection

For pig detection, we will rely on a pig detection repository that is part of ongoing research in the DaNuMa project: https://github.com/jonaden94/PigDetect. \
When you open the repository in a browser, you will see the ReadMe right under the repository's directory structure. This ReadMe usually contains all the information on how to use the repository (e.g. setup, download of relevant files, reference to demo notebooks). In practice, you would first have to clone the repository yourself, install the packages of interest (setup) and download the pretrained model weights. For the sake of this exercise, the repository is already cloned (``repos/PigDetect``) and all necessary installations and downloads have been performed. So don't worry about it at this point! The object detection model we will use is called "codino". A link to the paper is in the "Further reads" section in case you are interested (not relevant for the exercise!).

The demo notebook in the PigDetect repository (https://github.com/jonaden94/PigDetect/blob/main/tools/inference/inference_demo.ipynb) provides code to use pretrained models for pig detection. Inspect the notebook and copy/modify the relevant code from it:
1. Initialize the codino model. You will need to adjust ``config_path`` and ``checkpoint_path``. The config is located at ``repos/PigDetect/configs/co-detr/co_dino_swin.py`` and the pretrained model weights are under ``repos/DaNuMa2024/data/raw_data/7_instance_segmentation/pretrained/codino_swin.pth``
2. Run model inference on one of the example images under ``DaNuMa2024/data/raw_data/7_instance_segmentation/images``
3. Plot the bounding boxes to verify that the model works (code already provided below). Do you obtain the correct bounding boxes? Hint: You might need to filter the bounding boxes based on their score!

You will get some warnings even if the code is correct. You can ignore them.

In [None]:
######### YOUR CODE HERE:
# 1. initialize the model

In [31]:
######### YOUR CODE HERE:
# 2. run model inference on image

In [None]:
# 3. plot bounding boxes
bboxes = [] # DELETE THIS LINE ONCE YOU OBTAINED THE REAL BBOXES
image_path = os.path.join(raw_data_dir, '7_instance_segmentation/images/danuma_1578.jpg') # path to the image you obtained the bboxes for
image = mmcv.imread(image_path, channel_order='rgb')
plt.figure(figsize=(10, 10))
plt.imshow(image)
for box in bboxes:
    show_box(box, plt.gca(), random_color=True)
plt.axis('off')
plt.show()

### 2. Instance mask from single-point prompt

For instance segmentation, we will rely on the repository of the SAM model introduced by Facebook AI Research: https://github.com/facebookresearch/segment-anything. \
The model generates segmentation masks for objects on images without the need to retrain the model explicitly for the task at hand! (hence the name "Segment Anything"). If you want to know more about this ground-breaking model, a link to the paper is given in the "Further reads" section (not relevant for this exercise). \
Once again, all necessary packages have already been installed and pretrained models are already downloaded. In this part of the notebook, we will use the model with a **point prompt**. That means: We will provide the model with a point on the image and, if the model lives up to its name, it hopefully generates a segmentation mask that represents the object that was marked with the point. For this, we do not need the bounding boxes yet.

There is also a demo notebook for SAM (https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb) that provides all functionality you need. If you want to try out the demo notebook before you work on this exercise, you can do so by clicking on the "open in colab" button. A quick guide on how to use google colab is provided in the pad.

For this exercise, you do **NOT** need to install/import any packages. All relevant functions are already imported and can be directly used in this notebook, so you can ignore any installation/import statements in the demo notebook. Inspect the notebook and copy/modify the relevant code from it:
1. Load the SAM model and predictor. The pretrained model checkpoint for SAM is located at ``repos/DaNuMa2024/data/raw_data/7_instance_segmentation/pretrained/sam_vit_h_4b8939.pth``. The model type is ``vit_h`` and the model should be put on the GPU (cuda).
2. Set the image of the predictor to the image from the pig barn.
3. Define the point to provide as input to the model (code is already provided).
4. Input the point to the predictor to obtain a segmentation mask. As recommended in the demo notebook, set multimask_output=True and then select the mask with the highest score.
5. Plot the mask (code is already provided).


In [None]:
# use same image as for detection
image_path = os.path.join(raw_data_dir, '7_instance_segmentation/images/danuma_1578.jpg') # path to the image you obtained the bboxes for
image = mmcv.imread(image_path, channel_order='rgb')
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis('off')
plt.show()

In [25]:
######### YOUR CODE HERE:
# 1. load the SAM model and predictor

In [26]:
######### YOUR CODE HERE:
# 2. set image

In [None]:
# 3. select and visualize input point based on which to segment
input_point = np.array([[700, 350]])
input_label = np.array([1])

plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

In [None]:
######### YOUR CODE HERE:
# 4. input point to predictor to obtain segmentation masks

In [None]:
# 5.plot image with mask
mask = np.zeros_like(image[..., 0]).astype('bool') # DELETE THIS LINE ONCE YOU OBTAINED THE REAL MASKS
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()  

### 3. Instance mask from multi-box prompt

SAM is a flexible model. It works with different kinds of prompts. We will now use the bounding boxes as prompts that we obtained in the first part of the exercise. The demo notebook of SAM provides functionality to input all of these bounding boxes at once (batched prompt inputs) to obtain segmentation masks for all pigs. Inspect the notebook and copy/modify the relevant code to do this:

1. use the image and bounding boxes that you obtained in the first part of this exercise. You have to convert the bounding boxes to a torch.tensor.
2. perform batched inference to obtain instance segmentation masks. What is the shape of the mask and how is it related to the shape of the image? What values does the mask contain?
3. plot image with bounding boxes and masks (code is already provided)

In [33]:
######### YOUR CODE HERE:
# 1. image and bounding boxes

In [None]:
######### YOUR CODE HERE:
# 2. get segmentation based on bounding boxes

In [None]:
# 3. plot image with bounding boxes and segmentation masks
bboxes = [] # DELETE THIS LINE ONCE YOU OBTAINED THE REAL BBOXES
masks = [] # DELETE THIS LINE ONCE YOU OBTAINED THE REAL MASKS
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in bboxes:
    show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()

### Further reads

The object detection model used in this notebook: https://openaccess.thecvf.com/content/ICCV2023/papers/Zong_DETRs_with_Collaborative_Hybrid_Assignments_Training_ICCV_2023_paper.pdf \
Starting with such a complicated object detection model right away might not be the best way to dive into the topic. The following are more simple approaches that have been shaping the field of object detection for years:
* https://arxiv.org/pdf/1504.08083 (Fast-RCNN)
* https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Redmon_You_Only_Look_CVPR_2016_paper.pdf (YOLO)
* https://people.ee.duke.edu/~lcarin/Christy10.9.2020.pdf (DeTr)

Segment Anything Model: https://arxiv.org/pdf/2304.02643