In [None]:
import math
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.font_manager import FontProperties
import sys
import torch
import cv2
import supervision as sv


sys.path.insert(0, "../CLIP")
sys.path.append("..")
from industrial_clip.evaluation import Evaluation

from segment_anything import build_sam, sam_model_registry, SamAutomaticMaskGenerator

from evaluation.utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Visualize results

In [None]:
# plot sam results
def plot_sam_result(masks, image):
    mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
    detections = sv.Detections.from_sam(sam_result=masks)
    image = image.copy()
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    annotated_image = mask_annotator.annotate(scene=image, detections=detections)

    sv.plot_images_grid(
        images=[image, annotated_image],
        grid_size=(1, 2),
        titles=['Source Image', 'Segmented Image']
    )

In [None]:
# visualize the individual segments
def visualize_segments(segments, result, prompt_idx=0):
    cols = 10
    rows = math.ceil(len(segments) / cols)

    fig, axs = plt.subplots(rows, cols, figsize=(cols*5, rows*5))

    axs = axs.flatten()

    for i, img in enumerate(segments):
        axs[i].imshow(img)
        #axs[i].set_title(round(result[i][prompt_idx].item(), 3), fontsize=30)
        axs[i].axis('off')

    for ax in axs[i+1:]:
        ax.axis('off')

    plt.subplots_adjust(wspace=0.2, hspace=0.2)
    plt.show()

In [None]:
# plot language-guided segmentation results
def plot_results(axs, res, pred_prompt_idx, threshold, resize=False, image_shape=None):
    cmap = mcolors.ListedColormap(['none', 'red'])

    indices = get_predictions_above_threshold(res, pred_prompt_idx, threshold=threshold)
    for idx in indices:
        # get mask
        mask = masks[idx]["segmentation"]

        if resize:
            mask = cv2.resize(mask.astype(float), (image_shape[1], image_shape[0]))

        # overlay the mask
        axs.imshow(mask, cmap=cmap, alpha=0.8)

In [None]:
# visualize language-guided segmentation results in one plot
def visualize(res_zeroshotclip, res_iclip, pred_prompt_idx, threshold, image, resize=False, vertical=False, name=None):
    if vertical:
        fig, axs = plt.subplots(3, 1, figsize=(7.8, 17.5))
    else:
        fig, axs = plt.subplots(1, 3, figsize=(31.5, 6.3))
    
    font = FontProperties()
    font.set_name('Arial')

    axs[0].axis('off')
    axs[0].imshow(image)
    axs[0].set_title("Input", fontproperties=font, fontweight='bold', fontsize=30, pad=9)

    axs[1].axis('off')
    axs[1].imshow(image)
    axs[1].set_title("Zero-shot CLIP", fontproperties=font, fontweight='bold', fontsize=30, pad=9)
    plot_results(axs[1], res_zeroshotclip, pred_prompt_idx, threshold, resize=resize, image_shape=image.shape)

    axs[2].axis('off')
    axs[2].imshow(image)
    axs[2].set_title("Results", fontproperties=font, fontweight='bold', fontsize=30, pad=9)
    plot_results(axs[2], res_iclip, pred_prompt_idx, threshold, resize=resize, image_shape=image.shape)

    if vertical:
        fig.tight_layout(pad=0, h_pad=1.0)
        fig.subplots_adjust(bottom=0, top=0.97, left=0.0, right=1.0)
    else:
        fig.tight_layout(pad=0, h_pad=0.0, w_pad=0.05)
        fig.subplots_adjust(bottom=-0.10, top=0.90, left=-0.01, right=1.01)

    if name is not None:
        fig.savefig(name)

## Eval SAM + CLIP

In [None]:
sam_model_type = "vit_h"
sam_checkpoint = "sam_vit_h_4b8939.pth"

sam_model = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint)
sam_model.to(device)

In [None]:
def eval(clip_model, clip_model_dir, clip_epoch, image_path, prompts, segments_with_background=False):
    full_res_image, resized_image = load_and_resize_image(image_path, sam_image_size)

    mask_generator = SamAutomaticMaskGenerator(
        model=sam_model, 
        points_per_side=sam_points_per_side, 
        pred_iou_thresh=sam_predicted_iou_threshold, 
        stability_score_thresh=sam_stability_score_thresh, 
        stability_score_offset=sam_stability_score_offset,
        box_nms_thresh=sam_box_nms_thresh, 
        min_mask_region_area=sam_min_mask_region_area
    )

    # generate masks
    masks = mask_generator.generate(resized_image)
    # remove masks that are too small
    masks = [mask for mask in masks if mask["area"] > sam_min_mask_size]

    # extract segments
    segments = []
    new_masks = []
    for mask in masks:
        segmented_image = segment_image(resized_image, mask["segmentation"])
        bbox = convert_box_xyxy_dilate_square(mask["bbox"], image_size=resized_image.shape[0:2], make_squared=False, dilation=dilation)
        if not segments_with_background:
            segment = segmented_image[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])]
        else:
            segment = resized_image[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])]

        if segment.shape[0] == 0 or segment.shape[1] == 0:
            continue

        if segment.shape[0] + 10 >= resized_image.shape[0] or segment.shape[1] + 10 >= resized_image.shape[1]:
            continue
        
        new_masks.append(mask)
        segments.append(segment)

    masks = new_masks

    cfg_list = [
        "TRAINER.ZSCLIP.PROMPT_TEMPLATE", "a photo of an industrial product {}",
    ]
    zeroshotclip = Evaluation("ZeroshotCLIP", cfg_list, prompts)
    zeroshotclip.build()

    cfg_list = [
        "TRAINER.COOP.CTX_INIT", "X X X X a photo of an industrial product",
        "TRAINER.COOP.CLASS_TOKEN_POSITION", "end"
    ]
    iclip = Evaluation(clip_model, cfg_list, prompts, model_dir=clip_model_dir, epoch=clip_epoch)
    iclip.build()

    image_tensors = numpy_to_tensor(segments, image_size=zeroshotclip.visual_resolution)
    prompt_idx = list(range(len(prompts)))

    res_zeroshotclip = zeroshotclip.forward(image_tensors, prompt_idx)[0]
    res_iclip = iclip.forward(image_tensors, prompt_idx)[0]

    return (full_res_image, resized_image, res_zeroshotclip, res_iclip, masks, segments)

### Evaluation

In [None]:
# set SAM parameters

sam_image_size = 1024
sam_predicted_iou_threshold = 0.90
sam_stability_score_thresh = 0.90
sam_stability_score_offset = 1.0
sam_box_nms_thresh = 0.6
sam_min_mask_region_area = 700
sam_min_mask_size = 700
sam_points_per_side = 160
dilation = 5

In [None]:
%%capture

# set the path to the image, prompt, model, path to model, and epoch (=40)

image_path = "<<<path to image>>>"
prompts = [
    "<<<add prompt>>>",
    ""
]
full_res_image, resized_image, res_zeroshotclip, res_iclip, masks, segments = eval(
    "CoOpIATA",
    "<<<path to model>>>",
    40,
    image_path, 
    prompts, 
    segments_with_background=True
)

In [None]:
plot_sam_result(masks, resized_image)

In [None]:
visualize_segments(segments, res_iclip)

In [None]:
pred_prompt_idx = 0
threshold = 0.90
print("Prompt: %s" % prompts[pred_prompt_idx])
visualize(res_zeroshotclip, res_iclip, pred_prompt_idx, threshold, resized_image, vertical=True)