In [1]:
import yaml
import cv2
from segment_anything import build_sam, SamAutomaticMaskGenerator, build_sam_vit_b
from PIL import Image, ImageDraw
import clip
import torch
import numpy as np
import matplotlib.pyplot as plt
import mlflow
from pathlib import Path

In [2]:
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

images = config["images"]
queries = config["queries"]
sam_configs = config["sam_configs"]


In [3]:
images

['test_images_for_mlflow/1.png',
 'test_images_for_mlflow/2.png',
 'test_images_for_mlflow/3.png']

In [4]:
def log_image_to_mlflow(image: Image.Image, name: str):
    path = f"/tmp/{name}.png"
    image.save(path)
    mlflow.log_artifact(path, artifact_path="results")

In [5]:
from segment_anything import build_sam, SamAutomaticMaskGenerator
from extract_masks import get_semantic_match
mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="sam_vit_h_4b8939.pth"))

In [None]:
mlflow.set_tracking_uri("http://localhost:5001")
mlflow.set_experiment("semantic_retrieval_sam")

for config in sam_configs:
    with mlflow.start_run(run_name=config["name"]):
        mlflow.log_params(config)

        sam_model =  build_sam(checkpoint="sam_vit_h_4b8939.pth")
        mask_generator = SamAutomaticMaskGenerator(
            model=sam_model,
            points_per_side=config["points_per_side"],
            pred_iou_thresh=config["pred_iou_thresh"],
            stability_score_thresh=config["stability_score_thresh"],
            box_nms_thresh=config["box_nms_thresh"],
            crop_n_layers=config["crop_n_layers"],
            min_mask_region_area=config["min_mask_region_area"]
        )
        for image_path in images:
            for query in queries:
                result_image = get_semantic_match(image_path, query, mask_generator)

                name = f"{config['name']}_{Path(image_path).stem}_{query.replace(' ', '_')}"
                log_image_to_mlflow(result_image, name)