# Yolov8 Model Training

This notebook helps easily train a yolo model. Model evaluation will happen in another notebook.

## 1. Imports and variable setup

In [1]:
# Import all libraries
import os
from roboflow import Roboflow
from IPython import display
import ultralytics
from ultralytics import YOLO
from IPython.display import display, Image
from pathlib import Path
from dotenv import find_dotenv, load_dotenv
import sys
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch

# Setup HOME environment variable
HOME = os.getcwd()
print(HOME)

if torch.cuda.is_available():
    Device = torch.device("cuda")
elif torch.backends.mps.is_available():
    Device = torch.device("mps")
else:
    Device = torch.device("cpu")
print(Device)

/Users/mreagles524/Documents/gitrepos/projects/DBD-Killer-AI/notebooks
mps


In [2]:
# Check image displays are good
# display.clear_output()
sys.path.append(str(Path.cwd().parent))
# Check ultralytics library is good
ultralytics.checks()

Ultralytics 8.3.53 🚀 Python-3.11.8 torch-2.5.1 CPU (Apple M1 Pro)
Setup complete ✅ (8 CPUs, 16.0 GB RAM, 261.4/460.4 GB disk)


In [3]:
# Environment variables
PROJECT_DIR = Path.cwd().parent
DATASET_VERSION = 9

## 2. Import Model

Pre-trained is the ideal model

In [9]:
# Load the model
model = YOLO(model=str(PROJECT_DIR) + "/models/yolov8n.pt")

## 3. Train Model

In [10]:
from dbdkillerai.data.make_dataset import roboflow_connect, roboflow_download
from dbdkillerai.models.train_model import train_yolo, validate_yolo

# Establish Roboflow connection and acquire dataset location. DONT download.
rf_conn, rf_project = roboflow_connect()
data_location = roboflow_download(rf_project=rf_project,
                                  rf_data_version=DATASET_VERSION,
                                  data_format="yolov8",
                                  project_dir=str(PROJECT_DIR) + "/data/external",
                                  overwrite=True)

# Get location of yml
yml_location = data_location.location + "/data.yaml"
print(f"Pulling yml from: \n{yml_location}")

loading Roboflow workspace...
loading Roboflow project...


Downloading Dataset Version Zip in /Users/mreagles524/Documents/gitrepos/projects/DBD-Killer-AI/data/external/deadbydaylightkillerai/killer_ai_object_detection/9 to yolov8:: 100%|██████████| 116616/116616 [00:01<00:00, 59596.85it/s]





Extracting Dataset Version Zip to /Users/mreagles524/Documents/gitrepos/projects/DBD-Killer-AI/data/external/deadbydaylightkillerai/killer_ai_object_detection/9 in yolov8:: 100%|██████████| 5714/5714 [00:00<00:00, 8089.19it/s]

Pulling yml from: 
/Users/mreagles524/Documents/gitrepos/projects/DBD-Killer-AI/data/external/deadbydaylightkillerai/killer_ai_object_detection/9/data.yaml





In [None]:
# Train the model
model, results_train = train_yolo(yolo_model=model,
                                    data_yml=yml_location,
                                    epochs=150,
                                    imgsz=800,
                                    plots=True,
                                    workers=0,
                                    device="mps")

## Validate Custom Model

In [None]:
best_weights_location = str(results_train.save_dir) + "/weights/best.pt"
best_model = YOLO(best_weights_location)
results_val = validate_yolo(yolo_model=best_model,
                            data_yml=yml_location)


Ultralytics YOLOv8.0.227 🚀 Python-3.10.13 torch-2.1.2 CUDA:0 (NVIDIA GeForce RTX 3080, 10240MiB)
Model summary (fused): 168 layers, 3006428 parameters, 0 gradients


[34m[1mval: [0mScanning /home/mreag/repos/DBD-Killer-AI/data/external/deadbydaylightkillerai/killer_ai_object_detection/7/valid/labels.cache... 238 images, 80 backgrounds, 0 corrupt: 100%|██████████| 238/238 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 15/15 [00:02<00:00,  6.26it/s]


                   all        238        343      0.806      0.713      0.765       0.52
              activity        238         12      0.815      0.733      0.777      0.501
             generator        238        189      0.862      0.884      0.904      0.656
                  hook        238        120      0.855      0.734      0.791      0.526
              survivor        238         22      0.694        0.5      0.588      0.396
Speed: 0.4ms preprocess, 2.8ms inference, 0.0ms loss, 1.0ms postprocess per image
Results saved to [1mruns/detect/val5[0m
