In [None]:
# Define the function for Mask R-CNN model
# Import the necessary libraries
import torch
import torchvision
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import os
import urllib.request

DATASET_PREFIX = os.environ.get('DATASET_PREFIX', '')
IMAGENET_LABELS_FILE = DATASET_PREFIX + "imagenet_classes.txt"
CIFAR100_LABELS_FILE = DATASET_PREFIX + "cifar100_labels.txt"
CIFAR10_LABELS_FILE = DATASET_PREFIX + "cifar10_labels.meta"
PASCAL_VOC_LABELS_FILE = DATASET_PREFIX + "pascal_voc_labels.txt"
PLACES365_LABELS_FILE = DATASET_PREFIX + "categories_places365.txt"
COCO_LABELS_FILE = DATASET_PREFIX + "coco_labels.txt"

def get_coco_labels():
    # Download the labels file from the internet
    if not os.path.exists(COCO_LABELS_FILE):
        labels_url = "https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names"
        urllib.request.urlretrieve(labels_url, COCO_LABELS_FILE)
    
    # Load the labels file
    with open(COCO_LABELS_FILE, "r") as f:
        coco_labels = f.readlines()
        coco_labels = [label.strip() for label in coco_labels]
    
    # Return the COCO labels
    return coco_labels

# Call the function to get the labels from COCO
coco_labels = get_coco_labels()
print(coco_labels)

def get_model(model_name):
    if model_name == 'maskrcnn':
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    elif model_name == 'yolact':
        model = torch.hub.load('dbolya/yolact', 'yolact_resnet50', pretrained=True)
    else:
        raise ValueError('Invalid model name')
    # Use GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval() # Set model to inference mode
    return model


# Define the function for postprocessing
def postprocess(output):
    # Perform postprocessing on the model output
    # get the predicted boxes, labels, and masks for the objects in the image
    boxes = output[0]['boxes'].detach().numpy()
    labels = output[0]['labels'].detach().numpy()
    classs = np.array()
    for lable in range(labels):
        cls = coco_labels[lable]
        classs.append(cls)

    masks = output[0]['masks'].detach().numpy()

    print("boxes:",boxes)
    print("labels:",labels)
    print("classs:",classs)
    print("masks:",masks)
    return boxes,classs,masks

# Define the function for instance segmentation using Mask R-CNN and YOLACT models with postprocessing
# Define the factory function for instance segmentation using Mask R-CNN and YOLACT models with postprocessing
def instance_segmentation(image_path, model_name):
    image = Image.open(image_path)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    image = transform(image)
    if torch.cuda.is_available():
        image = image.to('cuda')
    
    model = get_model(model_name)
    output = None
    with torch.no_grad():
        output = model([image])
    
    print(output)
    
    boxes,labels,masks = postprocess(output)
    return boxes,labels,masks



In [None]:

boxes,labels,masks = instance_segmentation("/workspace/tests/pexels-pixabay-45201.jpg","maskrcnn")

In [None]:

boxes,labels,masks = instance_segmentation("/workspace/tests/pexels-pixabay-45201.jpg","yolact")

In [None]:

image = Image.open("/workspace/tests/pexels-pixabay-45201.jpg")
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
image = transform(image)

image = np.array(image.permute(1, 2, 0))
fig, ax = plt.subplots(1, figsize=(10, 10))
ax.imshow(image)

for i in range(len(boxes)):
    mask = masks[i, 0]
    x1, y1, x2, y2 = boxes[i]
    width = x2 - x1
    height = y2 - y1
    ax.imshow(mask, alpha=0.5, extent=[x1, x1+width, y1, y1+height], cmap='Reds')
    label = labels[i]
    ax.text(x1, y1, f"{label}", fontsize=12, color='white', bbox=dict(facecolor='red', alpha=0.5, pad=0), verticalalignment='top')

plt.show()