In [2]:
import os, time, cv2, queue, threading, glob
import numpy as np
import torch
import supervision as sv
from groundingdino.util.inference import load_model, load_image, predict, annotate

# ---------- Config ----------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
config_path = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
weights_path = "GroundingDINO/weights/groundingdino_swint_ogc.pth"
image_folder = "/mnt/d/whales/happy-whale-and-dolphin/train_images/"
OUTPUT_FOLDER = "/mnt/d/whales/processed/4"
text_prompt = "whale"
MAX_IMAGES = 100           # test subset (change to 50000 later)
MAX_WORKERS = 4            # loader threads
BUFFER_SIZE = 64           # how many decoded images to keep ready
BOX_THRESHOLD = 0.4
TEXT_THRESHOLD = 0.25
RESIZE_TO = (640, 640)
CLASS_ID = 5 # whale = 5 for now
val_txt = "dev/whales/yolo/data/train.txt"

def boxes_normalized_to_pixels(boxes, src, path):
    h, w, _ = src.shape
    # print(path)
    ret = []
    for box in boxes:
        x, y, bw, bh = box
        x1 = int((x - bw/2) * w)
        y1 = int((y - bh/2) * h)
        x2 = int((x + bw/2) * w)
        y2 = int((y + bh/2) * h)
        x1, y1, x2, y2 = map(float, (x1, y1, x2, y2))
        ret.append([x1, y1, x2, y2])
    ret = torch.tensor(ret)
    return ret
            

# 1. Async loader with resize
def prefetch_loader(paths, loader_func, resize_to=None, max_workers=4, buffer_size=64):
    q = queue.Queue(maxsize=buffer_size)
    sentinel = object()
    it = iter(paths)
    lock = threading.Lock()

    def worker():
        while True:
            with lock:
                try:
                    p = next(it)
                except StopIteration:
                    break
            t0 = time.time()
            try:
                src, trans = loader_func(p)

                if resize_to is not None:
                    w, h = resize_to
                    src = cv2.resize(src, (w, h), interpolation=cv2.INTER_LINEAR)
                    trans = torch.nn.functional.interpolate(
                        trans.unsqueeze(0), size=(h, w),
                        mode="bilinear", align_corners=False
                    ).squeeze(0)

                load_t = time.time() - t0
                q.put((p, src, trans, load_t))
            except Exception as e:
                q.put((p, e, None, 0))
        q.put(sentinel)

    threads = [threading.Thread(target=worker, daemon=True) for _ in range(max_workers)]
    for t in threads:
        t.start()

    finished = 0
    while finished < max_workers:
        item = q.get()
        if item is sentinel:
            finished += 1
            continue
        yield item


# 2. Inference
def run_inference(model, batch_tensors, batch_paths, batch_src):
    import torchvision

    t0 = time.time()
    trans = batch_tensors[0]
    src = batch_src[0]
    
    with torch.no_grad():
        with torch.amp.autocast('cuda'):
            boxes, logits, phrases = predict(
                model=model,
                image=trans,
                caption=text_prompt,
                box_threshold=BOX_THRESHOLD,
                text_threshold=TEXT_THRESHOLD
            )

    box_threshold = 0.36
    while len(boxes) == 0 and box_threshold > 0.29:
        # print(f"[] cannot find box, trying lower confidence ...")
        box_threshold -= 0.01
        boxes, logits, phrases = predict(
            model=model,
            image=trans,
            caption=text_prompt,
            box_threshold=box_threshold,
            text_threshold=0.25
        )
        
    if len(boxes) == 0:  
        print(f"[{batch_paths[0]}] still cannot cannot find a box ...")

        
    # Optional NMS (keep)
    if len(boxes) > 1:
        pixel_boxes = boxes_normalized_to_pixels(boxes, src, batch_paths[0])
        keep = torchvision.ops.nms(pixel_boxes, logits, iou_threshold=0.5)
        boxes = boxes[keep]
        logits = logits[keep]
        phrases = [phrases[i] for i in keep]

    if DEVICE == "cuda":
        torch.cuda.synchronize()
        
    inf_t = time.time() - t0
    return boxes, logits, phrases, inf_t


# 3. Annotate + save
def annotate_and_save(image_path, src, outputs):
    import cv2, os

    boxes, logits, phrases, _ = outputs
    annotated = annotate(src, boxes, logits, phrases)

    # Save YOLO labels
    h, w, _ = src.shape
    base = os.path.splitext(os.path.basename(image_path))[0]
    with open(os.path.join(OUTPUT_FOLDER, "labels", f"{base}.txt"), "w") as f:
        for box in boxes:
            x_center, y_center, box_w, box_h = box
            f.write(f"{CLASS_ID} {x_center:.6f} {y_center:.6f} {box_w:.6f} {box_h:.6f}\n")

    cv2.imwrite(os.path.join(OUTPUT_FOLDER, "annotated", f"{base}.jpg"), cv2.cvtColor(annotated, cv2.COLOR_RGB2BGR))
    cv2.imwrite(os.path.join(OUTPUT_FOLDER, "images", f"{base}.jpg"), cv2.cvtColor(src, cv2.COLOR_RGB2BGR))


# 4. Dataset loop
def process_dataset(model, image_paths, resize_to, batch_size):
    total_load, total_inf, total_write = 0, 0, 0
    batch_imgs, batch_paths, batch_src = [], [], []

    for (path, src, trans, load_t) in prefetch_loader(image_paths, load_image, resize_to=resize_to, max_workers=4):
        if isinstance(src, Exception):
            print(f"❌ Error loading {path}: {src}")
            continue
        total_load += load_t

        batch_paths.append(path)
        batch_imgs.append(trans)
        batch_src.append(src)

        if len(batch_imgs) >= batch_size:
            tensors = torch.stack(batch_imgs).to(DEVICE)
            # outputs, inf_t = run_inference(model, device, tensors)
            outputs = run_inference(model, batch_imgs, batch_paths, batch_src)
            total_inf += outputs[3]

            for p, s, o in zip(batch_paths, batch_src, [outputs]):
                t0 = time.time()
                annotate_and_save(p, s, o)
                total_write += time.time() - t0

            batch_imgs, batch_paths, batch_src = [], [], []

    return total_load, total_inf, total_write


def main():
    print(f"🚀 Using device: {DEVICE}")
    # ---------- Load model ----------
    model = load_model(config_path, weights_path, device=DEVICE)


    
    image_paths = sorted(glob.glob(os.path.join(image_folder, "*.*")))[40000:]

    # with open(val_txt, 'r') as file_object:
    #     image_paths = [
    #         os.path.join(image_folder, os.path.basename(x.strip()))
    #         for x in file_object.readlines()
    #         if "whales" in x
    #     ]
    

    
    print(f"📷 Found {len(image_paths)} images")

    # --- output ---
    os.makedirs(os.path.join(OUTPUT_FOLDER, "images"), exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_FOLDER, "labels"), exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_FOLDER, "annotated"), exist_ok=True)

    # --- process ---
    t0 = time.time()

    for i in range(0, len(image_paths), 100):
        t1 = time.time()
        images = image_paths[i:i + 100]
        load_t, inf_t, write_t = process_dataset(model, images, resize_to=RESIZE_TO, batch_size=1)
        batch_time = time.time() - t1
        total_time = time.time() - t0
        print(f" 100 images from {i} to {i + 100} finished, ⏱ Batch time: {batch_time:.2f}s; Total time: {total_time:.2f}s")

    # --- report ---
    # print(f"\n✅ Finished processing {len(image_paths)} images")
    # print(f"⏱ Total time: {total_time:.2f}s")
    # print(f"   - Loading:   {load_t:.2f}s")
    # print(f"   - Inference: {inf_t:.2f}s")
    # print(f"   - Writing:   {write_t:.2f}s")
    # print(f"   - Avg/img:   {total_time/len(image_paths):.3f}s")

if __name__ == "__main__":
    main()


🚀 Using device: cuda
final text_encoder_type: bert-base-uncased
📷 Found 11033 images




 100 images from 0 to 100 finished, ⏱ Batch time: 40.56s; Total time: 40.56s
 100 images from 100 to 200 finished, ⏱ Batch time: 20.81s; Total time: 61.37s
 100 images from 200 to 300 finished, ⏱ Batch time: 22.69s; Total time: 84.06s
 100 images from 300 to 400 finished, ⏱ Batch time: 21.32s; Total time: 105.38s
 100 images from 400 to 500 finished, ⏱ Batch time: 22.71s; Total time: 128.08s
 100 images from 500 to 600 finished, ⏱ Batch time: 22.50s; Total time: 150.59s
 100 images from 600 to 700 finished, ⏱ Batch time: 25.60s; Total time: 176.19s
 100 images from 700 to 800 finished, ⏱ Batch time: 27.09s; Total time: 203.28s
 100 images from 800 to 900 finished, ⏱ Batch time: 28.03s; Total time: 231.31s
 100 images from 900 to 1000 finished, ⏱ Batch time: 28.90s; Total time: 260.21s
 100 images from 1000 to 1100 finished, ⏱ Batch time: 27.03s; Total time: 287.24s
 100 images from 1100 to 1200 finished, ⏱ Batch time: 25.29s; Total time: 312.53s
 100 images from 1200 to 1300 finished, 