# Introduction to image segmentation

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from PIL import Image
import matplotlib.pyplot as plt

from torchvision.models.detection import maskrcnn_resnet50_fpn

## Creating binary masks

In [None]:
# Load mask image
mask = Image.open("annotations/Egyptian_Mau_123.png")

# Transform mask to tensor
transform = transforms.Compose([transforms.ToTensor()])
mask_tensor = transform(mask)

# Create binary mask
binary_mask = torch.where(
    mask_tensor == 1/255,
    torch.tensor(1.0),
    torch.tensor(0.0),
)

# Print unique mask values
print(binary_mask.unique())

## Segmenting image with a mask

In [None]:
# Load image and transform to tensor
image = Image.open("images/Egyptian_Mau_123.jpg")
transform = transforms.Compose([transforms.ToTensor()])
image_tensor = transform(image)

# Segment object out of the image
object_tensor = image_tensor * binary_mask

# Convert segmented object to image and display
to_pil_image = transforms.ToPILImage()
object_image = to_pil_image(object_tensor)
plt.imshow(object_image)
plt.show()

# Instance segmentation with mask R-CNN

## Segmenting with pre-trained Mask R-CNN

In [None]:
# Load a pre-trained Mask R-CNN model
model = maskrcnn_resnet50_fpn(pretrained=True)
model.eval()

# Load an image and convert to a tensor
image = Image.open("images/two_cats.jpg")
transform = transforms.Compose([transforms.ToTensor()])
image_tensor = transform(image).unsqueeze(0)

# Perform inference
with torch.no_grad():
    prediction = model(image_tensor)
    print(prediction)

## Displaying soft masks

In [None]:
# Extract masks and labels from prediction
masks = prediction[0]["masks"]
labels = prediction[0]["labels"]

# Plot image with two overlaid masks
for i in range(2):
    plt.imshow(image)
    # Overlay the i-th mask on top of the image
    plt.imshow(masks[i, 0], cmap="jet", alpha=0.5)
    #plt.title(f"Object: {class_names[labels[i]]}")
    plt.show()