# Import

In [None]:
import cv2
import numpy as np
import csv
import glob
import os


from skimage import measure
from scipy import ndimage
from tqdm import tqdm
from PIL import Image
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

# Config
Attention: 
* Depending on the height of the camera when recording, the min/max number of pixels must be adjusted.
* If another plug is used, the default_ratio must be adjusted.


In [None]:
path = 'dataset/mini_testdatensatz'
path = path + '/'
device = "cuda"
default_ratio = 3.5
min_pixel = 5000 # für höhe 100 sehr guter wert: 12000
max_pixel = 10000 # für höhe 100 sehr guter wert: 24000

# Init model

In [None]:
sam = sam_model_registry["default"](checkpoint="sam_model/sam_vit_h_4b8939.pth")
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)

# Folder for Segementation

Attention: This functions delets all segmentation images in the path, if the path contains them

In [None]:
if not os.path.exists(os.path.join(path,'segmentation_images')):
    os.mkdir(os.path.join(path,'segmentation_images'))

filelist = [ f for f in os.listdir(os.path.join(path,'segmentation_images'))]
for f in filelist:
    os.remove(os.path.join(path,'segmentation_images', f))

# Define Functions for the loop

In [None]:
def find_biggest_contiguous_area(array):
    labels,_ = ndimage.label(array)
    counts = np.bincount(labels.flatten())
    max_label = np.argmax(counts[1:]) + 1
    region = (labels == max_label)
    
    return region

# List with all image paths

In [None]:
img_path_list = sorted(glob.glob(path+'/train/*')) # normaly glob.glob(path+'/train/*')

# Sam Segmentation Loop

In [None]:
for img_path in tqdm(img_path_list):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    masks = mask_generator.generate(img)
    sorted_masks = sorted(masks, key=(lambda x: x['area']), reverse=True)
    ratio_with_mask = [[],[]]

    for mask in sorted_masks:
        if mask['area'] > min_pixel and mask['area'] < max_pixel:
            maybe_plugmask = find_biggest_contiguous_area(mask['segmentation'])
            maybe_plugmask = np.array(maybe_plugmask, dtype=np.uint8)
            num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(maybe_plugmask)
            bbox = stats[1][:4]
            x,y,w,h = bbox
            ratio_with_mask[0].append(w/h)
            ratio_with_mask[1].append(maybe_plugmask)

    closest_value = min(ratio_with_mask[0], key=lambda x: abs(x - default_ratio))
    closest_index = ratio_with_mask[0].index(closest_value)
    plug_mask = ratio_with_mask[1][closest_index]

    plug_image = Image.fromarray(plug_mask.astype('uint8') * 255, mode='L')
    img_name = img_path.split('/')[-1]
    plug_image.save(path+'/segmentation_images/'+img_name)

# test singel images

In [None]:
number = 100

mask = Image.open(path+'/segmentation_images/picture_'+str(number)+'.png').convert('1')
mask = np.array(mask)

img = cv2.imread(path+'/train/picture_'+str(number)+'.png') # normaly /train/ instead of /test/
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

result_array = np.zeros_like(img)
result_array[mask] = img[mask]
result_image = Image.fromarray(result_array)
display(result_image)
