# Notebook for training yolo detection model

In [None]:
%matplotlib inline

from ultralytics import YOLO
from ultralytics.models.yolo.detect.train import DetectionTrainer
from ultralytics.models.yolo.classify.train import ClassificationTrainer

import pandas as pd
from ultralytics.cfg import cfg2dict
import os
import ultralytics
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_theme(style="whitegrid")
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 14

ultralytics.settings.reset()

In [None]:
DATA_CONFIG_FILE = "data_config.yaml"


data_config = cfg2dict(DATA_CONFIG_FILE)
ROOT_DIR = data_config["path"]
CLASSIFY_DATA_DIR = os.path.join(os.path.dirname(ROOT_DIR), "classification/")
CLASSIFY_DATA_DIR = os.path.normpath(CLASSIFY_DATA_DIR).replace("\\", "/")
train_path = os.path.join(ROOT_DIR, "autosplit_train.txt")
val_path = os.path.join(ROOT_DIR, "autosplit_val.txt")

train_lines = []
val_lines = []

if os.path.exists(train_path) and os.path.exists(val_path):
    plt.figure(figsize=(10,6))
    with open(train_path, 'r') as f:
        train_lines = f.readlines()
    with open(val_path, 'r') as f:
        val_lines = f.readlines()

    ax = sns.barplot(
            x=["train", "val"],
            y=[len(train_lines), len(val_lines)],
    )

    plt.title("Dataset Split Distribution")
    ax.bar_label(ax.containers[0])
    plt.ylabel("Number of Images")
    plt.show()

else:
    print("Dataset split files (autosplit_train.txt or autosplit_val.txt) not found. Autosplit may have failed or not been run.")


In [None]:

def collect_object_data(root_dir, data):
    """Collects object data (dataset, class) from label files."""
    
    object_data = []
    
    for dataset_name, file_list in data.items():
        if dataset_name in ["names"]: #skip names.
            continue
        for entry in file_list:
            label_filename = os.path.splitext(os.path.basename(entry))[0] + ".txt"
            label_path = os.path.normpath(os.path.join(root_dir, "labels", label_filename))
            
            if os.path.exists(label_path):
                with open(label_path, 'r') as f:
                    for line in f:
                        class_id = line.split()[0]
                        object_data.append({
                            'Dataset': dataset_name,
                            'Class': class_id
                        })
    return object_data

def plot_dataset_class_counts(data, root_dir):
    """Plots a bar chart of class counts per dataset (with hue)."""
    
    object_data = collect_object_data(root_dir, data)
    df = pd.DataFrame(object_data).replace("1", data["names"][1]).replace("2", data["names"][2])
    
    plt.figure()
    ax = sns.countplot(x='Dataset', hue='Class', data=df)
    
    for p in ax.patches:
        height = p.get_height()
        if height > 0: # Avoid displaying 0 values
            ax.text(p.get_x() + p.get_width() / 2., height, f'{int(height)}', 
                    ha='center', va='bottom')
    
    plt.title('Object Counts per Dataset, Grouped by Class')
    plt.show()

# Example usage (assuming data_config and ROOT_DIR are defined):
plot_dataset_class_counts({
    "names": data_config["names"],
    "train": train_lines,
    "val": val_lines,
}, ROOT_DIR)

In [None]:
detection_training_params = {
    "task": "detect",
    "mode": "train",
    "model": "yolo11n.pt",
    "data": "data_config.yaml",
    "epochs": 300,
    "time": None,
    "patience": 20,
    "batch": 16,
    "imgsz": 640,
    "save": True,
    "save_period": -1,
    "cache": False,
    "device": 0,
    "workers": 8,
    "project": None,
    "name": None,
    "exist_ok": False,
    "pretrained": True,
    "optimizer": "auto",
    "verbose": True,
    "seed": 0,
    "deterministic": True,
    "single_cls": True,
    "rect": False,
    "cos_lr": False,
    "close_mosaic": 10,
    "resume": False,
    "amp": True,
    "fraction": 1.0,
    "profile": False,
    "freeze": None,
    "multi_scale": False,
    "overlap_mask": True,
    "mask_ratio": 4,
    "dropout": 0.0,
    "val": True,
    "split": "val",
    "save_json": False,
    "save_hybrid": False,
    "conf": 0.25,
    "iou": 0.7,
    "max_det": 300,
    "half": False,
    "dnn": False,
    "plots": True,
    "lr0": 0.01,
    "lrf": 0.01,
    "momentum": 0.937,
    "weight_decay": 0.0005,
    "warmup_epochs": 3.0,
    "warmup_momentum": 0.8,
    "warmup_bias_lr": 0.1,
    "box": 7.5,
    "cls": 1.0,
    "dfl": 1.5,
    "nbs": 32,
    "hsv_h": 0.015,
    "hsv_s": 0.1,
    "hsv_v": 0.1,
    "degrees": 90.0,
    "translate": 0.1,
    "scale": 0.2,
    "shear": 0.0,
    "perspective": 0.0002,
    "flipud": 0.0,
    "fliplr": 0.5,
    "bgr": 0.0,
    "mosaic": 0.0,
    "mixup": 0.0,
    "copy_paste": 0.0,
    "copy_paste_mode": "flip",
    "auto_augment": "randaugment",
    "erasing": 0.1,
    "crop_fraction": 1.0,
}

detection_training = DetectionTrainer(overrides=detection_training_params)
detection_training.train()

In [None]:
classification_training_params = {
    "task": "classify",
    "mode": "train",
    "model": "yolo11n-cls",  # Specify a classification model
    "data": CLASSIFY_DATA_DIR,  
    "epochs": 100,  
    "time": None,
    "patience": 10,  # Adjust as needed
    "batch": 64,    # Adjust as needed
    "imgsz": 224,   # Standard image size for classification
    "save": True,
    "save_period": -1,
    "cache": False,
    "device": 0,
    "workers": 4,   # Adjust as needed
    "project": None,
    "name": None,
    "exist_ok": False,
    "pretrained": True,
    "optimizer": "auto",
    "verbose": True,
    "seed": 0,
    "deterministic": True,
    "rect": False,
    "cos_lr": False,
    "resume": False,
    "amp": True,
    "fraction": 1.0,
    "profile": False,
    "freeze": None,
    "multi_scale": False,
    "dropout": 0.0,      # Specific to classification
    "val": True,
    "split": "val",
    "plots": True,
    "lr0": 0.001,      # Adjust learning rate as needed
    "lrf": 0.01,
    "momentum": 0.9,   # Adjust momentum as needed
    "weight_decay": 0.0001, # Adjust weight decay as needed
    "warmup_epochs": 3.0,
    "warmup_momentum": 0.8,
    "warmup_bias_lr": 0.1,
    "cls": 1.0,
    "nbs": 64,         # Adjust nominal batch size
    "hsv_h": 0.005,
    "hsv_s": 0.05,
    "hsv_v": 0.05,
    "degrees": 10.0,   # Adjust rotation for classification
    "translate": 0.1,
    "scale": 0.1,      # Adjust scale for classification
    "shear": 0.0,
    "perspective": 0.0,
    "flipud": 0.5,
    "fliplr": 0.5,
    "bgr": 0.0,
    "copy_paste": 0.0,
    "copy_paste_mode": "flip",
    "auto_augment": "randaugment",
    "erasing": 0.2,     # Random erasing is common for classification
    "crop_fraction": 1, # Adjust crop fraction
}

print(classification_training_params)

classification_training = ClassificationTrainer(overrides=classification_training_params)
classification_training.train()