# Notebook for training yolo detection model

In [None]:
%matplotlib inline

from ultralytics import YOLO
from ultralytics.models.yolo.detect.train import DetectionTrainer
from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import autosplit
from ultralytics.cfg import cfg2dict, copy_default_cfg
import os
import ultralytics
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

sns.set_theme(style="whitegrid")
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 14

ultralytics.settings.reset()

In [None]:
DATA_CONFIG_FILE = "data_config.yaml"
data_config = cfg2dict(DATA_CONFIG_FILE)
ROOT_DIR = data_config["path"]

if not os.path.exists(os.path.join(os.path.join(ROOT_DIR, "autosplit_train.txt"))):
    autosplit(os.path.join(ROOT_DIR, "images"), (0.8, 0.2, 0))
    
train_path = os.path.join(ROOT_DIR, "autosplit_train.txt")
val_path = os.path.join(ROOT_DIR, "autosplit_val.txt")

train_lines = []
val_lines = []

if os.path.exists(train_path) and os.path.exists(val_path):
    plt.figure(figsize=(10,6))
    with open(train_path, 'r') as f:
        train_lines = f.readlines()
    with open(val_path, 'r') as f:
        val_lines = f.readlines()

    ax = sns.barplot(
            x=["train", "val"],
            y=[len(train_lines), len(val_lines)],
    )

    plt.title("Dataset Split Distribution")
    ax.bar_label(ax.containers[0])
    plt.ylabel("Number of Images")
    plt.show()

else:
    print("Dataset split files (autosplit_train.txt or autosplit_val.txt) not found. Autosplit may have failed or not been run.")


In [None]:
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

def collect_object_data(root_dir, data):
    """Collects object data (dataset, class) from label files."""
    
    object_data = []
    
    for dataset_name, file_list in data.items():
        if dataset_name in ["names"]: #skip names.
            continue
        for entry in file_list:
            label_filename = os.path.splitext(os.path.basename(entry))[0] + ".txt"
            label_path = os.path.normpath(os.path.join(root_dir, "labels", label_filename))
            
            if os.path.exists(label_path):
                with open(label_path, 'r') as f:
                    for line in f:
                        class_id = line.split()[0]
                        object_data.append({
                            'Dataset': dataset_name,
                            'Class': class_id
                        })
    return object_data

def plot_dataset_class_counts(data, root_dir):
    """Plots a bar chart of class counts per dataset (with hue)."""
    
    object_data = collect_object_data(root_dir, data)
    df = pd.DataFrame(object_data).replace("1", data["names"][1]).replace("2", data["names"][2])
    
    plt.figure()
    ax = sns.countplot(x='Dataset', hue='Class', data=df)
    
    for p in ax.patches:
        height = p.get_height()
        if height > 0: # Avoid displaying 0 values
            ax.text(p.get_x() + p.get_width() / 2., height, f'{int(height)}', 
                    ha='center', va='bottom')
    
    plt.title('Object Counts per Dataset, Grouped by Class')
    plt.show()

# Example usage (assuming data_config and ROOT_DIR are defined):
plot_dataset_class_counts({
    "names": data_config["names"],
    "train": train_lines,
    "val": val_lines,
}, ROOT_DIR)

In [None]:
if not os.path.exists("config.yaml"):
    copy_default_cfg()
    os.rename("default_copy.yaml", "config.yaml")
trainer_args = cfg2dict("config.yaml")
trainer = DetectionTrainer(overrides=trainer_args)

In [None]:
trainer.train()