In [2]:
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import cv2
import numpy as np
import csv
from PIL import Image
import glob
import os
from skimage import measure
from scipy import ndimage
from tqdm import tqdm

In [3]:
path = 'dataset_2_2023_05_04'
device = "cuda"
min_overlap_percentage = 85
default_ratio = 3.5
min_pixel = 12000
max_pixel = 24000

In [4]:
sam = sam_model_registry["default"](checkpoint="sam_vit_h_4b8939.pth")
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)

In [5]:
if not os.path.exists(os.path.join(path,'segmentation_images')):
    os.mkdir(os.path.join(path,'segmentation_images'))

filelist = [ f for f in os.listdir(os.path.join(path,'segmentation_images'))]
for f in filelist:
    os.remove(os.path.join(path,'segmentation_images', f))

In [6]:
img_path_list = sorted(glob.glob(path+'/train/*')) # normaly glob.glob(path+'/train/*')

In [7]:
def find_biggest_contiguous_area(array):
    labels,_ = ndimage.label(array)
    counts = np.bincount(labels.flatten())
    max_label = np.argmax(counts[1:]) + 1
    region = (labels == max_label)
    
    return region

In [8]:
for img_path in tqdm(img_path_list):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    masks = mask_generator.generate(img)
    sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
    ratio_with_mask = [[],[]]

    for mask in sorted_masks:
        if mask['area'] > min_pixel and mask['area'] < max_pixel:
            maybe_plugmask = find_biggest_contiguous_area(mask['segmentation'])
            maybe_plugmask = np.array(maybe_plugmask, dtype=np.uint8)
            num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(maybe_plugmask)
            bbox = stats[1][:4]
            x,y,w,h = bbox
            ratio_with_mask[0].append(w/h)
            ratio_with_mask[1].append(maybe_plugmask)
            # print("ratio",w/h)
            # print("area",mask['area'])
            # plug_mask = mask['segmentation']
            # plug_image = Image.fromarray(plug_mask.astype('uint8') * 255, mode='L')
            # display(plug_image)


    closest_value = min(ratio_with_mask[0], key=lambda x: abs(x - default_ratio))
    closest_index = ratio_with_mask[0].index(closest_value)
    plug_mask = ratio_with_mask[1][closest_index]

    plug_image = Image.fromarray(plug_mask.astype('uint8') * 255, mode='L')
    img_name = img_path.split('/')[-1]
    plug_image.save(path+'/segmentation_images/'+img_name)


  1%|          | 20/2400 [01:22<2:45:38,  4.18s/it]

In [None]:
number = 100

mask = Image.open(path+'/segmentation_images/picture_'+str(number)+'.png').convert('1')
mask = np.array(mask)

img = cv2.imread(path+'/train/picture_'+str(number)+'.png') # normaly /train/ instead of /test/
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

result_array = np.zeros_like(img)
result_array[mask] = img[mask]
result_image = Image.fromarray(result_array)
display(result_image)
