# *Phase 0: Sea Ice Image Segmentation with Segment Anything Model (SAM)*
***
Meta AI recently released the Segment Anything Model (SAM), which is a promptable segmentation system with zero-shot generalization to unfamiliar objects and images, without need for additional training.

Here, we generate masks for a sample random of images, previously manually segmented for comparison, using both `SamAutomaticMaskGenerator` with default settings and `Sam Predictor` with seeds acquired from pre-processing during GVF application.
***

## Environment setup and SAM initialization

| Package  | Purpose  |
| -------- | -------- |
| `cv2`        | Open the image |
| `matplotlib` | Plot figures |
| `numpy`      | Manage arrays |
| `torch`      | Manage input seeds into SAM |
| `scipy` | Save output in a MATLAB format |
| `segment_anything` | SAM |
| `skimage` | Clean and post processing binary masks |

In [1]:
import glob
import os

# import matplotlib.pyplot as plt
import cv2
import numpy as np
import torch
from scipy.io import savemat
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from skimage import measure, morphology, segmentation


def circular(area, perimeter):
    return (4 * np.pi * area) / perimeter**2


# Function to read 1-D text file saved from MATLAB and convert to tensor
def read_tensor(file_path, device):
    with open(file_path, "r") as file:
        data = file.read().split()
    data = list(map(float, data))
    return torch.tensor(data, device=device)

The SAM model can be loaded with 3 different encoders: ViT-B, ViT-L, and ViT-H. Set the path below to the SAM checkpoint.

To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Masks for the entire image can be generated by sampling a large number of prompts over an image.

To run prediction given prompts that indicate the desired object, provide a SAM model to `SamPredictor` class. The model first converts the image into an image embedding that allows high quality masks to be efficiently produced from point and box prompts, as well as masks from the previous iteration of prediction.

In [2]:
#!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)
predictor = SamPredictor(sam)

### Define image paths

In [3]:
dayOfCruise = "2022-07-23"
outputManual_path = os.path.join(
    "/Users/giuliopasserotti/Documents/Pyth.D/seaice_fsd_w_GVF/manual_fsd_pixelmator/output_left_"
    + dayOfCruise,
)
seedDiffuse_path = os.path.join(
    "/Users/giuliopasserotti/Documents/Pyth.D/seaice_fsd_w_GVF/output_left_"
    + dayOfCruise,
    "tmp",
)

outputManualCropped_files = sorted(
    glob.glob(os.path.join(outputManual_path, "*croppedImage.png"))
)

SAMsave_path = os.path.join(
    "/Users/giuliopasserotti/Documents/Pyth.D/seaice_fsd_w_GVF/manual_fsd_pixelmator/output_left_"
    + dayOfCruise
    + "_SAM",
    "tmp",
)

SAMpredsave_path = os.path.join(
    "/Users/giuliopasserotti/Documents/Pyth.D/seaice_fsd_w_GVF/manual_fsd_pixelmator/output_left_"
    + dayOfCruise
    + "_SAMpred",
    "tmp",
)

## Automatic mask generation

To generate masks, just run `generate` on an image. Mask generation returns a list over masks, where each mask is a dictionary containing various data about the mask. These keys are:
* `segmentation` : the mask
* `area` : the area of the mask in pixels
* `bbox` : the boundary box of the mask in XYWH format
* `predicted_iou` : the model's own prediction for the quality of the mask
* `point_coords` : the sampled input point that generated this mask
* `stability_score` : an additional measure of mask quality
* `crop_box` : the crop of the image used to generate this mask in XYWH format

In [4]:
for file_path in outputManualCropped_files:
    base_name = os.path.basename(file_path)[:18]

    croppedImage_path = os.path.join(seedDiffuse_path, base_name + "croppedImage.png")
    croppedImage = cv2.imread(croppedImage_path)
    croppedImage = cv2.cvtColor(croppedImage, cv2.COLOR_BGR2RGB)

    diffuseImage_path = os.path.join(seedDiffuse_path, base_name + "diffuseImage.png")
    diffuseImage = cv2.imread(diffuseImage_path)
    diffuseImage = cv2.cvtColor(diffuseImage, cv2.COLOR_BGR2RGB)

    blackMask = np.all(croppedImage == [0, 0, 0], axis=2)
    AggrSAM_masks = np.zeros_like(blackMask)
    Allcoordinates = []
    # AllperimPixelList = []

    masks = mask_generator.generate(diffuseImage)
    masks_binary = [mask["segmentation"] for mask in masks]

    for SAM_mask in masks_binary:
        SAM_mask[blackMask] = True

        bclearSAM = segmentation.clear_border(SAM_mask, buffer_size=1)

        se = morphology.disk(2)
        erodebclearSAM = morphology.binary_erosion(bclearSAM, se)

        labelerodebclearSAM = measure.label(erodebclearSAM, connectivity=2)
        regions = measure.regionprops(labelerodebclearSAM)

        circularity = np.array(
            [circular(prop.area, prop.perimeter) for prop in regions]
        )
        circularity = np.round(circularity, 2)
        eccentricity = np.array([prop.eccentricity for prop in regions])
        eccentricity = np.round(eccentricity, 2)
        coordinates = [prop.coords for prop in regions]

        noFloeIdx = np.where((circularity > 1) | ~np.isfinite(circularity))[0]
        noFloeBW = np.isin(
            labelerodebclearSAM, noFloeIdx + 1
        )  # Add 1 because skimage labels start from 1

        cleanSAM_mask = np.copy(erodebclearSAM)
        cleanSAM_mask[noFloeBW] = False

        circularity = np.delete(circularity, noFloeIdx)
        eccentricity = np.delete(eccentricity, noFloeIdx)
        coordinates = [
            coordinates[i] for i in range(len(coordinates)) if i not in set(noFloeIdx)
        ]

        lineFloeIdx = np.where(eccentricity >= 0.9)[0]
        labelcleanSAM = measure.label(cleanSAM_mask, connectivity=2)
        lineFloeBW = np.isin(
            labelcleanSAM, lineFloeIdx + 1
        )  # Add 1 because skimage labels start from 1

        cleanSAM_mask[lineFloeBW] = False

        circularity = np.delete(circularity, lineFloeIdx)
        eccentricity = np.delete(eccentricity, lineFloeIdx)
        coordinates = [
            coordinates[i] for i in range(len(coordinates)) if i not in set(lineFloeIdx)
        ]

        AggrSAM_masks[cleanSAM_mask] = True
        Allcoordinates.append(coordinates)
        # perimPixelList = measure.find_contours(cleanSAM_mask)
        # AllperimPixelList.append(perimPixelList)

    AllcoordinatesSAM_masks = [
        coo for coords in Allcoordinates if coords for coo in coords
    ]
    AllcoordinatesSAM_masks = np.array(AllcoordinatesSAM_masks, dtype=object)
    AllcoordinatesSAMmasks_path = os.path.join(
        SAMsave_path, base_name + "AllcoordinatesSAM_masks.mat"
    )
    savemat(
        AllcoordinatesSAMmasks_path,
        {"AllcoordinatesSAM_masks": AllcoordinatesSAM_masks},
    )

    AggrSAMmasks_path = os.path.join(SAMsave_path, base_name + "AggrSAM_masks.png")
    cv2.imwrite(AggrSAMmasks_path, np.uint8(AggrSAM_masks * 255))

  return (4 * np.pi * area) / perimeter**2


| Function  | Purpose  |
| -------- | -------- |
| `segmentation`        | Clear floes connected to the image borders |
| `morphology` | Binary erosion of an amount equal to disk(2) |
| `measure`      | Regionprops to filter based on circularity, eccentricity and retrieve floes coords |


Note that circularity and eccentricity give lower results compared to MATLAB. In addition `measure.find_contours` finds different perimeter pixel list compared to `bwboundaries` on MATLAB. Reason why I need to save directly the coords that form the floes.

## Generate segmentation with point prompt

In [4]:
for file_path in outputManualCropped_files:
    base_name = os.path.basename(file_path)[:18]

    croppedImage_path = os.path.join(seedDiffuse_path, base_name + "croppedImage.png")
    croppedImage = cv2.imread(croppedImage_path)
    croppedImage = cv2.cvtColor(croppedImage, cv2.COLOR_BGR2RGB)

    diffuseImage_path = os.path.join(seedDiffuse_path, base_name + "diffuseImage.png")
    diffuseImage = cv2.imread(diffuseImage_path)
    diffuseImage = cv2.cvtColor(diffuseImage, cv2.COLOR_BGR2RGB)

    seedX_path = os.path.join(seedDiffuse_path, base_name + "roundXseeds.txt")
    seedY_path = os.path.join(seedDiffuse_path, base_name + "roundYseeds.txt")
    seedX = read_tensor(seedX_path, predictor.device)
    seedY = read_tensor(seedY_path, predictor.device)

    input_point = torch.stack([seedX, seedY], dim=1).unsqueeze(1)
    input_label = torch.ones((input_point.shape[0], 1), device=predictor.device)

    blackMask = np.all(croppedImage == [0, 0, 0], axis=2)
    AggrSAMpred_masks = np.zeros_like(blackMask)
    Allcoordinates = []
    # AllperimPixelList = []

    predictor.set_image(diffuseImage)
    transformed_points = predictor.transform.apply_coords_torch(
        input_point, diffuseImage.shape[:2]
    )

    masks, _, _ = predictor.predict_torch(
        point_coords=transformed_points,
        point_labels=input_label,
        boxes=None,
        multimask_output=False,
    )
    masks_binary = masks.clone().squeeze().numpy()

    for SAMpred_mask in masks_binary:
        SAMpred_mask[blackMask] = True

        bclearSAMpred = segmentation.clear_border(SAMpred_mask, buffer_size=1)

        se = morphology.disk(2)
        erodebclearSAMpred = morphology.binary_erosion(bclearSAMpred, se)

        labelerodebclearSAMpred = measure.label(erodebclearSAMpred, connectivity=2)
        regions = measure.regionprops(labelerodebclearSAMpred)

        circularity = np.array(
            [circular(prop.area, prop.perimeter) for prop in regions]
        )
        circularity = np.round(circularity, 2)
        eccentricity = np.array([prop.eccentricity for prop in regions])
        eccentricity = np.round(eccentricity, 2)
        coordinates = [prop.coords for prop in regions]

        noFloeIdx = np.where((circularity > 1) | ~np.isfinite(circularity))[0]
        noFloeBW = np.isin(
            labelerodebclearSAMpred, noFloeIdx + 1
        )  # Add 1 because skimage labels start from 1

        cleanSAMpred_mask = np.copy(erodebclearSAMpred)
        cleanSAMpred_mask[noFloeBW] = False

        circularity = np.delete(circularity, noFloeIdx)
        eccentricity = np.delete(eccentricity, noFloeIdx)
        coordinates = [
            coordinates[i] for i in range(len(coordinates)) if i not in set(noFloeIdx)
        ]

        lineFloeIdx = np.where(eccentricity >= 0.9)[0]
        labelcleanSAMpred = measure.label(cleanSAMpred_mask, connectivity=2)
        lineFloeBW = np.isin(
            labelcleanSAMpred, lineFloeIdx + 1
        )  # Add 1 because skimage labels start from 1

        cleanSAMpred_mask[lineFloeBW] = False

        circularity = np.delete(circularity, lineFloeIdx)
        eccentricity = np.delete(eccentricity, lineFloeIdx)
        coordinates = [
            coordinates[i] for i in range(len(coordinates)) if i not in set(lineFloeIdx)
        ]

        AggrSAMpred_masks[cleanSAMpred_mask] = True
        Allcoordinates.append(coordinates)
        # perimPixelList = measure.find_contours(cleanSAM_mask)
        # AllperimPixelList.append(perimPixelList)

    AllcoordinatesSAMpred_masks = [
        coo for coords in Allcoordinates if coords for coo in coords
    ]
    AllcoordinatesSAMpred_masks = np.array(AllcoordinatesSAMpred_masks, dtype=object)
    AllcoordinatesSAMpredmasks_path = os.path.join(
        SAMpredsave_path, base_name + "AllcoordinatesSAMpred_masks.mat"
    )
    savemat(
        AllcoordinatesSAMpredmasks_path,
        {"AllcoordinatesSAMpred_masks": AllcoordinatesSAMpred_masks},
    )

    AggrSAMpredmasks_path = os.path.join(
        SAMpredsave_path, base_name + "AggrSAMpred_masks.png"
    )
    cv2.imwrite(AggrSAMpredmasks_path, np.uint8(AggrSAMpred_masks * 255))

  return (4 * np.pi * area) / perimeter**2
