In [None]:
# #!/usr/bin/env python

# """Notebook for rapid prompt-based annotation. Box to mask using Segment Anything Model"""

# __author__      = "Sahib Julka <sahib.julka@uni-passau.de>"
# __copyright__   = "GPL"


In [None]:
import torch
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from utils import encode_image
from jupyter_bbox_widget import BBoxWidget


In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_TYPE = "vit_h"
CHECKPOINT_PATH = os.path.join("sam","sam_vit_h_4b8939.pth")

In [None]:
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
mask_predictor = SamPredictor(sam)

In [None]:
DATA_PATH = 'processed/chorus/'
IMAGES = []
filenames = []
for fn in os.listdir(DATA_PATH):
    #if fn.endswith('cropped.png'):
        filenames.append(fn)
        IMAGES.append(os.path.join(DATA_PATH, fn))
        
        #widget.image = encode_image(IMAGE_PATH)

In [None]:
i = np.random.randint(1, len(filenames))
widget = BBoxWidget()
widget.image = encode_image(IMAGES[i])
widget

In [None]:
widget.bboxes

In [None]:
import numpy as np

# default_box is going to be used if you will not draw any box on image above
default_box = {'x': 68, 'y': 247, 'width': 555, 'height': 678, 'label': ''}
boxes = []
#box = widget.bboxes[0] if widget.bboxes else default_box
for box in widget.bboxes:
    box = np.array([
        box['x'], 
        box['y'], 
        box['x'] + box['width'], 
        box['y'] + box['height']
    ])
    boxes.append(box)

In [None]:
boxes = np.array(boxes)
boxes

In [None]:
import cv2
import numpy as np
import supervision as sv


boxes = torch.Tensor(boxes).to(DEVICE)
image_bgr = cv2.imread(IMAGES[i])
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

mask_predictor.set_image(image_rgb)
transformed_boxes = mask_predictor.transform.apply_boxes_torch(boxes, image_rgb.shape[:2])


masks, scores, logits = mask_predictor.predict_torch(
    point_coords = None,
    point_labels = None,
    boxes=transformed_boxes,
    multimask_output=False
)
mask = masks.sum(axis = 0).cpu().numpy()

In [None]:
box_annotator = sv.BoxAnnotator(color=sv.Color.red())
mask_annotator = sv.MaskAnnotator(color=sv.Color.red())

detections = sv.Detections(
    xyxy=sv.mask_to_xyxy(masks=mask),
    mask=mask
)
detections = detections[detections.area == np.max(detections.area)]

source_image = box_annotator.annotate(scene=image_bgr.copy(), detections=detections, skip_label=True)
segmented_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)

sv.plot_images_grid(
    images=[source_image, segmented_image],
    grid_size=(1, 2),
    titles=['source image', 'segmented image']
)

In [None]:
import supervision as v

sv.plot_images_grid(
    images=mask,
    grid_size=(1, 4),
    size=(16, 4)
)

In [None]:
os.makedirs("processed/masks/", exist_ok = True)
np.save('processed/masks/{}.npy'.format(filenames[i].split('.png')[0]), mask)

In [24]:
os.makedirs("processed/masks/", exist_ok = True)
np.save('processed/masks/{}.npy'.format(filenames[i].split('.png')[0]), mask)