# Import the required libraries

In [1]:
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import cv2
import matplotlib.pyplot as plt
import numpy as np
#import yaml
from PIL import Image
import glob
import os

# Config

In [2]:
path = 'mini_testdatensatz'
min_overlap_percentage = 99

# Check the segmentation folder

In [3]:
if not os.path.exists(os.path.join(path,'segmentation_images')):
    os.mkdir(os.path.join(path,'segmentation_images'))

filelist = [ f for f in os.listdir(os.path.join(path,'segmentation_images'))]
for f in filelist:
    os.remove(os.path.join(path,'segmentation_images', f))

# Load the model

In [4]:
sam = sam_model_registry["default"](checkpoint="sam_vit_h_4b8939.pth")
mask_generator = SamAutomaticMaskGenerator(sam)

# Path to the images

In [5]:
img_path_list = sorted(glob.glob(path+'/test/*')) # normaly glob.glob(path+'/train/*')

# Function for the loop

In [6]:
def count_true_values(mask):
    return np.sum(mask)

# Loop 

In [7]:
for img_path in img_path_list:
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    masks = mask_generator.generate(img)
    sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
    board_mask = sorted_masks[1]['segmentation']
    inside_masks = []
    for i in range(2, len(sorted_masks)):
        small_mask = sorted_masks[i]['segmentation']
        num_small_mask_true = np.sum(small_mask)
        required_overlap = int(num_small_mask_true * min_overlap_percentage / 100)
        indices = np.where(small_mask)
        num_overlapping_true = np.sum(board_mask[indices])

        if num_overlapping_true >= required_overlap:
            inside_masks.append(small_mask)
    
    inside_masks_sorted = sorted(inside_masks, key=count_true_values, reverse=True)
    plug_mask = inside_masks_sorted[1]
    plug_image = Image.fromarray(plug_mask.astype('uint8') * 255, mode='L')
    img_name = img_path.split('/')[-1]
    plug_image.save(path+'/segmentation_images/'+img_name)

