In [None]:
import os
import cv2
import numpy as np
from ultralytics import YOLO
from tqdm import tqdm

# =============================
# CONFIG
# =============================
MODEL_PATH = "runs/segment/rice_yolo_runs/yolov8m_seg/weights/best.pt"
IMAGE_DIR  = "rice_dataset/train/images"   # change to valid/test if needed
SAVE_DIR   = "rice_crops"
IMG_SIZE   = 224   # for CNNs later

# =============================
# LOAD MODEL
# =============================
model = YOLO(MODEL_PATH)

os.makedirs(SAVE_DIR, exist_ok=True)


# =============================
# MASK CROP FUNCTION
# =============================
def crop_from_mask(img, mask):
    """
    Extract tight crop using segmentation mask
    """
    mask = mask.astype(np.uint8)

    # get bounding rect around mask
    x, y, w, h = cv2.boundingRect(mask)

    crop_img = img[y:y+h, x:x+w]
    crop_mask = mask[y:y+h, x:x+w]

    # remove background
    crop_img = cv2.bitwise_and(crop_img, crop_img, mask=crop_mask)

    return crop_img


# =============================
# PREDICT + SAVE CROPS
# =============================
results = model.predict(
    source=IMAGE_DIR,
    stream=True,     # important for memory
    conf=0.25,
    verbose=False
)

print("Extracting grains...")

for r in tqdm(results):

    img = r.orig_img
    image_name = os.path.splitext(os.path.basename(r.path))[0]

    if r.masks is None:
        continue

    masks = r.masks.data.cpu().numpy()
    classes = r.boxes.cls.cpu().numpy()

    for i, (mask, cls_id) in enumerate(zip(masks, classes)):

        mask = cv2.resize(mask, (img.shape[1], img.shape[0]))
        mask = (mask > 0.5).astype(np.uint8)

        crop = crop_from_mask(img, mask)

        if crop.size == 0:
            continue

        # resize for CNN
        crop = cv2.resize(crop, (IMG_SIZE, IMG_SIZE))

        class_name = model.names[int(cls_id)]

        class_folder = os.path.join(SAVE_DIR, class_name)
        os.makedirs(class_folder, exist_ok=True)

        save_path = os.path.join(class_folder, f"{image_name}_{i}.jpg")

        cv2.imwrite(save_path, crop)


print("âœ… Done! Crops saved to:", SAVE_DIR)
