# IMPORTS

In [None]:
import os
import random
import shutil
from PIL import Image, ImageFile
from ultralytics import YOLO
import albumentations as A
from collections import Counter

ImageFile.LOAD_TRUNCATED_IMAGES = True  # allow truncated image loading

### 1.Organize Data

In [None]:
"""
Dataset splitter for YOLO‑style images/labels.

Reads raw JPEG/PNG images plus corresponding YOLO‑format `.txt` label files
from `Dataset/raw_images` and `Dataset/raw_labels`, performs a stratified
80 / 20 split (per class) into *train* and *val* subsets, and copies the
results to:

    Dataset/images/train/   *.jpg / *.png
    Dataset/labels/train/   *.txt
    Dataset/images/val/
    Dataset/labels/val/

The split preserves class balance as far as possible by allocating images
class‑wise before shuffling. A summary of the original class distribution
and final split sizes is printed.
"""

import os
import random
import shutil
from collections import Counter

# --------------------------------------------------------------------------- #
# 0. Define source / target directories
# --------------------------------------------------------------------------- #

base_dir = "Dataset"
raw_images_dir = os.path.join(base_dir, "raw_images")    # source images
raw_labels_dir = os.path.join(base_dir, "raw_labels")    # source YOLO labels

images_dir = os.path.join(base_dir, "images")            # destination images
labels_dir = os.path.join(base_dir, "labels")            # destination labels

# --------------------------------------------------------------------------- #
# 1. Create train/val sub‑folders (if they do not already exist)
# --------------------------------------------------------------------------- #

for split in ("train", "val"):
    os.makedirs(os.path.join(images_dir, split), exist_ok=True)
    os.makedirs(os.path.join(labels_dir, split), exist_ok=True)

# --------------------------------------------------------------------------- #
# 2. Collect list of all image filenames
# --------------------------------------------------------------------------- #

image_files = [
    f
    for f in os.listdir(raw_images_dir)
    if f.lower().endswith((".jpg", ".jpeg", ".png"))
]

# --------------------------------------------------------------------------- #
# 3. Utility: count class occurrences in a single YOLO label file
# --------------------------------------------------------------------------- #

def count_classes_in_label_file(label_path: str) -> Counter:
    """
    Count how many instances of each class index appear in a YOLO label file.

    Parameters
    ----------
    label_path : str
        Path to a `.txt` file whose lines follow the YOLO format:
        ``<class_id> x_center y_center width height``

    Returns
    -------
    collections.Counter
        Mapping ``class_id -> instance_count``.
    """
    class_counts: Counter[int] = Counter()
    try:
        with open(label_path, "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 5:          # valid YOLO row
                    class_id = int(parts[0])
                    class_counts[class_id] += 1
    except Exception as e:                  # file missing / unreadable
        print(f"Error reading {label_path}: {e}")
    return class_counts


# --------------------------------------------------------------------------- #
# 4. Compute overall class distribution in the raw dataset
# --------------------------------------------------------------------------- #

total_class_counts: Counter[int] = Counter()
for img_file in image_files:
    name, _ = os.path.splitext(img_file)
    label_path = os.path.join(raw_labels_dir, name + ".txt")
    if os.path.exists(label_path):
        total_class_counts.update(count_classes_in_label_file(label_path))

print("Class Distribution in original dataset:")
for cid, cnt in sorted(total_class_counts.items()):
    print(f"  Class {cid}: {cnt} instances")

# --------------------------------------------------------------------------- #
# 5. Group image filenames by the classes they contain
# --------------------------------------------------------------------------- #

class_to_images: dict[int, list[str]] = {cid: [] for cid in total_class_counts}
for img_file in image_files:
    name, _ = os.path.splitext(img_file)
    lbl_path = os.path.join(raw_labels_dir, name + ".txt")
    if os.path.exists(lbl_path):
        for cid in count_classes_in_label_file(lbl_path):
            class_to_images[cid].append(img_file)

# --------------------------------------------------------------------------- #
# 6. Stratified 80/20 split per class
# --------------------------------------------------------------------------- #

random.seed(42)                                 # reproducibility
train_files: set[str] = set()
val_files: set[str] = set()

for cid, imgs in class_to_images.items():
    random.shuffle(imgs)
    split_idx = int(len(imgs) * 0.8)
    train_files.update(imgs[:split_idx])        # 80 % to train
    val_files.update(imgs[split_idx:])          # 20 % to val

# ensure an image is not in both sets
val_files = list(val_files - train_files)
train_files = list(train_files)

# --------------------------------------------------------------------------- #
# 7. File‑copy helper
# --------------------------------------------------------------------------- #

def move_files(file_list: list[str], img_dst: str, lbl_dst: str) -> None:
    """
    Copy images and their YOLO label files to destination folders.

    Parameters
    ----------
    file_list : list[str]
        Filenames (with extension) to move.
    img_dst : str
        Directory to receive images.
    lbl_dst : str
        Directory to receive label `.txt` files.

    Notes
    -----
    * The function uses ``shutil.copy`` so the originals remain untouched.
    * An image is copied only if both image and label exist.
    """
    for file in file_list:
        name, _ = os.path.splitext(file)
        img_src = os.path.join(raw_images_dir, file)
        lbl_src = os.path.join(raw_labels_dir, name + ".txt")

        if os.path.exists(img_src) and os.path.exists(lbl_src):
            shutil.copy(img_src, os.path.join(img_dst, file))
            shutil.copy(lbl_src, os.path.join(lbl_dst, name + ".txt"))


# --------------------------------------------------------------------------- #
# 8. Copy train and val splits
# --------------------------------------------------------------------------- #

move_files(train_files, os.path.join(images_dir, "train"), os.path.join(labels_dir, "train"))
move_files(val_files,   os.path.join(images_dir, "val"),   os.path.join(labels_dir, "val"))

# --------------------------------------------------------------------------- #
# 9. Report final stats
# --------------------------------------------------------------------------- #

print(f"Total images: {len(image_files)}")
print(f"Train images: {len(train_files)}")
print(f"Val images:   {len(val_files)}")


### 2. Fix truncated Image

In [None]:
def reload_and_save_images(folder_path: str) -> int:
    """
    Re‑encode every image inside *folder_path* to RGB and overwrite it in place.

    Purpose
    -------
    Some images in scraped or legacy datasets are partially corrupted or stored
    in colour spaces (CMYK, indexed, etc.) that can confuse training code.
    Re‑opening and re‑saving with Pillow forces a clean RGB decode/encode pass,
    fixing many such issues.

    Parameters
    ----------
    folder_path : str
        Directory that holds the images to repair.  
        Only filenames ending with ``.jpg``, ``.jpeg`` or ``.png`` (case‑insensitive)
        are processed; all others are ignored.

    Returns
    -------
    int
        Count of images successfully rewritten.

    Notes
    -----
    * The original files are **overwritten**; make a backup first if needed.  
    * Any file Pillow cannot open is skipped with a console warning.  
    * Use this on your train/val image folders before YOLO training to avoid
      mysterious “broken data” errors.

    Examples
    --------
    >>> n_fixed = reload_and_save_images("Dataset/images/train")
    >>> print(f"{n_fixed} images repaired.")
    """
    fixed_count = 0

    # iterate over everything in the directory
    for filename in os.listdir(folder_path):
        # process only recognised image extensions
        if filename.lower().endswith((".jpg", ".jpeg", ".png")):
            path = os.path.join(folder_path, filename)

            try:
                img = Image.open(path)          # ⇢ decode file
                img = img.convert("RGB")        # ⇢ force RGB colour space
                img.save(path, optimize=True)   # ⇢ re‑encode & overwrite
                fixed_count += 1                # ⇢ track successes
            except Exception as e:
                # log but continue on any unreadable/corrupted file
                print(f"Skipping {filename}: {e}")

    return fixed_count


# --------------------------------------------------------------------------- #
# Run the repair pass on both training and validation image folders
# --------------------------------------------------------------------------- #

print(
    f"Fixed {reload_and_save_images(os.path.join(images_dir, 'train'))} training images"
)
print(
    f"Fixed {reload_and_save_images(os.path.join(images_dir, 'val'))} validation images"
)


### 3. Create data.yaml file

In [None]:
"""
Compute inverse‑frequency class weights and build a YOLO data YAML that
includes those weights for imbalance‑aware training.

Given `total_class_counts` (a Counter produced earlier), we:
1. Calculate an inverse‑frequency weight for each class
2. Normalise so the largest weight equals 1.0
3. Print the weights for inspection
4. Assemble a `data_balanced.yaml` file pointing to train/val folders and
   embedding the weight list.

The resulting YAML can be passed to Ultralytics YOLOv8:

    yolo detect train data=data_balanced.yaml model=yolov8n.pt ...
"""

# --------------------------------------------------------------------------- #
# 1. Compute inverse‑frequency weights
# --------------------------------------------------------------------------- #

total_instances = sum(total_class_counts.values())  # total labelled objects
class_weights: dict[int, float] = {}

for class_id, count in total_class_counts.items():
    # weight ∝ 1 / frequency  (= N_total / (K * n_i))
    class_weights[class_id] = total_instances / (len(total_class_counts) * count)

# --------------------------------------------------------------------------- #
# 2. Normalise so the largest weight is 1.0 (keeps values in a nicer range)
# --------------------------------------------------------------------------- #

max_weight = max(class_weights.values())
class_weights = {cid: w / max_weight for cid, w in class_weights.items()}

print("Class weights for training:")
for cid, wt in class_weights.items():
    print(f"  Class {cid}: {wt:.2f}")

# --------------------------------------------------------------------------- #
# 3. Create list in class‑index order to drop into YAML
#    (missing classes default to weight 1.0)
# --------------------------------------------------------------------------- #

weight_list = [
    class_weights.get(i, 1.0) for i in range(max(class_weights.keys()) + 1)
]

# --------------------------------------------------------------------------- #
# 4. Assemble the YOLO data YAML with the computed weights
# --------------------------------------------------------------------------- #

data_yaml = f"""
path: {os.path.abspath(base_dir)}        # dataset root
train: images/train                      # relative to `path`
val: images/val

names:                                   # class index → label
    0: bubble
    1: narration
    2: other
    3: text
    4: ui

# Automatically generated class weights
class_weights: {weight_list}
"""

with open("data_balanced.yaml", "w") as f:
    f.write(data_yaml)

print("Wrote data_balanced.yaml with class weights.")


### 4. Train YOLOV8 Model

In [None]:
"""
Train a YOLOv8 model on the balanced manga‑bubble dataset, print core
validation metrics, and define a heuristic post‑processing helper.
"""

from ultralytics import YOLO

# --------------------------------------------------------------------------- #
# Start from the tiny YOLOv8‑N weights to avoid bias from previous finetunes
# --------------------------------------------------------------------------- #
model = YOLO('yolov8n.pt')  # Start fresh to reduce historical bias

# --------------------------------------------------------------------------- #
# Launch training with data‑augmentation and early‑stopping params
# --------------------------------------------------------------------------- #
results = model.train(
    data="data_balanced.yaml",  # YAML built earlier with class weights/paths
    epochs=50,
    imgsz=640,
    patience=15,   # stop if val metric stalls for 15 epochs
    batch=16,
    cos_lr=True,   # cosine learning‑rate schedule
    mixup=0.1,     # mixup augmentation probability
    copy_paste=0.1,  # copy‑paste augmentation probability
    degrees=10.0,  # random rotation ±10°
    scale=0.5,     # random scaling (0.5–1.5)
)

# --------------------------------------------------------------------------- #
# Print best metrics stored inside the Ultralytics Results object
# --------------------------------------------------------------------------- #
metrics = results.results_dict
print("Training Complete")

try:
    # Keys differ slightly across package versions, so we use .get() fallbacks
    print(f"Best mAP@0.5:      {metrics.get('metrics/mAP50(B)',  metrics.get('mAP50',      0)):.4f}")
    print(f"Best mAP@0.5:95:   {metrics.get('metrics/mAP50-95(B)', metrics.get('mAP50-95', 0)):.4f}")
    print(f"Best Precision:    {metrics.get('metrics/precision(B)', metrics.get('precision', 0)):.4f}")
    print(f"Best Recall:       {metrics.get('metrics/recall(B)',    metrics.get('recall',    0)):.4f}")
except Exception:
    print("Could not access metrics directly. Check the results object for details.")

# --------------------------------------------------------------------------- #
# Post‑processing: simple heuristics to fix common mis‑classifications
# --------------------------------------------------------------------------- #
def apply_post_processing_rules(results):
    """
    Apply rule‑based tweaks to raw YOLO detections for manga bubble layouts.

    Parameters
    ----------
    results : list[ultralytics.engine.results.Results]
        Output from ``model(image_path)``.  
        Each item exposes ``xyxy[0]`` (tensor, shape *N×6*):
        ``x1,y1,x2,y2,conf,cls``.

    Returns
    -------
    list[dict]
        Cleaned detections—one dict per box with keys:
        ``x, y, width, height, confidence, class``.

    Rules implemented
    -----------------
    1. **Narration squares** – if a box is nearly square (aspect 0.9‑1.1),
       predicted as class 0 with conf < 0.9 → re‑label to class 1.
    2. **Wide UI strips** – if width/height > 3, not already class 3, conf < 0.85
       → re‑label to class 3.

    Notes
    -----
    * The original `results` object is **not** modified.
    * Thresholds were picked empirically—tune for other domains if needed.
    """
    processed_results = []

    for result in results:              # iterate over batch (often 1 image)
        for box in result.xyxy[0]:      # each row: x1,y1,x2,y2,conf,cls
            x1, y1, x2, y2, conf, cls = box.tolist()

            width  = x2 - x1
            height = y2 - y1
            aspect_ratio = width / height if height else 0

            # Rule 1: almost‑square speech bubble → narration
            if 0.9 < aspect_ratio < 1.1 and cls == 0 and conf < 0.9:
                cls = 1

            # Rule 2: extra‑wide rectangle → UI element
            if width / height > 3.0 and cls != 3 and conf < 0.85:
                cls = 3

            processed_results.append({
                'x': x1,
                'y': y1,
                'width': width,
                'height': height,
                'confidence': conf,
                'class': cls
            })

    return processed_results


### Prediction test

| using best.pt

In [None]:
"""
Inference pipeline:

1. Load the best fine‑tuned weights.
2. Ensure a *test_set* folder exists (create and pre‑fill with a few val images
   if necessary).
3. Re‑save test images to catch hidden corruptions.
4. Run YOLOv8 inference, then apply heuristic post‑processing.
5. Save visualised predictions to disk.
"""

# --------------------------------------------------------------------------- #
# 1. Load the trained checkpoint
# --------------------------------------------------------------------------- #
best_model = YOLO("runs/detect/train12/weights/best.pt")   # path to best.pt

# --------------------------------------------------------------------------- #
# 2. Prepare a small test directory
# --------------------------------------------------------------------------- #
test_dir = os.path.join(base_dir, "test_set")

if not os.path.exists(test_dir):
    print(f"Warning: Test directory {test_dir} does not exist. Creating it.")
    os.makedirs(test_dir, exist_ok=True)

    # If completely empty, copy up to 5 validation images for a quick demo
    for i, file in enumerate(val_files[:5]):
        if i >= 5:
            break
        shutil.copy(os.path.join(raw_images_dir, file), test_dir)

# --------------------------------------------------------------------------- #
# 3. Sanitise test images (repair colour mode / corruption)
# --------------------------------------------------------------------------- #
reload_and_save_images(test_dir)

# --------------------------------------------------------------------------- #
# 4. Run inference and apply rule‑based clean‑up
# --------------------------------------------------------------------------- #
results = best_model.predict(test_dir, save=True)  # Ultralytics will create a
                                                   # runs/predict folder

processed_results = apply_post_processing_rules(results)

# --------------------------------------------------------------------------- #
# 5. Save the processed detections as visualisations
# --------------------------------------------------------------------------- #
for i, result in enumerate(processed_results):
    # result.save() draws boxes on the original image; we give each a new name
    result.save(filename=f"processed_{i}.jpg")

print("Inference complete with post-processing rules applied")
