# Model Fine-Tuning and Export

## Preamble

### Imports

In [None]:
import os
import shutil
import zipfile
from collections import Counter
from typing import Optional

import albumentations as A
import cv2
import fiftyone as fo
import matplotlib.pyplot as plt
import pymongo
import yaml
from roboflow import Roboflow
from roboflow.core.dataset import Dataset as RoboflowDataset
from ultralytics import YOLO

### General Configuration

In [None]:
# Define folder paths.
ROOT_PATH = os.path.abspath(os.path.join(os.getcwd()))
IMAGES_PATH = os.path.join(ROOT_PATH, 'images')

# Format to use when downloading Roboflow datasets.
RF_DATASET_FORMAT = "yolov8"

# Format to use when downloading FiftyOne datasets.
FO_DATASET_FORMAT = fo.types.YOLOv5Dataset # YOLOv5 and YOLOv8 use the same format

FiftyOne uses MongoDB to manage its datasets. When possible, FiftyOne will automatically set up the database for you. However, when it fails to do so, you need to manually set up a MongoDB database. The code below checks if FiftyOne is able to set up the database — if not, then you must set up your own and specify the connection string. After installing MongoDB, run `mongod --dbpath <DBPATH>`, replacing `DBPATH` with any path of your choice. By default (no authentication and using the default port), the connection string is: `mongodb://localhost:27017`.

In [None]:
while True:
    try:
        print("Trying to reach MongoDB...")
        fo.core.odm.database.get_db_config()
        print("MongoDB is reachable.")
        break
    except (fo.core.config.FiftyOneConfigError, pymongo.errors.ServerSelectionTimeoutError):
        print("Failed to reach a running MongoDB instance. Enter a valid MongoDB connection string:")
        db_uri = input()
        fo.config.database_uri = db_uri

## Datasets

Helper function to `gitignore` a directory:

In [None]:
def gitignore(directory: str):
    """
    Make the given directory ignored by Git.

    No prefixes are prepended to the directory. The directory must already exist.

    This function adds a `.gitignore` file to the directory
    containing the wildcard pattern "*" so that git ignores it.
    """
    if not os.path.isdir(directory):
        raise ValueError("The given path does not exist or is not a directory.")
        
    gitignore_path = os.path.join(directory, ".gitignore")
    with open(gitignore_path, "w") as gitignore_file:
        gitignore_file.write("*")

We'll create the `IMAGES_PATH` directory early to make `git` ignore it:

In [None]:
if not os.path.exists(IMAGES_PATH):
    os.makedirs(IMAGES_PATH)
    print(f"Created '{IMAGES_PATH}' directory.")
else:
    print(f"'{IMAGES_PATH}' exists — nothing to do.")

In [None]:
if not os.path.exists(os.path.join(IMAGES_PATH, ".gitignore")):
    gitignore(IMAGES_PATH)
    print(f"Gitignored '{IMAGES_PATH}'.")
else:
    print(f"'{IMAGES_PATH}/.gitignore' exists — skipping.")

### Roboflow Datasets

To download datasets from Roboflow, you must have a Roboflow API key. This notebook will attempt to load the API key from the `ROBOFLOW_API_KEY` environment variable. If the variable does not exist, then you will be prompted for it.

In [None]:
if "ROBOFLOW_API_KEY" not in os.environ:
    print("Could not find Roboflow API key from environment.")
    print("Please enter your Roboflow API key: ")
    rf_api_key = input()
else:
    rf_api_key = os.environ["ROBOFLOW_API_KEY"]

rf = Roboflow(api_key=rf_api_key)

In [None]:
def download_roboflow_dataset(workspace: str, project: str, version: str, directory: str, dataset_format=RF_DATASET_FORMAT):
    """
    Downloads the specified Roboflow dataset into the given directory
    and returns the dataset as a Roboflow `Dataset` object.

    The directory will be prefixed by `IMAGES_PATH`.

    If the directory already exists, the dataset will not be redownloaded.
    """
    abs_directory = os.path.join(IMAGES_PATH, directory)

    rf_project = rf.workspace(workspace).project(project)
    rf_version = rf_project.version(version)
    
    if os.path.exists(abs_directory):
        print(f"Path '{abs_directory}' exists — refusing to overwrite.")
        print("If you want to redownload the dataset, please manually remove the directory.")
        return RoboflowDataset(rf_version.name, rf_version.version, dataset_format, abs_directory)
        
    dataset = rf_version.download(dataset_format, location=abs_directory)

    print(f"Dataset downloaded to: {abs_directory}")
    
    return dataset

In [None]:
gun_ds = download_roboflow_dataset("liteye-systems", "weapon-classification", "2", "guns")

### COCO Dataset

In [None]:
def download_coco2017(
    categories: Optional[list[str]] = ["person"],
    max_samples: Optional[int] = None,
    directory: str = "coco-2017",
    dataset_format=FO_DATASET_FORMAT,
    **kwargs,
):
    """
    Downloads the COCO 2017 dataset into the given directory.

    All splits will be downloaded. The dataset can be filtered by category
    using the `categories` argument. If `max_samples` is specified, then each
    split will be limited to have a maximum of `max_samples` number of samples.
    
    By default, the dataset will be exported in the format specified by `FO_DATASET_FORMAT`.
    To change the output format, specify the `dataset_format` argument.
    """
    splits = ["train", "test", "validation"]
    
    dataset = fo.zoo.load_zoo_dataset(
        "coco-2017",
        splits=splits,
        label_types=["detections"],
        max_samples=max_samples,
        **kwargs
    )

    # Rename 'validation' split to 'val'
    validation_view = dataset.match_tags("validation")
    validation_view.tag_samples("val")
    validation_view.untag_samples("validation")
    
    splits.remove("validation")
    splits.append("val")

    ds_view = dataset.view()

    # Manually filter the dataset to samples matching the given catgories
    # due to a bug: https://github.com/voxel51/fiftyone/issues/4570
    # Workaround based on: https://github.com/voxel51/fiftyone/issues/4570#issuecomment-2392548410
    # Unfortunately, the workaround downloads images we don't need and then filters them,
    # so we waste a bit of space and network bandwidth.
    if categories is not None:
        ds_view = ds_view.filter_labels("ground_truth", fo.ViewField("label").is_in(categories))

    # Export in YOLOv8 format.
    # According to https://github.com/voxel51/fiftyone/issues/3392#issuecomment-1666520356,
    # splits must be exported separately.
    export_dir = os.path.join(IMAGES_PATH, directory)
    for split in splits:
        view = ds_view.match_tags(split)
        view.export(
            export_dir=export_dir,
            dataset_type=dataset_format,
            split=split,
            classes=categories,
        )
        print(f"Split '{split}' exported to '{export_dir}/{split}'")

In [None]:
download_coco2017(max_samples=1000)

## Unzip dataset zip file (Not used anymore)

In [None]:
# unzip data files into the specified folder
def unzip_data(zip_file, folder_path):
    # created folder if it does not exist
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    # unzip the contents of the zip file to the destination folder
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        zip_ref.extractall(folder_path)

    print(f"{zip_file} unzip to {folder_path}")

## Combined datasets

In [None]:
# Combined folder for datasets 
COMBINED_FOLDER = os.path.join(DATAFOLDER, 'combined-images')

if not os.path.exists(COMBINED_FOLDER):
    os.makedirs(COMBINED_FOLDER)

def combine_and_rename(src_folder, dataset_name, class_offset):
    for split in ['train', 'valid', 'test']:
        img_src_folder = os.path.join(src_folder, split, 'images')
        lbl_src_folder = os.path.join(src_folder, split, 'labels')

        img_dest_folder = os.path.join(COMBINED_FOLDER, split, 'images')
        lbl_dest_folder = os.path.join(COMBINED_FOLDER, split, 'labels')

        if not os.path.exists(img_dest_folder):
            os.makedirs(img_dest_folder)
        if not os.path.exists(lbl_dest_folder):
            os.makedirs(lbl_dest_folder)

        img_files = sorted(os.listdir(img_src_folder))
        lbl_files = sorted(os.listdir(lbl_src_folder))

        for i, img_file in enumerate(img_files):
            lbl_file = lbl_files[i]

            new_img_name = f"{dataset_name}-img-{i+1}.jpg"
            new_lbl_name = f"{dataset_name}-img-{i+1}.txt"

            img_src_path = os.path.join(img_src_folder, img_file)
            lbl_src_path = os.path.join(lbl_src_folder, lbl_file)

            img_dest_path = os.path.join(img_dest_folder, new_img_name)
            lbl_dest_path = os.path.join(lbl_dest_folder, new_lbl_name)

            # Read and modify label file content (adjust class IDs)
            with open(lbl_src_path, 'r') as lbl_file:
                lines = lbl_file.readlines()

            with open(lbl_dest_path, 'w') as new_lbl_file:
                for line in lines:
                    parts = line.strip().split()
                    class_id = int(parts[0]) + class_offset
                    new_line = f"{class_id} " + " ".join(parts[1:]) + "\n"
                    new_lbl_file.write(new_line)

            # Copy the image file
            shutil.copy(img_src_path, img_dest_path)

            # Augmentation logic
            # Only augment underrepresented classes based on `class_occurrences`
            # with open(lbl_src_path, 'r') as lbl_file:
            #     for line in lbl_file.readlines():
            #         class_id = int(line.split()[0]) + class_offset
            #         # If the class ID has fewer samples, augment the image
            #         if class_occurrences[class_id] < 500:  # Define threshold for underrepresented classes
            #             augment_image(img_src_path, lbl_src_path, img_dest_folder, i, dataset_name)

            #print(f"Moved {img_file} -> {new_img_name} with updated labels")

## Create yaml file for combined dataset

In [None]:
TRAIN_PATH = os.path.join(COMBINED_FOLDER, 'train', 'images')
VAL_PATH = os.path.join(COMBINED_FOLDER, 'valid', 'images')
TEST_PATH = os.path.join(COMBINED_FOLDER, 'test', 'images')
OUTPUT_PATH = os.path.join(COMBINED_FOLDER, 'data.yaml')


def load_yaml(yaml_path):
    with open(yaml_path, 'r') as f:
        return yaml.safe_load(f)

def combine_yaml_new(yaml_files, class_offsets):
    combined_data = {
        'train': TRAIN_PATH,
        'val': VAL_PATH,
        'test': TEST_PATH,
        'names': []  # will hold class names at the right index positions
    }

    # Loop through each yaml file and its corresponding class offset
    for i, yaml_file in enumerate(yaml_files):
        data = load_yaml(yaml_file)
        offset = class_offsets[i]  # get the offset for the current dataset

        # Ensure combined_data['names'] list is large enough
        max_index = offset + len(data['names']) - 1
        while len(combined_data['names']) <= max_index:
            combined_data['names'].append(None)

        # Insert class names at the correct index positions
        for j, class_name in enumerate(data['names']):
            combined_data['names'][offset + j] = class_name

    # Fill any None entries with default values (optional, in case something was missed)
    combined_data['names'] = [name if name is not None else 'unknown' for name in combined_data['names']]

    combined_data['nc'] = len(combined_data['names'])  # set the number of unique classes

    # Write combined data to a new YAML file
    with open(OUTPUT_PATH, 'w+') as yaml_f:
        yaml.dump(combined_data, yaml_f)

    print(f"Combined YAML file created at {OUTPUT_PATH}")

## Check for Data Imbalance 

In [None]:
def count_class_occurrences(folder_path):
    label_files = []
    for split in ['train', 'valid', 'test']:
        lbl_folder = os.path.join(folder_path, split, 'labels')
        for file in os.listdir(lbl_folder):
            label_files.append(os.path.join(lbl_folder, file))

    class_counter = Counter()
    for lbl_file in label_files:
        with open(lbl_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                class_id = int(line.split()[0])
                class_counter[class_id] += 1

    print(f"Class Occurrences: {class_counter}")
    return class_counter

## Handle Imbalanced Data (Does not work right now)

In [None]:
# Define the augmentation pipeline
AUGMENTATIONS = [
    A.HorizontalFlip(p=1.0),
    A.Rotate(limit=20, p=1.0),
    A.RandomBrightnessContrast(p=1.0),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=20, p=1.0)
]

# Augment image function
def augment_image(image_path, label_path, aug_folder, index, dataset_name):
    image = cv2.imread(image_path)
    img_height, img_width = image.shape[:2]

    # Read bounding box labels
    with open(label_path, 'r') as file:
        lines = file.readlines()

    for i, aug in enumerate(AUGMENTATIONS):
        augmented = aug(image=image)
        aug_image = augmented['image']

        # Save the augmented image with a new name
        aug_img_name = f"{dataset_name}-img-{index+1}-aug-{i+1}.jpg"
        aug_img_path = os.path.join(aug_folder, aug_img_name)
        cv2.imwrite(aug_img_path, aug_image)

        # Save the corresponding augmented label file
        aug_lbl_name = f"{dataset_name}-img-{index+1}-aug-{i+1}.txt"
        aug_lbl_path = os.path.join(aug_folder, aug_lbl_name)
        with open(aug_lbl_path, 'w') as aug_label_file:
            aug_label_file.writelines(lines)

    print(f"Augmented image and label saved to {aug_folder}")

## Fine Tune Model

In [None]:
model = YOLO('yolov8n.pt')

In [None]:
# fine tune the YOLO model with new dataset
# epochs=5 for testing purposes
def fine_tune(model, yaml_path, epochs=5, imgsz=640, batch=16, device=None):
    # model.train(data=yaml_path, epochs=epochs, imgsz=imgsz, batch=batch)
    # prepare the arguments for model.train
    train_kwargs = {
        'data': yaml_path,
        'epochs': epochs, 
        'imgsz': imgsz,
        'batch': batch
    }

    # include 'deivce' only if its not none 
    if device is not None: 
        train_kwargs['device'] = device

    # train model 
    model.train(**train_kwargs)
    
    return model

In [None]:
# save the model 
def save_model(model):
    model.save('yolo_fine_tuned.pt')

## Combining Datasets (Will create a new section for this)

In [None]:
# unzip parcel images dataset into the 'parcel-images' folder
unzip_data(os.path.join(DATAFOLDER, "parcel.v1i.yolov8.zip"), os.path.join(DATAFOLDER, "parcel-images"))

# unzip the gun images dataset into the 'gun-images' folder
unzip_data(os.path.join(DATAFOLDER, "Weapon classification.v2i.yolov8.zip"), os.path.join(DATAFOLDER, "gun-images"))

# unzip the human images dataset into the 'human-images' folder
unzip_data(os.path.join(DATAFOLDER, "Crowd Detection.v3i.yolov8.zip"), os.path.join(DATAFOLDER, "human-images"))

# dataset source folder 
parcel_folder = os.path.join(DATAFOLDER, "parcel-images")
gun_folder = os.path.join(DATAFOLDER, "gun-images")
human_folder = os.path.join(DATAFOLDER, "human-images")

# Calculate class occurrences 
# class_occurrences = count_class_occurrences(COMBINED_FOLDER)
# print(f"Before balancing data: {class_occurrences}")

# combine dataset with augmentation for underrepresented classes
combine_and_rename(parcel_folder, 'parcel', class_offset=0)
combine_and_rename(gun_folder, 'gun', class_offset=1)
combine_and_rename(human_folder, 'human', class_offset=3)

# print(f"After balancing data: {count_class_occurrences(COMBINED_FOLDER)}")

# path to yaml configuration for parcel images and gun images
parcel_yaml = os.path.join(DATAFOLDER, "parcel-images", "data.yaml")
gun_yaml = os.path.join(DATAFOLDER, "gun-images", "data.yaml")
human_yaml = os.path.join(DATAFOLDER, "human-images", "data.yaml")

class_offsets = [0, 1, 3]
yaml_list = [parcel_yaml, gun_yaml, human_yaml]

combine_yaml_new(yaml_list, class_offsets)

count_class_occurrences(COMBINED_FOLDER)

## New Dataset Combination

In [None]:
# data source folder 
gun_folder = os.path.join(DATAFOLDER, "gun-images")

# combine dataset 
combine_and_rename(gun_folder, 'gun', class_offset=0)

gun_yaml = os.path.join(DATAFOLDER, "gun-images", "data.yaml")

class_offsets = [0]
yaml_list = [gun_yaml]

combine_yaml_new(yaml_list, class_offsets)

count_class_occurrences(COMBINED_FOLDER)

## Fine Tune Model 

In [None]:
# fine tune the YOLO model with combined dataset images 
fined_tuned_model = fine_tune(model, os.path.join(COMBINED_FOLDER, "data.yaml"))

#fined_tuned_model = fine_tune(model, gun_yaml)

# save the model
save_model(fined_tuned_model)

## Load the Fine-Tuned Model

In [None]:
def load_model(model_name):
    model = YOLO(model_name)
    return model

fined_tuned_model = load_model('yolo_fine_tuned.pt')

## Test Model 

In [None]:
%matplotlib inline
def test_model(model, img_path, conf=0.25):
    # Perform object detection
    results = model(img_path, conf=conf)

    # retrieve the annotated image (with bounding boxes and labels)
    annotated_img = results[0].plot()

    # display the image
    plt.figure(figsize=(10, 10))
    plt.imshow(annotated_img)
    plt.axis('off')
    plt.show()


test_model(fined_tuned_model, os.path.join(FOLDER_PATH, 'sentinel_model_gen', 'guy_w_box.png'))
test_model(fined_tuned_model, os.path.join(FOLDER_PATH, 'sentinel_model_gen', 'guy.png'))
test_model(fined_tuned_model, os.path.join(FOLDER_PATH, 'sentinel_model_gen', 'guy_w_gun.png'))
test_model(fined_tuned_model, os.path.join(FOLDER_PATH, 'sentinel_model_gen', 'guy_w_gun_2.png'))

## Export the model to TFLite Format

In [None]:
# export the model to TFLite for use in the detection system 
def export_model(model): 
    model.export(format='tflite')