# <b>Deep Learning:</b>
# Improving existing segmentators performance with zero-shot segmentators

This Notebook implements the code used in our paper "Improving existing segmentators performance with zero-shot segmentators".

In our study, we used the predicted segmentation masks from state-of-the-art methods **DeepLabV3+** https://github.com/VainF/DeepLabV3Plus-Pytorch and **PVTv2** https://github.com/whai362/PVT.

From these masks, we produce some checkpoints to feed **SAM** (**Segment Anything**, https://github.com/facebookresearch/segment-anything) for *Post-Processing Segmentation Enhancement* or SEEM (**Segment Everything Everywhere All at Once**, https://github.com/UX-Decoder/Segment-Everything-Everywhere-All-At-Once) models.

The **Segment Anything Model (SAM)** produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.

Similarly to SAM, **Segment Everything Everywhere All at Once (SEEM)** allows users to easily segment an image using prompts of different types including visual prompts (points, marks, boxes, scribbles and image segments) and language prompts (text and audio), etc. It can also work with any combinations of prompts or generalize to custom prompts.


We devised 4 different methods for producing checkpoints:
 - A: the pixel whose coordinates are, along each dimension, the average of value of all the mask's pixels coordinate
 - B: the center of mass of the mask
 - C: one (or more) pixels drawn (uniformly or not...) randomly inside the mask area
 - D: pixels drawn from the intersection of a uniform grid of fixed step size and the mask. "b" stands for the intersection between the grid and the eroded mask, where the mask is shrinked of 10 pixels.
---

### In order to run the script, you need to:
 - set the path to your data folder    (in "Parameters of the script" cell)
 - set the which type of DeepLabV3+ mask to consider (binary or real valued) (in "Run SAM" cell)

---



### Necessary imports

In [None]:
%matplotlib inline
import os
import glob
import sys
import math

import cv2                      ## pip install opencv-python
import torch                    ## pip install torch
import matplotlib.pyplot as plt ## pip install matplotlib
import numpy as np              ## pip install numpy
import pickle                   ## pip install pickle
import pandas as pd
import seaborn as sb
from skimage import measure     ## pip install scikit-image
from scipy import ndimage       ## pip install scipy
from tqdm import tqdm           ## pip install tqdm
###
## to download SAM:
## git clone git@github.com:facebookresearch/segment-anything.git
## cd segment-anything; pip install -e .
###
from segment_anything import sam_model_registry, SamPredictor

### For reproducibility, seed of random generators

In [2]:
torch.manual_seed(0)
np.random.seed(0)
np.set_printoptions(threshold=sys.maxsize)

### Script Consts

In [35]:
# Paths
DATASET_PATH = 'c:\\Users\\gustavo\\Documents\\SAM\\data\\test'
BASE_OUTPUT_PATH = 'c:\\Users\\gustavo\\Documents\\SAM\\outputs\\test'
SAM_SUBMODULES_PATH = 'c:\\Users\\gustavo\\Documents\\SAM\\pretrained_models'

In [36]:
# Execution Configs
VERBOSE = True
OVERWRITE_OUTPUTS = True

In [37]:
# Hyperparameters
SAMPLING_MODE="grid"
SAMPLING_STEP=50
BORDER_MODE="off"
PREDICTION_TH = 0.0

# Experiment configuration
DEVICE = "cpu" # ["cuda", "cpu"]
SOURCE_MASK = "deeplab" # ["oracle", "deeplab", "pvtv2"]

In [38]:
# Model config
## VIT-H Config
SAM_CHECKPOINT = os.path.join(SAM_SUBMODULES_PATH, "sam_vit_h_4b8939.pth")
MODEL_TYPE = "default"

## VIT-L Config
# SAM_CHECKPOINT = os.path.join(SAM_SUBMODULES_PATH, "sam_vit_l_0b3195.pth")
# MODEL_TYPE = "vit_l"

In [39]:
# Dataset
DATASETS = [
    # "CAMO",
    # "portrait",
    # "locuste",
    "ribs",
    # "SKIN/SKIN_COMPAQ",
    # "SKIN/SKIN_ECU",
    # "SKIN/SKIN_HANDGESTURE",
    # "SKIN/SKIN_MCG",
    # "SKIN/SKIN_Pratheepan",
    # "SKIN/SKIN_Schmugge",
    # "SKIN/SKIN_SFA",
    # "SKIN/SKIN_uchile",
    # "SKIN/SKIN_VMD",
    # "SKIN/SKIN_VT-,
    # "Butterfly/FoldDA1_1",
    # "Butterfly/FoldDA1_2",
    # "Butterfly/FoldDA1_3",
    # "Butterfly/FoldDA1_4",
    # "COCO_val2017",
    # "MARS"
]

### Helper function for reading images and masks

In [40]:
def read_img(path:str) -> np.ndarray:
    return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)

def read_bmask(path:str) -> np.ndarray:
    return cv2.imread(path, cv2.IMREAD_GRAYSCALE) / 255.0

def read_rmask(path:str) -> np.ndarray:
    return cv2.imread(path, cv2.IMREAD_UNCHANGED) 

### Helper functions for displaying points, boxes, and masks.

In [41]:
def show_mask(mask: np.ndarray, ax, random_color:bool=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords:np.ndarray, labels:np.ndarray, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) # this is if you want the star
    #ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=10, edgecolor='green', linewidth=1.25) # this is if you want the dot
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

def draw_img(img:np.ndarray, input_point:np.ndarray=None, input_label:np.ndarray=None, \
             mask: np.ndarray=None, title:plt.title = None, plt_show:bool=True):
    plt.clf()
    plt.figure(figsize=(3,3))
    plt.imshow(img)
    if mask is not None:
        show_mask(mask, plt.gca())
    if input_point is not None and input_label is not None:
        show_points(input_point, input_label, plt.gca())
    if title is not None:
        plt.title(title, fontsize=18)
    plt.axis('off')
    if plt_show:
        plt.show()
        
def draw_results(masks:list, scores:list, gt_mask:np.ndarray, \
                 input_point:np.ndarray, input_label:np.ndarray):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        iou = get_iou(mask, gt_mask)
        title = f"Mask {i+1}, Score: {score:.3f}, IoU: {iou:.3f}"
        draw_img(img, input_point, input_label, mask, title)

### Metrics
(between a predicted mask and a Ground Truth mask)

In [42]:
class Metrics():
    eps=np.finfo(np.double).eps

    def __init__(self, dataset=None):
        self.reset()
        self.step        = self.step_common
        self.get_iou     = self.get_iou_common
        self.get_dice    = self.get_dice_common
        self.get_results = self.get_results_common
        
        if "SKIN" in dataset:
            self.set_mode_skin()
        elif "Locuste" in dataset:
            self.set_mode_locuste()
    
    def reset(self):
        self.tps, self.fps, self.tns, self.fns = 0, 0, 0, 0
        self.ious, self.maes, self.dices, self.wfms, self.emes = [], [], [], [], []
    
    def step_common(self, pred, GT):
        iou       = self.get_iou(pred, GT)
        dice      = self.get_dice(pred, GT)
        mae       = self.compute_mae(pred, GT)
        fscore    = self.FbetaMeasure(pred.astype(bool), GT.astype(bool))
        e_measure = self.EMeasure(pred.astype(bool), GT.astype(bool))
        self.ious.append(iou)
        self.dices.append(dice)
        self.maes.append(mae)
        self.wfms.append(fscore)
        self.emes.append(e_measure)
    
    def step_skin(self, pred, gt):
        y_pred_bool = pred.astype(bool)
        y_true_bool = gt.astype(bool)
        self.tps += np.logical_and(y_true_bool, y_pred_bool).sum()
        self.tns += np.logical_and(~y_true_bool, ~y_pred_bool).sum()
        self.fps += np.logical_and(~y_true_bool, y_pred_bool).sum()
        self.fns += np.logical_and(y_true_bool, ~y_pred_bool).sum()
        
        
    def step_locuste(self, pred, GT):
        iou       = self.get_iou_locuste(pred, GT)
        dice      = self.get_dice_locuste(pred.astype(bool), GT.astype(bool))
        mae       = self.compute_mae(pred, GT)
        e_measure = self.EMeasure(pred.astype(bool), GT.astype(bool))
        fscore    = self.FbetaMeasure(pred.astype(bool), GT.astype(bool))
        self.ious.append(iou)
        self.dices.append(dice)
        self.maes.append(mae)
        self.wfms.append(fscore)
        self.emes.append(e_measure)
        
    
    def get_iou_common(self, pred, gt, beta=1):
        y_pred_bool = pred.astype(bool)
        y_true_bool = gt.astype(bool)
        tp = np.logical_and(y_true_bool, y_pred_bool).sum()
        tn = np.logical_and(~y_true_bool, ~y_pred_bool).sum()
        fp = np.logical_and(~y_true_bool, y_pred_bool).sum()
        fn = np.logical_and(y_true_bool, ~y_pred_bool).sum()
        if tp+fn+fp==0:
            if tp==0:
                iou=1.0
            else:
                iou=0.0
        else:
            iou = tp / (tp + fn + fp)

        return iou

    def get_iou_locuste(self, pred, target, n_classes = 2):
        return jaccard_score(target.reshape(-1).astype(bool), pred.reshape(-1).astype(bool))

    def get_dice_common(self, pred, gt):
        y_pred_bool = pred.astype(bool)
        y_true_bool = gt.astype(bool)

        tp = np.logical_and( y_true_bool,  y_pred_bool).sum()
        tn = np.logical_and(~y_true_bool, ~y_pred_bool).sum()
        fp = np.logical_and(~y_true_bool,  y_pred_bool).sum()
        fn = np.logical_and( y_true_bool, ~y_pred_bool).sum()
        
        if tp+fn+fp==0:
            dice=1.0 if tp==0 else 0.0
        else:
            dice = 2*tp / (2*tp + fn + fp)

        return dice

    def _calConfusion(self, pred, GT):
        TP=np.sum(pred[GT]==1)
        FP=np.sum(pred[~GT]==1)
        TN=np.sum(pred[~GT]==0)
        FN=np.sum(pred[GT]==0)
        return TP,FP,TN,FN

    def get_dice_locuste(self, pred, GT):
        tp, fp, tn, fn = _calConfusion(pred, GT)
        return (2.0 * tp) / (2.0 * tp + fp + fn + 1e-7)

    def compute_mae(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
        mae = np.mean(np.abs(pred - gt))
        return mae
    
    ## F-Measure
    def FbetaMeasure(self, pred, GT, beta= math.sqrt(0.3)):
        TP,FP,TN,FN=self._calConfusion(pred, GT)
        if TP+FN+FN==0:
            if TP==0:
                Fbeta=1.0
            else:
                Fbeta=0.0
        else:
            P=TP/(TP+FP+1e-8) #precision
            R=TP/(TP+FN+1e-8) #recall
            Fbeta=(beta**2+1)*P*R/((beta**2)*P+R+1e-8)
        return Fbeta
    
    ## E-Measure
    def _EnhancedAlignmnetTerm(self, align_Matrix):
        enhanced=((align_Matrix+1)**2)/4
        return enhanced

    def _AlignmentTerm(self, dGT, dpred):
        mean_dpred=np.mean(dpred)
        mean_dGT=np.mean(dGT)
        align_dpred=dpred-mean_dpred
        align_dGT=dGT-mean_dGT
        align_matrix=2*(align_dGT*align_dpred)/(align_dGT**2+align_dpred**2+self.eps)
        return align_matrix

    def EMeasure(self, pred, GT):
        dGT,dpred=GT.astype(np.float64),pred.astype(np.float64)
        if np.sum(GT)==0:#completely black
            enhanced_matrix=1-dpred
        elif np.sum(~GT)==0:
            enhanced_matrix=dpred
        else:
            align_matrix=self._AlignmentTerm(dGT,dpred)
            enhanced_matrix=self._EnhancedAlignmnetTerm(align_matrix)
        rows,cols= GT.shape
        
        # score=np.sum(enhanced_matrix)/(rows*cols-1+self.eps)
        score=np.sum(enhanced_matrix)/(rows*cols+self.eps)
        return score
    
    def get_results_common(self) -> (float, float, float, float, float):
        return np.array(self.ious).mean(), np.array(self.dices).mean(), np.array(self.maes).mean(), np.array(self.wfms).mean(), np.array(self.emes).mean()

    def get_results_skin(self):
        iou = self.tps / (self.tps + self.fns + self.fps)
        dice = (2.0 * self.tps) / (2.0 * self.tps + self.fps + self.fns + 1e-7)
        # for skin dataset, we didn't need the other metrics. TODO: implement
        return iou, dice, None, None, None
    
    def set_mode_locuste(self):
        self.step = self.step_locuste
        self.get_iou = self.get_iou_locuste
        self.get_dice = self.get_dice_locuste
    
    def set_mode_skin(self):
        self.step = self.step_skin
        self.get_results = self.get_results_skin
            
    

### Functions for the sampling of the checkpoints.

In [43]:

class Sampler:
    verbose = True
    sampling_step = None
    min_blob_count = None
    
    def __init__(self, verbose, sampling_step, min_blob_count):
        self.verbose        = verbose
        self.sampling_step  = sampling_step
        self.min_blob_count = min_blob_count
    
    def sample_pixels(self, mask_of_blobs: np.ndarray, mask: np.ndarray) -> (np.ndarray, np.ndarray):
        # draw a pix for each blob
        input_point, input_label = [], []
        blob_labels, blob_sample = np.unique(mask_of_blobs, return_index=True)
        gt_fl = mask.flatten()
        for bl, bs in zip(blob_labels, blob_sample):
            mask_bool = (mask_of_blobs==bl)
            count = mask_bool.sum()
            if gt_fl[bs]>=1.0 and count>self.min_blob_count: ## it's not a background blob or a false blob
                x_center, y_center = np.argwhere(mask_bool).sum(0)/count
                x_center, y_center = int(x_center) % mask.shape[0], int(y_center) % mask.shape[1]
                input_point.append([y_center, x_center])
                input_label.append(1)
                print(f"blob #{bl} drawn point: {[x_center, y_center]}") if self.verbose else None

        # no mask? pick the center pixel of image
        if len(input_point) == 0:
            input_point, input_label = [[mask.shape[1]//2, mask.shape[0]//2]], [1]

        return np.array(input_point), np.array(input_label)
    
    def sample_pixels_center_of_mass(self, mask_of_blobs: np.ndarray, mask: np.ndarray) -> (np.ndarray, np.ndarray):
        # draw a pix for each blob
        input_point, input_label = [], []
        blob_labels, blob_sample = np.unique(mask_of_blobs, return_index=True)
        gt_fl = mask.flatten()
        for bl, bs in zip(blob_labels, blob_sample):
            mask_bool = (mask_of_blobs==bl)
            count = mask_bool.sum()
            if gt_fl[bs]>=1.0 and count>self.min_blob_count: ## it's not a background blob or a false blob
                x_center, y_center = ndimage.center_of_mass(mask_bool)
                input_point.append([y_center, x_center])
                input_label.append(1)
                print(f"blob #{bl} drawn point: {[x_center, y_center]}") if self.verbose else None

        # no mask? pick the center pixel of image
        if len(input_point) == 0:
            input_point, input_label = [[mask.shape[1]//2, mask.shape[0]//2]], [1]
            print(f"empty blob -> {input_point}") if self.verbose else None

        return np.array(input_point), np.array(input_label)

    def sample_pixels_random(self, mask_of_blobs: np.ndarray, mask: np.ndarray) -> (np.ndarray, np.ndarray):
        # draw a pix for each blob
        input_point, input_label = [], []
        blob_labels, blob_sample = np.unique(mask_of_blobs, return_index=True)
        gt_fl = mask.flatten()
        for bl, bs in zip(blob_labels, blob_sample):
            mask_bool = (mask_of_blobs==bl)
            count = mask_bool.sum()
            if gt_fl[bs]>=1.0 and count>self.min_blob_count: ## it's not a background blob or a false blob
                indices = np.argwhere(mask_bool)
                random_index = np.random.choice(indices.shape[0])
                x_center, y_center = indices[random_index]
                input_point.append([y_center, x_center])
                input_label.append(1)
                print(f"blob #{bl} drawn point: {[x_center, y_center]}") if self.verbose else None

        # no mask? sample a random point
        if len(input_point) == 0:
            input_point, input_label = [ \
                [np.random.randint(0, mask.shape[1]),   \
                 np.random.randint(0, mask.shape[0])]], \
            [1]

        return np.array(input_point), np.array(input_label)

    def get_grid(self, mask, offset_px_x, offset_px_y):
        row = np.zeros(mask.shape, dtype=int)
        col = np.zeros(mask.shape, dtype=int)

        for i in range(offset_px_y, row.shape[0], self.sampling_step):
            row[i, :] = 1
        for i in range(offset_px_x, col.shape[1], self.sampling_step):
            col[:, i] = 1
        res = row & col
        return res
    
    def sample_pixels_grid(self, mask_of_blobs: np.ndarray, mask: np.ndarray) -> (np.ndarray, np.ndarray):

        # draw a pix for each blob
        input_point = []
        offset_px_x = 0
        offset_px_y = 0
        while len(input_point)==0 and offset_px_y < self.sampling_step:
            res = self.get_grid(mask, offset_px_x, offset_px_y)

            input_point = np.argwhere(res & mask.astype(np.int64))
            input_point[:, (0, 1)] = input_point[:, (1, 0)]
            
            offset_px_x += 1
            if offset_px_x>self.sampling_step:
                offset_px_x = 0
                offset_px_y += 1

        # STILL no mask? sample a random point
        blob_labels = np.unique(mask_of_blobs)
        if len(input_point) <= blob_labels.shape[0]-1:
            return self.sample_pixels_random(mask_of_blobs, mask)

        input_label = [1 for _ in input_point]
        

        return np.array(input_point), np.array(input_label)
    
    def sample_pixels_eroded_grid(self, mask_of_blobs: np.ndarray, mask: np.ndarray) -> (np.ndarray, np.ndarray):
        
        input_point = []
        offset_px_x = 0
        offset_px_y = 0
        
        while len(input_point)==0 and offset_px_y < self.sampling_step:
            res = self.get_grid(mask, offset_px_x, offset_px_y)
            erode_size = 10
        
            while True:
                # Erode the mask
                kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_size, erode_size))
                eroded_mask = cv2.erode(mask, kernel)

                input_point = np.argwhere(res & eroded_mask.astype(np.int64))
                input_point[:, (0, 1)] = input_point[:, (1, 0)]

                blob_labels, blob_sample = np.unique(mask_of_blobs, return_index=True)
                gt_fl = mask.flatten()

                blobs = np.zeros(blob_labels.shape[0], dtype=np.float32)
                for i, (bl, bs) in enumerate(zip(blob_labels, blob_sample)):
                    mask_bool = (mask_of_blobs==bl)
                    count = mask_bool.sum()
                    if not (gt_fl[bs]>=1.0 and count>self.min_blob_count): ## it's not a background blob or a false blob
                        blobs[i] = -1

                for i in range(input_point.shape[0]):
                    fl_ip = input_point[i][0]*mask.shape[1] + input_point[i][1]
                    idx = mask_of_blobs[input_point[i][1], input_point[i][0]]
                    blobs[idx]=1.0
                
                if not np.any(blobs==0.0) or erode_size==1:
                    break
                erode_size -= 1
            
            offset_px_x += 1
            if offset_px_x>self.sampling_step:
                offset_px_x = 0
                offset_px_y += 1
            
        
        input_label = [1 for _ in input_point]

        # still no mask? sample a random point
        if len(input_point) == 0:
            return self.sample_pixels_grid(mask_of_blobs, mask)

        return np.array(input_point), np.array(input_label)
    
    def sample(self, mode, border_mode, mask_of_blobs: np.ndarray, mask: np.ndarray):
        if mode=="common": # A
            return self.sample_pixels(mask_of_blobs, mask)
        elif mode=="center_of_mass": # B
            return self.sample_pixels_center_of_mass(mask_of_blobs, mask)
        elif mode=="random": # C
            return self.sample_pixels_random(mask_of_blobs, mask)
        elif mode=="grid": # D
            if border_mode=="on":
                return self.sample_pixels_eroded_grid(mask_of_blobs, mask)
            else:
                return self.sample_pixels_grid(mask_of_blobs, mask)


### Parameters of the script

In [44]:
def get_complete_output_path(bop, dataset_name, src_msk, model, create=False):
    results_dir = os.path.join(bop, dataset_name, src_msk, model)
    if create:
        os.makedirs(results_dir, exist_ok=True)
    return results_dir

## get sample points folder path
def get_spf_path(
    bop, dataset_name,
    src_msk, model, s_step,
    pnts_smp_mode: str='', 
    border_mode: str='on',
    create=False):
    folder_name = (
        pnts_smp_mode 
            + (("_" + str(s_step)) if pnts_smp_mode=="grid" else "")
            + ("_bm" if border_mode=="on" and pnts_smp_mode=="grid" else "")
    )
    sampled_points_folder = os.path.join(
        bop, "sampled_points_final", dataset_name, src_msk, model, folder_name
    )
    if create:
        os.makedirs(sampled_points_folder, exist_ok=True)
    return sampled_points_folder


def get_min_blob_number_based_on_dataset(dataset):
    return 20 if dataset=="portrait" else 10

In [45]:
def loadpaths(dataset_path, dataset_name, segmentator_name):
    
    if not os.path.isdir(os.path.join(dataset_path, dataset_name)):
        print("ERROR. provided dataset does not exist!")
        return None
    
    orig_images_folder = os.path.join(dataset_path, dataset_name, "imgs")
    gt_folder          = os.path.join(dataset_path, dataset_name, "gt")
    segmentator_folder = os.path.join(
        dataset_path, dataset_name, "segmentator_" + segmentator_name
    )
    
    ## Load input images ##
    test_imgs = glob.glob(os.path.join(orig_images_folder, '*'))
    bn = [os.path.basename(path[:-4]).zfill(6) for path in test_imgs]
    test_imgs = [
        test_imgs[i] for i in sorted(range(len(bn)), key=lambda k: bn[k])
    ]
    
    ## Load GT masks ##
    gt_masks = glob.glob(os.path.join(gt_folder, '*'))
    bn = [os.path.basename(path[:-4]).zfill(6) for path in gt_masks]
    gt_masks = [
        gt_masks[i] for i in sorted(range(len(bn)), key=lambda k: bn[k])
    ]
    
    print(os.path.join(segmentator_folder, '*.bmp'))

    ## Load DeepLabV3+ produced binary masks ##
    segmentator_bmasks = glob.glob(os.path.join(segmentator_folder, '*.bmp'))
    bn = [os.path.basename(path[:-4]).zfill(6) for path in segmentator_bmasks]
    segmentator_bmasks = [
        segmentator_bmasks[i]
        for i in sorted(range(len(bn)), key=lambda k: bn[k])
    ]

    ## Load DeepLabV3+ produced 3D masks ##
    segmentator_rmasks = glob.glob(os.path.join(segmentator_folder, '*.png'))
    bn = [os.path.basename(path[:-4]).zfill(6) for path in segmentator_rmasks]
    segmentator_rmasks = [segmentator_rmasks[i] for i in sorted(range(len(bn)), key=lambda k: bn[k])]
        
    return [test_imgs, gt_masks, segmentator_bmasks, segmentator_rmasks]


## Apply SAM

Predict with `SamPredictor.predict`. The model returns
 - masks  (`masks.shape  # (number_of_masks) x H x W) ` )
 - quality predictions for those masks
 - low resolution mask logits that can be passed to the next iteration of prediction.

The `predict()` function accepts three parameters (among many):

 - `point_coords`: an np.ndarray of 2D pixels that will provide SAM the checkpoints/seeds of the object to segment
 - `point_labels`: is the corresponding pixel a pixel belonging to the object (1) or not (0) ?

 - With `multimask_output=True` (the default setting), SAM outputs 3 masks, where `scores` gives the model's own estimation of the quality of these masks. This setting is intended for ambiguous input prompts, and helps the model disambiguate different objects consistent with the prompt. When `False`, it will return a single mask. For ambiguous prompts such as a single point, it is recommended to use `multimask_output=True` even if only a single mask is desired; the best single mask can be chosen by picking the one with the highest score returned in `scores`. This will often result in a better mask.

In [46]:
def get_data_paths(dataset_name: str, source_mask: str):
     paths = loadpaths(DATASET_PATH, dataset_name, source_mask)
     lens = list(map(lambda i: len(i), paths))
     print('paths', paths)

     # if there are no 3D liogits, repeat the 2D ones for consistency purposes
     if lens[3] == 0:
          paths[3] = paths[2]
          lens[3] = lens[2]

     # check if the images and the ground truths have the same quantity
     img_has_gt_len = lens[0] == lens[1]
     # check if either of the masks (2D or 3D logits) matches with the len of
     # the images
     mask_has_img_len = (
        (lens[0] == lens[2] == lens[3])
          or (lens[0] == lens[2])
          or (lens[0] == lens[3])
     )
     assert img_has_gt_len and mask_has_img_len, f"unbalanced datasets! {lens}"
     del lens

     return list(zip(*paths))

In [47]:
def read_imgs_from_paths(paths: tuple) -> tuple:
    # Get paths
    img_path, gt_mask_path, src_bmask_path, src_rmask_path = paths
    print("img_path          :", img_path)
    print("gt_mask_path      :", gt_mask_path)
    print("dplabv3_bmask_path:", src_bmask_path)
    print("dplabv3_rmask_path:", src_rmask_path)
    
    # Load images from disk using paths
    return (
        read_img(img_path),
        read_bmask(gt_mask_path),
        read_bmask(src_bmask_path),
        read_rmask(src_rmask_path)
    )

def img_shape_matches(loaded_images: tuple, source_mask: str = 'oracle'):
    img, gt_mask, src_bmask, src_rmask = loaded_images
    img_matches_gt = img.shape[:2] == gt_mask.shape

    if source_mask=="oracle":
        return img_matches_gt
            
    return img_matches_gt and (src_bmask.shape == src_rmask.shape[:2])

In [52]:

def perform(
        predictor,
        dataset_name: str,
        model_type: str = 'defaut',
        source_mask: str = 'deeplab',
        points_sampling_mode: str = 'grid',
        sampling_step: int = 50,
        border_mode: str = 'on',
        predict_threshold: float = 0.0):
    is_model_cuda = next(predictor.model.parameters()).is_cuda
    data_paths = get_data_paths(dataset_name, source_mask)
    data_len = len(data_paths)
    if VERBOSE: print("len of files:", data_len)

    # Get the Sampler (used for generating the prompts)
    sampler = Sampler(
        VERBOSE,
        sampling_step,
        get_min_blob_number_based_on_dataset(dataset_name))

    results_dir = get_complete_output_path(
        BASE_OUTPUT_PATH,
        dataset_name,
        source_mask,
        model_type,
        create=True)
    sampled_points_folder = get_spf_path(
        BASE_OUTPUT_PATH,
        dataset_name,
        source_mask,
        model_type,
        sampling_step,
        pnts_smp_mode=points_sampling_mode,
        border_mode=border_mode,
        create=True)
    print("results will be saved at for results_dir", results_dir)
    
    for idx, paths in tqdm(enumerate(data_paths), total=data_len):
        if VERBOSE: print(f" - img idx {str(idx+1).zfill(6)}/{data_len}:")
    
        # Load images from disk using paths
        loaded_images = read_imgs_from_paths(paths)

        if not img_shape_matches(loaded_images, source_mask=source_mask):
            print('A mismatch between image shapes has been found!')
            continue

        # Output paths to the resulting binary mask and logits
        basename = os.path.basename(paths[0])
        basename = basename[:basename.rfind(".") + 1]
        out_mask_path = os.path.join(results_dir, basename + "npy")
        out_logits_path = os.path.join(results_dir, 'low_'+ basename + "npy")
        out_mask_best_path = os.path.join(results_dir, basename + "jpg")
        if (not OVERWRITE_OUTPUTS
                and os.path.exists(out_mask_path)
                and os.path.exists(out_mask_best_path)
                and os.path.exists(out_logits_path)):
            print(f'"{out_mask_path}" already exists, skipping file.')
            continue

        img, gt_mask, src_bmask, _ = loaded_images

        mask_to_sample = gt_mask if source_mask == "oracle" else src_bmask
        # Count the number of distinct labels
        #  -> it corresponds to the # of blobs
        mask_of_blobs = measure.label(mask_to_sample)
        
        # Sample the checkpoints (at least one for blob)
        unique_blobs = np.unique(mask_of_blobs)
        num_blobs = unique_blobs.shape[0]
        bin_path = os.path.join(
            sampled_points_folder, os.path.basename(paths[1])[:-4] + ".bin_"
        )
        if num_blobs==1 and 0 in unique_blobs:
            input_point, input_label = np.array([]), np.array([])
            input_point.tofile(bin_path)
            
            binary_mask = mask_to_sample
            masks=[binary_mask]
            best_score_idx=0
        else:
            input_point, input_label = sampler.sample(
                points_sampling_mode,
                border_mode,
                mask_of_blobs,
                mask_to_sample)
            input_point.tofile(bin_path)
            predictor.set_image(img)
            masks, scores, _ = predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=True,
                return_logits=True
            )

            if is_model_cuda:
                torch.cuda.empty_cache()

            best_score_idx = np.argmax(scores)
            binary_mask = masks[best_score_idx] > predict_threshold
        
        # save the outputs
        cv2.imwrite(out_mask_best_path, masks[best_score_idx] * 255)
        np.save(out_mask_path, masks)
        np.save(out_logits_path, masks[best_score_idx])


In [53]:
def perform_all(dataset_name, predictor, model_type, source_mask, ):
    with torch.no_grad():
        perform(
            predictor,
            dataset_name,
            model_type=model_type,
            source_mask=source_mask,
            points_sampling_mode=SAMPLING_MODE,
            sampling_step=SAMPLING_STEP,
            border_mode=BORDER_MODE,
            predict_threshold=PREDICTION_TH)

## Experiments

### Create & Run Model

In [None]:
# Model init
print(f"creating sam {MODEL_TYPE} and moving it to device")
sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
sam.to(device=DEVICE)
print("creating predictor")
predictor = SamPredictor(sam)  

In [None]:
# Run experiment on the datasets
for dataset in DATASETS:
    perform_all(dataset, predictor, MODEL_TYPE, SOURCE_MASK)

## Evaluate the results

In [55]:
def print_metrics(metrics_obj: Metrics):
    iou_avg, dice_avg, mae_avg, wfm_avg, eme_avg  = metrics_obj.get_results()
    print('Total number of objects: ', len(metrics_obj.ious))    
    print(f"average iou      : {iou_avg*100:5.2f}")
    print(f"average dice     : {dice_avg*100:5.2f}")
    print(f"average mae      : {mae_avg*100:5.2f}")
    print(f"average f-measure: {wfm_avg*100:5.2f}")
    print(f"average e-measure: {eme_avg*100:5.2f}")
    print()

    return {
        'iou': iou_avg,
        'dice': dice_avg,
        'mae': mae_avg,
        'f-measure': wfm_avg,
        'e-measure': eme_avg,
    }

def dump_metrics(
        database_name: str,
        in_met: dict,
        pred_met: dict,
        fus_met: dict):
    indexes = ['iou', 'dice', 'mae', 'f-measure', 'e-measure']
    out_path = os.path.join(BASE_OUTPUT_PATH, database_name, 'metrics.csv')
    pd.DataFrame({
        'input': [in_met[idx] for idx in indexes],
        'pred': [pred_met[idx] for idx in indexes],
        'fusion': [fus_met[idx] for idx in indexes]
    }, index=indexes).to_csv(out_path)

In [40]:
def dump_stats(out_path: str, file_names: list, metrics: Metrics):
    pd.DataFrame({
        'filenames': file_names,
        'iou': metrics.ious,
        'dice': metrics.dices,
        'mae': metrics.maes,
        'f-measure': metrics.wfms,
        'e-measure': metrics.emes
    }).to_csv(out_path)

def dump_input_stats(database_name: str, file_names: list, metrics: Metrics):
    dump_stats(
        os.path.join(BASE_OUTPUT_PATH, database_name, 'input_stats.csv'),
        file_names, metrics
    )

def dump_pred_stats(database_name: str, file_names: list, metrics: Metrics):
    dump_stats(
        os.path.join(BASE_OUTPUT_PATH, database_name, 'pred_stats.csv'),
        file_names, metrics
    )

def dump_fusion_stats(database_name: str, file_names: list, metrics: Metrics):
    dump_stats(
        os.path.join(BASE_OUTPUT_PATH, database_name, 'fusion_stats.csv'),
        file_names, metrics
    )

In [41]:
class FusionRules:
    @staticmethod
    def apply_default(b_mask, pred_mask):
        return (
            (abs(255 - pred_mask.astype('uint8')) + 2 * b_mask) / 3
        ).astype(np.uint8) < 128

    @staticmethod
    def __apply_kitter(b_mask, pred_mask, numpy_op):
        bin_mask = b_mask.astype('float32')
        prob_mask = torch.sigmoid(torch.from_numpy(pred_mask[1, :, :])).numpy()

        return numpy_op([bin_mask, prob_mask], axis=0)

    @classmethod
    def apply_sum(cls, b_mask, pred_mask):
        return cls.__apply_kitter(b_mask, pred_mask, np.sum)

    @classmethod
    def apply_min(cls, b_mask, pred_mask):
        return cls.__apply_kitter(b_mask, pred_mask, np.min)

    @classmethod
    def apply_max(cls, b_mask, pred_mask):
        return cls.__apply_kitter(b_mask, pred_mask, np.max)

    @classmethod
    def apply_avg(cls, b_mask, pred_mask):
        return cls.__apply_kitter(b_mask, pred_mask, np.average)

    @classmethod
    def apply_median(cls, b_mask, pred_mask):
        return cls.__apply_kitter(b_mask, pred_mask, np.median)


In [None]:
for dataset in DATASETS:
    input_metrics = Metrics(dataset)
    sam_metrics = Metrics(dataset)
    fusion_metrics = Metrics(dataset)

    pred_path = os.path.join(
        BASE_OUTPUT_PATH, dataset, SOURCE_MASK, MODEL_TYPE
    )
    paths = get_data_paths(dataset, SOURCE_MASK)

    basenames = []
    mask_difference = []
    print("Processing data in directory: ", pred_path, flush=True)
    for idx, img_paths in tqdm(enumerate(paths), total=len(paths)):
        img_path, gt_mask_path, bmask_path, rmask_path = img_paths

        # get the input images and groundtruth
        gt_mask = read_bmask(gt_mask_path).astype(np.float32)
        b_mask = read_bmask(bmask_path).astype(np.float32)
        del gt_mask_path, bmask_path

        # get the prediction input
        basename = os.path.basename(img_path)
        basenames.append(basename)
        pred_mask_fname = basename[:basename.rfind(".")+1] + "jpg"
        pred_mask_path = os.path.join(pred_path, pred_mask_fname)
        logit_path = pred_mask_path[:-3] + 'npy'

        pred_mask = read_bmask(pred_mask_path).astype(np.float32)
        del img_path, pred_mask_fname, pred_mask_path

        mask_difference.append((gt_mask - b_mask).sum())

        # get the fusion mask
        logit_mask = None #np.load(logit_path)
        r_mask = read_rmask(rmask_path).astype(np.float32)
        fusion_mask = FusionRules.apply_default(gt_mask, r_mask)
        del logit_path, logit_mask

        input_metrics.step(b_mask, gt_mask)
        sam_metrics.step(pred_mask > PREDICTION_TH, gt_mask)
        fusion_metrics.step(fusion_mask, gt_mask)
        del gt_mask, b_mask, pred_mask

    print(f'Results from {SOURCE_MASK}')
    in_met_dict = print_metrics(input_metrics)
    dump_input_stats(dataset, basenames, input_metrics)
    del input_metrics

    print(f'Results from SAM')
    sam_met_dict = print_metrics(sam_metrics)
    dump_pred_stats(dataset, basenames, sam_metrics)
    del sam_metrics
    
    print(f'Results from Fusion')
    fus_met_dict = print_metrics(fusion_metrics)
    dump_fusion_stats(dataset, basenames, fusion_metrics)
    del fusion_metrics

    dump_metrics(
        dataset, in_met_dict, sam_met_dict, fus_met_dict
    )

    print(max(mask_difference))

In [None]:
x = np.load('/media/zanoni/SSDZ/SAM/outputs/ribs/deeplab/default/low_799.npy')
print(x.shape, x.max(), x.min())
sb.heatmap(torch.sigmoid(torch.from_numpy(x[1, :, :])))


## evaluate fusion performance

In [None]:
model_type = "default"
#### TODO: oracle still suffers from: no folder "segmentator_oracle"...need to fix it!
source_mask = "deeplab" # "deeplab", "pvtv2", "oracle", "sota" # ...it depends!
points_sampling_mode = "D"
sampling_step = 50
border_mode = "off"

for dataset in DATASETS:
    pred_path = os.path.join(BASE_OUTPUT_PATH, dataset, source_mask, model_type)
    print(pred_path)
    os.makedirs(pred_path, exist_ok=True)

    metrics, source_mask_metrics, metrics_fusion = Metrics(dataset), Metrics(dataset), Metrics(dataset)
    test_imgs, gt_masks, src_bmasks, src_rmasks = loadpaths(DATASET_PATH, dataset, source_mask)
    toiterate = zip(test_imgs, gt_masks, src_bmasks, src_rmasks)
    
    for idx, paths in tqdm(enumerate(toiterate), total=len(gt_masks)):
        if VERBOSE:
            print(f" - img idx {str(idx+1).zfill(6)}/{len(test_imgs)}:")

        # Get paths
        img_path, gt_mask_path, src_bmask_path, src_rmask_path = paths
    
        basename = os.path.basename(img_path)
        in_mask_path = pred_path + "/" + basename[:basename.rfind(".")+1]+"jpg"
        
        # Load images from disk using paths
        gt_mask   = read_bmask(gt_mask_path).astype(np.float32)
        src_bmask = read_bmask(src_bmask_path).astype(np.float32)
        src_rmask = read_bmask(src_rmask_path).astype(np.float64) * 255
        assert gt_mask.shape == src_bmask.shape \
                == src_rmask.shape[:2], "Error: shape mismatch"
        
        img = cv2.imread(in_mask_path)[:, :, 0]
        binary_mask = img > 0
        
        metrics.step(binary_mask, gt_mask)
        source_mask_metrics.step(src_bmask, gt_mask)
        
        pred_mask = abs(255 - img.astype('uint8'))
        pred_mask = ((pred_mask + 2*src_rmask)/3).astype(np.uint8) < 128
        metrics_fusion.step(pred_mask, gt_mask)
        
        # break

    iou_avg, dice_avg, mae_avg, wfm_avg, eme_avg  = metrics.get_results()
    print(len(metrics.ious))
    print()
    
    print("SAM alone metrics:")
    print(f"average iou      : {iou_avg*100:5.2f}")
    print(f"average dice     : {dice_avg*100:5.2f}")
    # print(f"average mae      : {mae_avg*100:5.2f}")
    # print(f"average f-measure: {wfm_avg*100:5.2f}")
    # print(f"average e-measure: {eme_avg*100:5.2f}")

    print()
    
    iou_avg, dice_avg, mae_avg, wfm_avg, eme_avg  = source_mask_metrics.get_results()
    print(f"segmentator_{source_mask} metrics:")
    print(f"average iou      : {iou_avg*100:5.2f}")
    print(f"average dice     : {dice_avg*100:5.2f}")
    # print(f"average mae      : {mae_avg*100:5.2f}")
    # print(f"average f-measure: {wfm_avg*100:5.2f}")
    # print(f"average e-measure: {eme_avg*100:5.2f}")
    
    print()
    
    iou_avg, dice_avg, mae_avg, wfm_avg, eme_avg  = metrics_fusion.get_results()
    print("SAM-fusion performance:")
    print(f"average iou      : {iou_avg*100:5.2f}")
    print(f"average dice     : {dice_avg*100:5.2f}")
    # print(f"average mae      : {mae_avg*100:5.2f}")
    # print(f"average f-measure: {wfm_avg*100:5.2f}")
    # print(f"average e-measure: {eme_avg*100:5.2f}")
    
    with open(f'{pred_path}/metrics.pickle', 'wb') as handle:
        pickle.dump(metrics, handle)
    with open(f'{pred_path}/source_mask_metrics.pickle', 'wb') as handle:
        pickle.dump(source_mask_metrics, handle)
    with open(f'{pred_path}/metrics_fusion_{source_mask}.pickle', 'wb') as handle:
        pickle.dump(metrics_fusion, handle)



In [None]:
metrics, source_mask_metrics = analyze(predictor, model_type, source_mask, dataset, "D", 50, "off")
# iou_avg, dice_avg, mae_avg, wfm_avg, eme_avg  = metrics.get_results()
iou_avg, dice_avg, mae_avg, wfm_avg, eme_avg  = source_mask_metrics.get_results()
print(f"average iou      : {iou_avg*100:5.2f}")
print(f"average dice     : {dice_avg*100:5.2f}")
print(f"average mae      : {mae_avg*100:5.2f}")
print(f"average f-measure: {wfm_avg*100:5.2f}")
print(f"average e-measure: {eme_avg*100:5.2f}")

## VISUALIZE IMAGES (beta)

In [None]:


dataset_path = "/home/fusaro/segment-anything"
def perform(predictor, model_type, source_mask, dataset, points_sampling_mode, sampling_step, border_mode):
    global verbose
    min_blob_count = 20 if dataset=="portrait" else 10
    test_folder, gt_folder, dplabv3_folder, pvtv2_folder,\
              img_ext = loadinfo(dataset, dataset_path)
    
    if "COCO" in dataset:
        test_imgs, gt_masks, src_bmasks, src_rmasks = loadpaths_COCO(test_folder, gt_folder, dplabv3_folder)
    else:
        if source_mask=="pvtv2":
            test_imgs, gt_masks, src_bmasks, src_rmasks = loadpaths(test_folder, gt_folder, pvtv2_folder)
        else:
            test_imgs, gt_masks, src_bmasks, src_rmasks = loadpaths(test_folder, gt_folder, dplabv3_folder)

        assert len(test_imgs) == len(gt_masks) == len(src_bmasks) == len(src_rmasks),\
                  f"unbalanced datasets! {len(test_imgs)} {len(gt_masks)} {len(src_rmasks)}"
    
    toiterate = zip(test_imgs, gt_masks, src_bmasks, src_rmasks)

    print("len of files:", len(test_imgs)) if VERBOSE else None
    
    # metrics = Metrics()
    # source_mask_metrics = Metrics()
    
    results_dir = "/home/fusaro/segment-anything/SKIN_out/" + dataset + "/" + source_mask + "/" + model_type
    # os.makedirs(results_dir, exist_ok=True)
    
    print("processing for results_dir", results_dir)
    
    # sampled_points_folder = os.path.join("/home/fusaro/segment-anything/SKIN_out/sampled_points_final", dataset, source_mask, model_type, \
    #                             points_sampling_mode + (("_" + str(sampling_step)) if points_sampling_mode=="D" else "") \
    #                             + ("_bm" if border_mode=="on" and points_sampling_mode=="D" else ""))

    sampled_points_folder = os.path.join(dataset_path, "sampled_points_final", dataset, source_mask, \
                                points_sampling_mode + (("_" + str(sampling_step)) if points_sampling_mode=="D" else "") \
                                + ("_bm" if border_mode=="on" and points_sampling_mode=="D" else ""))
    
    # os.makedirs(sampled_points_folder, exist_ok=True)
    # verbose=True
    for idx, paths in tqdm(enumerate(toiterate), total=len(test_imgs)):
        if idx < 46:
            continue
        
        print(f" - img idx {str(idx+1).zfill(6)}/{len(test_imgs)}:") if VERBOSE else None
        
        img_path, gt_mask_path, src_bmask_path, src_rmask_path = paths
        print("img_path          :", img_path)
        print("gt_mask_path      :", gt_mask_path)
        print("dplabv3_bmask_path:", src_bmask_path)
        print("dplabv3_rmask_path:", src_rmask_path)

        # Load images from disk using paths
        img       = read_img(img_path)
        gt_mask   = read_bmask(gt_mask_path)
        src_bmask = 1.0 - read_bmask(src_rmask_path)
        src_rmask = read_rmask(src_rmask_path).astype(np.float64)
        
        input_point = np.fromfile(os.path.join(sampled_points_folder,os.path.basename(gt_mask_path)[:-4]+".bin"), dtype=np.int64).reshape(-1, 2)
        print(input_point)
        # if input_point.shape[0]<2:
        #     continue
#         print(os.path.join(sampled_points_folder,os.path.basename(gt_mask_path)[:-4]+".bin_"))
#         print(input_point.shape)
        
#         # input_label = [1 for _ in input_point]
        
#         # TO VISUALIZE computed mask, show the superposition with the original image
#         # and SAM's predicted score and IoU wrt Ground Truth mask
#         color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
        
        color = np.array([1.0, 0, 0.80, 0.6])
        
        
#         plt.clf()
#         fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(20, 20))
#         axes[0].imshow(img)
#         axes[0].set_title('img', fontsize = 16)
#         if input_point.shape[0]>0:
#             print("true")
#             axes[0].scatter([input_point[:, 0]], [input_point[:, 1]], color='green', marker='*', s=100, edgecolor='white', linewidth=1.25) # this is if you want the star
        
#         axes[1].imshow(img)
#         h, w = gt_mask.shape[-2:]
#         mask_image = gt_mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
#         axes[1].set_title('GT', fontsize = 16)
#         axes[1].imshow(mask_image)
        
#         axes[2].imshow(img)
#         h, w = pred_mask.shape[-2:]
#         mask_image = pred_mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
#         axes[2].set_title('sota-dlv3', fontsize = 16)
#         axes[2].imshow(mask_image)
        
#         axes[3].imshow(img)
#         h, w = src_bmask.shape[-2:]
#         mask_image = src_bmask.reshape(h, w, 1) * color.reshape(1, 1, -1)
#         axes[3].set_title('dlv3', fontsize = 16)
#         axes[3].scatter([input_point[:, 0]], [input_point[:, 1]], color='green', marker='*', s=100, edgecolor='white', linewidth=1.25) # this is if you want the star
#         axes[3].imshow(mask_image)
        
#         axes[4].imshow(img)
#         h, w = fusion_mask.shape[-2:]
#         mask_image = fusion_mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
#         axes[4].set_title('sota-fusion', fontsize = 16)
#         axes[4].imshow(mask_image)

#         fig.tight_layout()
        
#         plt.show()
        
#         plt.clf()
#         plt.imshow(img)
#         plt.scatter([input_point[:, 0]], [input_point[:, 1]], color='blue', marker='*', s=80, edgecolor='white', linewidth=1.25) # this is if you want the star
#         plt.axis('off')
#         plt.savefig(f'imgs/new_{idx}_img.png', dpi=300, bbox_inches='tight')

#         plt.clf()
#         # color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
#         h, w = gt_mask.shape[-2:]
#         mask_image = gt_mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
#         plt.imshow(img)
#         plt.imshow(mask_image)
#         plt.axis('off')
#         plt.savefig(f'imgs/new_{idx}_gt.png', dpi=300, bbox_inches='tight')
        
#         plt.clf()
#         plt.imshow(img)
#         h, w = pred_mask.shape[-2:]
#         mask_image = pred_mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
#         plt.imshow(mask_image)
#         plt.axis('off')
#         plt.savefig(f'imgs/new_{idx}_samdlv3.png', dpi=300, bbox_inches='tight')
        
#         plt.clf()
#         plt.imshow(img)
#         h, w = fusion_mask.shape[-2:]
#         mask_image = fusion_mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
#         plt.imshow(mask_image)
#         plt.axis('off')
#         plt.savefig(f'imgs/new_{idx}_samfusion.png', dpi=300, bbox_inches='tight')
        plt.clf()
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.imshow(img)
        h, w = src_bmask.shape[-2:]
        mask_image = src_bmask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        ax.scatter([input_point[:, 0]], [input_point[:, 1]], color='green', marker='*', s=100, edgecolor='white', linewidth=1.25) # this is if you want the star
        ax.imshow(mask_image)
        fig.tight_layout()
        plt.show()
    
        plt.clf()
        plt.imshow(img)
        h, w = src_bmask.shape[-2:]
        mask_image = src_bmask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        plt.scatter([input_point[:, 0]], [input_point[:, 1]], color='blue', marker='*', s=30, edgecolor='white', linewidth=0.5) # this is if you want the star
        plt.imshow(mask_image)
        plt.axis('off')
        plt.savefig(f'imgs/{idx}_dlv3_d50.png', dpi=300, bbox_inches='tight')
        
        if idx>24:
            break

# perform(None, "default", "deeplab", "COCO_animal", "D", 50, "off")
perform(None, "default", "deeplab", "portrait", "D", 50, "off")
# perform(None, "default", "deeplab", "sota_FoldDA1_3", "D", 50, "off")