In [1]:
import cv2
import os
import pandas as pd 
import sys
from collections import defaultdict
from tqdm import tqdm

from autodistill_grounded_sam import GroundedSAM
from autodistill.detection import CaptionOntology





In [None]:
os.chdir(r"D:\git-repos\mluerig\grounded-sam-intro")
sys.path.append("scripts")
from utils import model_helpers

In [None]:
base_model = GroundedSAM(
    ontology=CaptionOntology({"mouse": "mouse",}),
    text_threshold = 0.1,
    box_threshold = 0.1,
    )


In [None]:
## point to image root dir
root_dir = fr"data_raw\input_imgs\test_mice_copy"

## loop through all subfolders
dict_imgs = defaultdict(list)
for root, dirs, files in tqdm(list(os.walk(root_dir))):
    for file_name in files:
        file_path = os.path.join(root, file_name)
        parts = os.path.normpath(file_path).split(os.sep)
        info = {
            "image_name": file_name,
            "parent_name": parts[-2],
            "image_path": file_path,
        }
        dict_imgs[file_name].append(info)

## make a df        
flattened_data = []
for key, value_list in dict_imgs.items():
    for value in value_list:
        flattened_entry = {}
        flattened_entry.update(value)
        flattened_data.append(flattened_entry)
data_imgs = pd.DataFrame(flattened_data)


In [None]:
group = "test_mice"

path_mask_root_dir = fr"data_raw/segmentation_masks-{group}/"
path_seg_dict = fr"data_raw/results/segmentations-{group}.pkl"
path_seg_data = fr"data_raw/results/segmentations-{group}.csv"

dict_results = {}


In [None]:
min_area = 1000 ## min area in px
save_intervall = 100 ## save every x images

pbar = tqdm(total=len(data_imgs), position=0, leave=False, desc="Segmenting images")
for idx1, row in data_imgs.iterrows():

    image_name = row["image_name"]
    parent_name = row["parent_name"]
    image_path = row["image_path"]     

    base_image_name = os.path.splitext(image_name)[0]

    # Load image and predict
    try:
        image = cv2.imread(image_path)
        assert image is not None, "Failed to load the image."
    except:
        dict_results[base_image_name] = {"result": "no detections", "parent_name": parent_name, "image_name": image_name}
        pbar.update(1)
        continue
    
    ## do prediction
    result = base_model.predict(image)
    
    # Check for masks and process
    if len(result.mask) > 0:
        for idx2, (area, mask) in enumerate(zip(result.area, result.mask)):
            mask_name = base_image_name + f"_{idx2+1}.png"
            if area > min_area:
                
                # Filter mask and save
                # try:
                roi, info = model_helpers.filter_mask(image, mask, min_area)
                path_mask_dir = os.path.join(path_mask_root_dir, parent_name, mask_name)
                os.makedirs(os.path.dirname(path_mask_dir), exist_ok=True)
                saved = cv2.imwrite(path_mask_dir, roi)
                # except:
                #     info = {}
            else:
                info = {}
                                
            # Store info
            info["confidence"] = result.confidence[idx2]
            info["area"] = area
            info["mask_idx"] = idx2 + 1
            info["image_name"] = image_name
            info["parent_name"] = parent_name
            dict_results[mask_name] = info
        pbar.update(1)
    else:
        # Add an empty entry if no detections
        dict_results[mask_name] = {"result": "no detections", "parent_name": parent_name, "image_name": image_name}
    pbar.update(1)
    
    if (idx1 + 1) % save_intervall == 0:
        various.save_dict(dict_results, path_dict_results, format="pickle")
    

In [None]:
## final save
various.save_dict(dict_results, path_dict_results, format="pickle")
       
## to dt   
data_results = pd.DataFrame.from_dict(dict_results, orient="index").reset_index()
data_results = data_results.rename(columns={"index":"mask_name"})
data_results = data_results[data_results['bbox'].notna()]
data_results.rename(columns={"confidence":"confidence_seg"}, inplace=True)
data_results[["mask_idx", "area", "diameter"]] = data_results[["mask_idx", "area", "diameter"]].apply(lambda x: x.astype('int'))
data_results = data_results[['species','image_name', 'mask_idx','mask_name','confidence_seg','area','bbox', 'center','diameter']]
data_results = data_results.sort_values(by=['species', 'image_name', "mask_idx"])
data_results.to_csv(path_data_results, index=False)