# Import Necessary Libraries

In [None]:
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
import matplotlib.pyplot as plt
from PIL import Image
import supervision as sv
import numpy as np

# To save long lists to pickle files
import pickle

# Converting the data to a COCO JSON file
import json
import copy
import os

# Set Paths to Image Names and Model Checkpoints

In [None]:
# Load in the CLAHE corrected images
image_path = "your_path_here"

# get the names of all the images in the folder
img_names = os.listdir(image_path)
img_names = sorted(img_names)

# convert img_names to a list
img_names = [img_names]

In [None]:
# Determine the model type as well as the checkpoint. Include the path to the checkpoint that you downloaded.
sam = sam_model_registry["vit_h"](checkpoint="your_path_here/sam_vit_h_4b8939.pth") # had highest iou scores
# sam = sam_model_registry["vit_l"](checkpoint="your_path_here/sam_vit_l_0b3195.pth")
# sam = sam_model_registry["vit_b"](checkpoint="your_path_here/sam_vit_b_01ec64.pth")

# Set the predictor
predictor = SamPredictor(sam)
mask_generator = SamAutomaticMaskGenerator(sam)

# Get Sam Results and Convert to COCO Style JSON annotations

In [None]:
#### Code here can be found at https://stackoverflow.com/questions/49494337/encode-numpy-array-using-uncompressed-rle-for-coco-dataset

def encode(binary_mask):
    rle = {"counts": [], "size": list(binary_mask.shape)}

    flattened_mask = binary_mask.ravel(order="F")
    diff_arr = np.diff(flattened_mask)
    nonzero_indices = np.where(diff_arr != 0)[0] + 1
    lengths = np.diff(np.concatenate(([0], nonzero_indices, [len(flattened_mask)])))

    # note that the odd counts are always the numbers of zeros
    if flattened_mask[0] == 1:
        lengths = np.concatenate(([0], lengths))

    rle["counts"] = lengths.tolist()

    return rle

In [None]:
# Get the sam results for all of the cryo images

cryo_sam_results_combined = []
images_dict = []
sam_results_by_image = []

# get each of the images from the image_path one at a time and perform the sam mask generation
for i,img_name in enumerate(img_names[:]):
    
    img_rgb = Image.open(image_path+img_name)
    # show the image
    plt.imshow(img_rgb)
    
    img_rgb = np.array(img_rgb).astype(np.uint8)
    img_rgb = np.repeat(img_rgb[:, :, np.newaxis], 3, axis=2)
    sam_result = mask_generator.generate(img_rgb)
    
    sam_results_by_image.append(copy.deepcopy(sam_result))
    
    for j,annotation in enumerate(sam_result):
        # annotation_id
        annotation["id"] = j+1
        
        # segmentation
        # convert the "segmentation" value to binary
        binary_mask = (annotation["segmentation"]).astype(np.uint8)

        rle_mask = encode(binary_mask)
        
        # bbox and area are already included
        
        # replace the "segmentation" value with the rle mask
        annotation["segmentation"] = rle_mask
        
        # iscrowd
        annotation["iscrowd"] = 1
        
        # attributes
        annotation["attributes"] = {}
        annotation["attributes"]["occluded"] = "false" 
        
        # image_id
        annotation["image_id"] = i+1
        
        # category_id
        annotation["category_id"] = 1
        
        # remove point_coords, stability_score, crop_box, predicted_iou
        annotation.pop("point_coords")
        annotation.pop("stability_score")
        annotation.pop("crop_box")
        annotation.pop("predicted_iou")
            
    for result in sam_result:
        cryo_sam_results_combined.append(result)
        
    image_dict = {}
    image_dict["id"] = i+1
    width, height = Image.open(image_path+img_name).size
    image_dict["width"] = width
    image_dict["height"] = height
    image_dict["file_name"] = img_name
    image_dict["license"] = 0
    image_dict["flickr_url"] = ""
    image_dict["coco_url"] = ""
    image_dict["date_captured"] = 0
    
    images_dict.append(image_dict)

# licenses, info, categories
info_dict = {"licenses":[{"name":"","id":0,"url":""}],
             "info":{"contributor":"","date_created":"","description":"","url":"","version":"","year":""},
             "categories":[{"id":1,"name":"Full","supercategory":""},{"id":2,"name":"Partial","supercategory":""},{"id":3,"name":"Empty","supercategory":""}]}

# Combine the dicts:
coco_json = dict(info_dict, images = images_dict, annotations = cryo_sam_results_combined)

# Save the coco_json to a json file
with open('your_path_here/cryo_sam_results_singlepic.json', 'w') as f:
    json.dump(coco_json, f)

In [None]:
# load in the coco_json file
with open('your_path_here/cryo_sam_results_singlepic.json', 'r') as f:
    coco_json = json.load(f)
    
# load in the sam_results_by_image pickle file
with open('your_path_here/sam_results_by_image_singlepic.pickle', 'rb') as f:
    sam_results_by_image = pickle.load(f)

# Display The Sam Results For a Single Image

In [None]:
#vit b example

# for each of the images in img_names display the original image and the annotated image
mask_annotator = sv.MaskAnnotator(color_lookup = sv.ColorLookup.CLASS)

for i,img_name in enumerate(img_names[:1]):
    img_rgb = Image.open(image_path+img_name)
    img_rgb = np.array(img_rgb).astype(np.uint8)
    img_rgb = np.repeat(img_rgb[:, :, np.newaxis], 3, axis=2)
    detections = sv.Detections.from_sam(sam_result=sam_results_by_image[i])
    
    # create an array of 1s and 0s the length of the annotations in sam_results_by_image[i]
    # these are the class_ids
    class_ids = np.ones(len(sam_results_by_image[i])).astype(int)
    
    # add the class_ids to the detections where confidence=None, class_id=None, tracker_id=None, data={}
    detections.class_id = class_ids
    
    annotated_img = mask_annotator.annotate(img_rgb.copy(), detections=detections)
    sv.plot_images_grid(
        images=[img_rgb, annotated_img],
        grid_size=(1, 2),
        titles= ["Original", "Annotated"],
    )