In [None]:
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import os
import torch
import gc
from pathlib import Path
from tqdm import tqdm
import time

FOLD_NUM = 7

for idx in range(FOLD_NUM):
    # Load model using SAHI's AutoDetectionModel
    model_path = f"/mnt/data/aorta/yolo12x_fold7_best/yolo12x_fold7_1_{idx}/weights/best.pt"
    
    print(f"\n{'='*60}")
    print(f"Loading Model {idx}...")
    print(f"{'='*60}")
    
    detection_model = AutoDetectionModel.from_pretrained(
        model_type='ultralytics',
        model_path=model_path,
        confidence_threshold=0.1,
        device="cuda:0"
    )
    
    # Get all images in the directory
    source_dir = "./aorta_detection/datasets/all_test/"
    image_paths = list(Path(source_dir).glob("*.png"))
    image_paths = sorted(image_paths)
    
    total_images = len(image_paths)
    print(f"Found {total_images} images to process")
    
    # Open output file
    output_file = open(
        f'./kfold_predict_txt/positive/best_model{idx}_conf1e-5_sahi.txt', 
        'w'
    )
    
    # Track timing
    start_time = time.time()
    
    # Process each image with progress bar
    for img_idx, img_path in enumerate(tqdm(image_paths, desc=f"Model {idx}", unit="img")):
        img_start = time.time()
        
        # Perform sliced inference with SAHI
        result = get_sliced_prediction(
            str(img_path),
            detection_model,
            slice_height=180,
            slice_width=180,
            overlap_height_ratio=0.2,
            overlap_width_ratio=0.2,
            postprocess_match_metric="IOS",
            postprocess_match_threshold=0.75,
            postprocess_class_agnostic=False
        )
        
        # Get filename without extension
        filename = img_path.stem
        
        # Extract predictions from SAHI result
        det_count = 0
        for obj_pred in result.object_prediction_list:
            bbox = obj_pred.bbox
            x1, y1, x2, y2 = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy
            label = obj_pred.category.id
            conf = obj_pred.score.value
            
            line = f"{filename} {label} {conf:.4f} {int(x1)} {int(y1)} {int(x2)} {int(y2)}\n"
            output_file.write(line)
            det_count += 1
        
        img_time = time.time() - img_start
        
        # Print detailed progress every 10 images
        if (img_idx + 1) % 10 == 0:
            elapsed = time.time() - start_time
            avg_time = elapsed / (img_idx + 1)
            remaining = avg_time * (total_images - img_idx - 1)
            print(f"\n  Progress: {img_idx + 1}/{total_images} | "
                  f"Detections: {det_count} | "
                  f"Time: {img_time:.2f}s | "
                  f"ETA: {remaining/60:.1f}m")
    
    # Close output file
    output_file.close()
    
    # Print summary
    total_time = time.time() - start_time
    print(f"\n{'='*60}")
    print(f"Model {idx} Complete!")
    print(f"Total time: {total_time/60:.2f} minutes")
    print(f"Average time per image: {total_time/total_images:.2f}s")
    print(f"{'='*60}\n")
    
    # Clean up memory
    del detection_model, result
    gc.collect()
    torch.cuda.empty_cache()

print("\nâœ“ All models processed successfully!")