In [None]:
import os
from datetime import datetime
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
import random

In [None]:
image_path = "images/Image_09L.jpg"
base_dataset_dir = "dataset/"
base_results_dir = "results/"
grid_size = 5
sam2_checkpoint = "../checkpoints/sam2.1_hiera_tiny.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
device = "cuda:1" if torch.cuda.is_available() else "cpu"

In [None]:
timestamp = datetime.now().strftime("%m%d%H%M")
dataset_dir = os.path.join(base_dataset_dir, timestamp)
results_dir = os.path.join(base_results_dir, timestamp)
unmasked_dir = os.path.join(results_dir, "unmasked")
masked_dir = os.path.join(results_dir, "masked")
os.makedirs(dataset_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)
os.makedirs(unmasked_dir, exist_ok=True)
os.makedirs(masked_dir, exist_ok=True)

print("Starting preprocessing...")
image = cv2.imread(image_path)

if image is None:
    raise FileNotFoundError(f"Image not found at path: {image_path}. Please check the file path.")

In [None]:
green_channel = image[:, :, 1]

clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced_img = clahe.apply(green_channel)

blurred_img = cv2.GaussianBlur(enhanced_img, (5, 5), sigmaX=1.0)

kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
sharpened_img = cv2.filter2D(blurred_img, -1, kernel)

preprocessed_image_path = os.path.join(results_dir, "preprocessed_image.png")
cv2.imwrite(preprocessed_image_path, sharpened_img)
print(f"Preprocessed image saved to {preprocessed_image_path}")

print("Splitting image into grids...")
h, w = sharpened_img.shape
grid_h, grid_w = h // grid_size, w // grid_size

if grid_h == 0 or grid_w == 0:
    raise ValueError("Grid size is too large for the image dimensions. Reduce the grid size.")

for i in range(grid_size):
    for j in range(grid_size):
        y1, y2 = i * grid_h, (i + 1) * grid_h
        x1, x2 = j * grid_w, (j + 1) * grid_w
        sub_image = sharpened_img[y1:y2, x1:x2]

        if sub_image.size == 0 or sub_image.shape[0] == 0 or sub_image.shape[1] == 0:
            print(f"Skipped saving grid_{i * grid_size + j} due to invalid dimensions.")
            continue

        grid_path = os.path.join(dataset_dir, f"grid_{i * grid_size + j}.png")
        cv2.imwrite(grid_path, sub_image)

print(f"Grid images saved in {dataset_dir}")

In [None]:
print("Initializing SAM2 model...")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2,
    points_per_side=128,
    points_per_batch=16,
    pred_iou_thresh=0.7,
    stability_score_thresh=0.50,
    stability_score_offset=0.7,
    crop_n_layers=4,
    box_nms_thresh=0.7,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=30.0,
    use_m2m=True,
)

final_unmasked_area = np.ones((h, w), dtype=np.uint8)
final_segmentation_colored = np.zeros((h, w, 3), dtype=np.uint8)

In [None]:
def random_color():
    return [random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)]

print("Processing grids with SAM2...")
for grid_file in sorted(os.listdir(dataset_dir)):
    grid_path = os.path.join(dataset_dir, grid_file)
    sub_image = np.array(Image.open(grid_path))

    grid_y = int(grid_file.split("_")[1].split(".")[0]) // grid_size
    grid_x = int(grid_file.split("_")[1].split(".")[0]) % grid_size
    y1, y2 = grid_y * grid_h, (grid_y + 1) * grid_h
    x1, x2 = grid_x * grid_w, (grid_x + 1) * grid_w

    sub_image = sub_image.astype(np.uint8)
    try:
        masks = mask_generator.generate(sub_image)
        torch.cuda.empty_cache()  # 주기적 메모리 정리
    except RuntimeError as e:
        print(f"Error processing {grid_file}: {e}")
        continue

    combined_mask = np.zeros((sub_image.shape[0], sub_image.shape[1]), dtype=np.uint8)
    unmasked_area = np.ones_like(combined_mask, dtype=np.uint8)

    for mask in masks:
        single_mask = mask['segmentation'].astype(np.uint8)
        combined_mask = np.maximum(combined_mask, single_mask)
        unmasked_area = np.minimum(unmasked_area, 1 - single_mask)

        color = random_color()
        final_segmentation_colored[y1:y2, x1:x2][single_mask > 0] = color

    final_unmasked_area[y1:y2, x1:x2] = unmasked_area

final_segmentation_path = os.path.join(results_dir, "final_segmentation_colored.png")
Image.fromarray(final_segmentation_colored).save(final_segmentation_path)

final_unmasked_colored = np.zeros((*final_unmasked_area.shape, 3), dtype=np.uint8)
final_unmasked_colored[final_unmasked_area > 0] = [0, 255, 0]
final_unmasked_path = os.path.join(results_dir, "final_unmasked_colored.png")
Image.fromarray(final_unmasked_colored).save(final_unmasked_path)

del sam2
torch.cuda.empty_cache()

print(f"Final segmentation result saved to {final_segmentation_path}")
print(f"Final unmasked result saved to {final_unmasked_path}")
