In [None]:
import os
import json
import logging
import glob
import pathlib
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm.auto import tqdm
from IPython.display import HTML, FileLink, display
import warnings
warnings.filterwarnings("ignore")

print("Installing required packages...")
!pip install -q ultralytics==8.3.36 kagglehub opencv-python-headless matplotlib tqdm seaborn

from ultralytics import YOLO
import kagglehub

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger('YOLOv12_BoneFracture')

# ==============================================================
# CONFIGURATION
# ==============================================================

In [None]:
from dataclasses import dataclass

@dataclass
class Config:
    data_dir: str = None
    output_dir: str = "/content/output"
    epochs: int = 35
    imgsz: int = 640
    batch: int = 16
    patience: int = 8
    conf: float = 0.25
    iou: float = 0.45
    max_frames: int = 30
    fps: int = 2

    def __post_init__(self):
        os.makedirs(self.output_dir, exist_ok=True)

# ==============================================================
# DATASET DOWNLOAD AND SETUP
# ==============================================================

In [None]:
def setup_dataset():
    """Download and setup the bone fracture dataset"""
    DATASET_SLUG = "jockeroika/human-bone-fractures-image-dataset"
    print(f"Downloading dataset: {DATASET_SLUG}...")

    dataset_path = kagglehub.dataset_download(DATASET_SLUG)
    print(f"Dataset downloaded to: {dataset_path}")

    # Auto-detect data directory
    potential_roots = [
        os.path.join(dataset_path, "Human Bone Fractures Multi-modal Image Dataset (HBFMID)", "Bone Fractures Detection"),
        os.path.join(dataset_path, "Bone Fractures Detection"),
        dataset_path
    ]

    DATA_DIR = None
    for p in potential_roots:
        if os.path.exists(p) and len(glob.glob(os.path.join(p, "train", "images", "*.*"))) > 0:
            DATA_DIR = p
            break

    if DATA_DIR is None:
        raise FileNotFoundError(f"Could not locate train/images in {dataset_path}")

    print(f"Data directory: {DATA_DIR}")
    return DATA_DIR

DATA_DIR = setup_dataset()
config = Config(data_dir=DATA_DIR)

# ==============================================================
# CREATE DATA CONFIGURATION
# ==============================================================

In [None]:
def create_data_yaml():
    """Create data.yaml configuration file"""
    train_img_dir = glob.glob(os.path.join(config.data_dir, "**", "train", "images"), recursive=True)[0]
    val_img_dir = glob.glob(os.path.join(config.data_dir, "**", "valid", "images"), recursive=True)[0]

    print(f"Train images: {train_img_dir}")
    print(f"Val images: {val_img_dir}")

    # Define class names from the dataset
    class_names = [
        'elbow_positive', 'finger_positive', 'forearm_positive', 'humerus_positive',
        'shoulder_positive', 'wrist_positive', 'elbow_negative', 'finger_negative',
        'forearm_negative', 'humerus_negative', 'shoulder_negative', 'wrist_negative'
    ]

    # Build data configuration
    root_for_yaml = os.path.dirname(config.data_dir)
    data_yaml = {
        'path': root_for_yaml,
        'train': os.path.relpath(train_img_dir, root_for_yaml).replace(os.sep, '/'),
        'val': os.path.relpath(val_img_dir, root_for_yaml).replace(os.sep, '/'),
        'nc': len(class_names),
        'names': class_names
    }

    yaml_path = os.path.join(config.output_dir, "data.yaml")
    with open(yaml_path, "w") as f:
        import yaml
        yaml.dump(data_yaml, f, default_flow_style=False)

    logger.info(f"data.yaml created: {len(class_names)} classes")
    return yaml_path, class_names

yaml_path, class_names = create_data_yaml()

# ==============================================================
# MODEL TRAINING
# ==============================================================

In [None]:
def train_model():
    """Train the YOLO model on bone fracture dataset"""
    print("Starting model training...")

    # Use YOLOv8 as a reliable base model
    model = YOLO("yolov8m.pt")

    results = model.train(
        data=yaml_path,
        epochs=config.epochs,
        imgsz=config.imgsz,
        batch=config.batch,
        patience=config.patience,
        lr0=0.01,
        optimizer='AdamW',
        weight_decay=0.0005,
        momentum=0.937,
        conf=config.conf,
        iou=config.iou,
        project=config.output_dir,
        name="yolo_bone_fracture",
        exist_ok=True,
        plots=True,
        save=True,
        device=0,
        workers=2,
        cache=False
    )

    # Get best model path
    best_model_path = str(pathlib.Path(results.save_dir) / "weights" / "best.pt")
    logger.info(f"Training complete. Best model: {best_model_path}")
    return best_model_path

best_model_path = train_model()

# ==============================================================
# INFERENCE AND VIDEO GENERATION
# ==============================================================

In [None]:
def create_detection_video():
    """Create detection video with automatic download"""

    print("🎥 Starting video generation...")

    # Load the trained model
    model = YOLO(best_model_path)
    val_img_dir = glob.glob(os.path.join(config.data_dir, "**", "valid", "images"), recursive=True)[0]
    val_images = sorted(glob.glob(os.path.join(val_img_dir, "*.*")))[:config.max_frames]

    print(f"Processing {len(val_images)} images for video...")

    def draw_detections(img_path, model, img_size=(640, 640)):
        """Draw detections on a single image"""
        try:
            # Read and resize image
            img = cv2.imread(img_path)
            if img is None:
                img = np.random.randint(100, 200, (*img_size, 3), dtype=np.uint8)
            else:
                img = cv2.resize(img, img_size)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        except:
            img = np.random.randint(100, 200, (*img_size, 3), dtype=np.uint8)

        # Run inference
        results = model.predict(
            source=img_path,
            conf=0.15,  # Lower confidence for better detection
            iou=0.4,
            verbose=False
        )

        # Draw detections
        for r in results:
            if r.boxes is not None and len(r.boxes) > 0:
                boxes = r.boxes.data.cpu().numpy()
                for box in boxes:
                    x1, y1, x2, y2, conf, cls_id = box
                    x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])

                    if x2 > x1 and y2 > y1:  # Validate coordinates
                        cls_id = int(cls_id)
                        label = f"{class_names[cls_id] if cls_id < len(class_names) else 'unknown'} {conf:.2f}"

                        # Generate color based on class
                        color = plt.cm.Set3(cls_id % 12)[:3]
                        color = tuple(int(c*255) for c in color)

                        # Draw bounding box
                        cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)

                        # Draw label
                        (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
                        cv2.rectangle(img, (x1, y1 - th - 10), (x1 + tw, y1), color, -1)
                        cv2.putText(img, label, (x1, y1-5),
                                   cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
            else:
                # Add "No Detection" text if no detections found
                cv2.putText(img, "No Fracture Detected", (50, 50),
                           cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

        return img

    # Generate frames
    frames = []
    print("Generating frames...")

    for i, img_path in enumerate(tqdm(val_images, desc="Processing images")):
        try:
            frame = draw_detections(img_path, model)
            frames.append(frame)
        except Exception as e:
            print(f"Error processing image {i}: {e}")
            # Create placeholder frame
            placeholder = np.random.randint(100, 200, (640, 640, 3), dtype=np.uint8)
            cv2.putText(placeholder, "Processing Error", (50, 50),
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
            frames.append(placeholder)

    # Create video
    if frames:
        video_path = "/content/bone_fracture_detection.mp4"
        h, w = frames[0].shape[:2]

        # Use better video codec
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(video_path, fourcc, config.fps, (w, h))

        print("Creating video...")
        for frame in tqdm(frames, desc="Writing video frames"):
            # Convert RGB to BGR for OpenCV
            frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            writer.write(frame_bgr)

        writer.release()
        print(f"✅ Video created: {video_path}")

        # Auto-download the video
        if os.path.exists(video_path):
            file_size = os.path.getsize(video_path) / (1024 * 1024)  # Size in MB
            print(f"📦 Video file size: {file_size:.1f} MB")

            # Create download link
            print("🔽 Downloading video automatically...")
            display(FileLink(video_path))

            # Alternative download method
            from google.colab import files
            files.download(video_path)

            return video_path
        else:
            print("❌ Video file was not created")
            return None
    else:
        print("❌ No frames were generated")
        return None

# Generate and download the video
video_path = create_detection_video()

if video_path:
    print("\n" + "="*60)
    print("🎉 VIDEO GENERATION COMPLETED SUCCESSFULLY!")
    print("="*60)
    print(f"📹 Output: bone_fracture_detection.mp4")
    print(f"🖼️  Frames: {config.max_frames}")
    print(f"🎬 FPS: {config.fps}")
    print(f"💾 Location: {video_path}")
    print("="*60)
else:
    print("\n❌ Video generation failed. Please check the error messages above.")