<img src="https://www.luxonis.com/logo.svg" width="400">

# Training a Luxonis-Train Model Using a **Custom Data Loader**

## 🌟 Overview
In this tutorial, the spotlight is on using a **custom data loader**.  
You'll learn how to create and integrate a custom loader to train a model with `luxonis-train`—ideal when your dataset doesn't follow standard formats or requires special preprocessing.

We’ll cover:
- Creating your own **custom Loader**
- Configuring the training pipeline
- Running training with `luxonis-train` using your Loader

## 📜 Table of Contents
- [🛠️ Installation](#️installation)
- [🗃️ Data Preparation](#data-preparation)  
  - [📥 Download COCO People Subset Dataset](#download-coco-people-subset-dataset)  
  - [🎨 Creating a **Custom Loader**](#creating-a-custom-loader)
  - [🧐 Inspecting Dataset using a **Custom Loader**](#inspecting-dataset-using-a-custom-loader)
- [🏋️‍♂️ Training](#training)  
  - [⚙️ Configuration](#configuration)  
  - [🦾 Train](#train)


<a name="️installation"></a>

## 🛠️ Installation


The main focus of this tutorial is using [`LuxonisTrain`](https://github.com/luxonis/luxonis-train), a user-friendly tool designed to streamline the training of deep learning models, especially for edge devices. We'll also use [`LuxonisML`](https://github.com/luxonis/luxonis-ml) since it provides us with a collection of utility functionality and an easy way of creating and managing computer vision datasets called `LuxonisDataset`.



In [None]:
%pip install -q luxonis-train>=0.3.6

<a name ="data-preparation"></a>

## 🗃️ Data Preparation

<a name ="download-coco-people-subset-dataset"></a>

### 📥 Download VOCDetection dataset

First, we will download the VOC detection dataset, which we will be working with.


In [None]:
import urllib.request
import tarfile
import os

url = "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar"
output_dir = "./data"
output_path = os.path.join(output_dir, "VOCtrainval_06-Nov-2007.tar")

os.makedirs(output_dir, exist_ok=True)
urllib.request.urlretrieve(url, output_path)

# Extract the tar file
with tarfile.open(output_path) as tar:
    tar.extractall(path=output_dir)


<a name ="creating-a-custom-loader"></a>

### 🎨 Creating a Custom Loader

🚀 **Implementing a Model Loader using `BaseLoaderTorch`**

When creating a custom data loader class by extending [`BaseLoaderTorch`](https://github.com/luxonis/luxonis-train/blob/d8a4e5b090a4806ce12f4d21ffa1cd5a41ad48dc/luxonis_train/loaders/base_loader.py), there are specific methods and properties that you **must implement** to ensure the class functions correctly within the Luxonis Train framework. Here's a concise overview of the requirements, using `VOCLoaderTorch` as an example:

### Required Methods to Implement:

- `__init__`:
  - Must call `super().__init__(**kwargs)` to initialize the base class properly.

- `__len__(self) -> int`:
  - Must return the total number of samples available in the dataset.

- `get(self, idx: int) -> tuple[Tensor | dict[str, Tensor], Labels]`:
  - Responsible for fetching and preprocessing a single sample (image and labels) from the dataset.
  - Includes reading images, parsing annotations, applying resizing, augmentations, and returning formatted tensors.
  - **Important**: The returned labels must follow the required `Labels` format expected by Luxonis-Train. For example:

    ```python
    labels: Labels = {
        "/boundingbox": torch.tensor(bboxes, dtype=torch.float32),  # [N,5] where 5 is [cls_id, xmin, ymin, xmax, ymax], with coordinates normalized to [0,1]
        "/classification": one_hot,  # [num_classes]
    }

    return img_tensor, labels
    ```

  For more supported annotation formats, refer to the documentation here: [Luxonis-Train Loaders](https://github.com/luxonis/luxonis-train/tree/main/luxonis_train/loaders#loaders)


- `get_classes(self) -> dict[str, dict[str, int]]`:
  - Must return a dictionary structured by **task names**.
  - The default task name key must be `""`.
  - Each entry should map class names to integer IDs. For example:

    ```python
    def get_classes(self) -> dict[str, dict[str, int]]:
        voc_classes = [
            'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
            'bus', 'car', 'cat', 'chair', 'cow',
            'diningtable', 'dog', 'horse', 'motorbike', 'person',
            'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
        ]
        return {'': {name: idx for idx, name in enumerate(voc_classes)}}
    ```

    This format ensures the dataset is structured according to task-based expectations in Luxonis-Train.


- `input_shapes(self) -> dict[str, torch.Size]`:
  - Defines the shape of input tensors without batch dimensions.
  - **Important**: The dictionary key must be `"image"`, representing the input image. For example:

    ```python
    @property
    def input_shapes(self) -> dict[str, torch.Size]:
        return {"image": torch.Size([3, self.height, self.width])}
    ```


### Example Implementation Highlights (`VOCLoaderTorch`):

- **Data preparation**: Parsing XML annotations and mapping bounding boxes correctly to resized image dimensions.
- **Augmentations**: Implementing random horizontal flips, color jittering, grayscale conversions.
- **Normalization**: Converting images to tensors and normalizing them to standard mean and standard deviation.

**Following these guidelines ensures that your loader integrates seamlessly into training pipelines, enabling model training, visualization and metric logging during training, and final model export for deployment on cameras.**


In [None]:
import os
import random
import xml.etree.ElementTree as ET

import cv2
import numpy as np
import torch
from torchvision import transforms
from luxonis_ml.typing import Labels
from luxonis_train.registry import LOADERS

from luxonis_train import BaseLoaderTorch

class VOCLoaderTorch(BaseLoaderTorch):
    def __init__(
        self,
        root: str = "./data/VOCdevkit",
        year: str = "2007",
        **kwargs,
    ):
        super().__init__(
            **kwargs,
        )

        self.root = root
        self.year = year
        self.is_training = "train" == self.view[0]

        # Get VOC classes and create mapping
        self.class_map = self.get_classes()[""]
        self.class_name_to_id = {name: idx for idx, name in enumerate(self.class_map)}

        # Collect image and annotation paths
        self.image_ids = []
        self.image_paths = []
        self.annotation_paths = []
        
        # Load image set
        image_set_file = os.path.join(
            self.root, f"VOC{year}", "ImageSets", "Main", f"{self.view[0]}.txt"
        )
        with open(image_set_file) as f:
            for line in f:
                image_id = line.strip().split()[0]
                self.image_ids.append(image_id)

        # Prepare paths
        for img_id in self.image_ids:
            img_path = os.path.join(
                self.root, f"VOC{year}", "JPEGImages", f"{img_id}.jpg"
            )
            ann_path = os.path.join(
                self.root, f"VOC{year}", "Annotations", f"{img_id}.xml"
            )
            self.image_paths.append(img_path)
            self.annotation_paths.append(ann_path)

    def __len__(self) -> int:
        return len(self.image_ids)

    def get(self, idx: int) -> tuple[torch.Tensor, Labels]:
        # Read image
        img_path = self.image_paths[idx]
        img = self.read_image(img_path)

        # Parse annotation
        tree = ET.parse(self.annotation_paths[idx])
        root_xml = tree.getroot()

        # Get original dimensions
        size = root_xml.find("size")
        original_height = int(size.find("height").text)
        original_width = int(size.find("width").text)

        # Extract objects
        objects = []
        for obj in root_xml.findall("object"):
            name = obj.find("name").text
            bbox = obj.find("bndbox")
            xmin = float(bbox.find("xmin").text)
            ymin = float(bbox.find("ymin").text)
            xmax = float(bbox.find("xmax").text)
            ymax = float(bbox.find("ymax").text)
            objects.append({
                "name": name,
                "xmin": xmin,
                "ymin": ymin,
                "xmax": xmax,
                "ymax": ymax,
            })

        # Letterbox resize
        # Compute scale and new size
        scale = min(self.width / original_width, self.height / original_height)
        new_w = int(original_width * scale)
        new_h = int(original_height * scale)
        # Resize image to new size
        resized_img = cv2.resize(img, (new_w, new_h))
        # Compute padding
        pad_w = self.width - new_w
        pad_h = self.height - new_h
        pad_left = pad_w // 2
        pad_right = pad_w - pad_left
        pad_top = pad_h // 2
        pad_bottom = pad_h - pad_top
        # Apply padding (using a constant value, e.g., 128)
        img = cv2.copyMakeBorder(resized_img, pad_top, pad_bottom, pad_left, pad_right,
                                 cv2.BORDER_CONSTANT, value=(128, 128, 128))

        # Apply augmentations
        flip = False
        if self.is_training:
            # Color jitter
            if random.random() < 0.5:
                img = self.apply_color_jitter(img)

            # Random grayscale
            if random.random() < 0.1:
                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
                img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

            # Horizontal flip
            if random.random() < 0.5:
                img = cv2.flip(img, 1)
                flip = True

        # Process bounding boxes using letterbox adjustments
        bboxes = []
        class_ids = []
        for obj in objects:
            xmin = obj["xmin"]
            ymin = obj["ymin"]
            xmax = obj["xmax"]
            ymax = obj["ymax"]

            # Scale coordinates with letterbox scale and add padding offset
            xmin = xmin * scale + pad_left
            xmax = xmax * scale + pad_left
            ymin = ymin * scale + pad_top
            ymax = ymax * scale + pad_top

            # Apply flip if required
            if flip:
                xmin, xmax = self.width - xmax, self.width - xmin

            # Normalize to [0,1]
            xmin /= self.width
            xmax /= self.width
            ymin /= self.height
            ymax /= self.height

            cls_id = self.class_name_to_id[obj["name"]]
            bboxes.append([cls_id, xmin, ymin, xmax, ymax])
            class_ids.append(cls_id)

        # Convert to tensor and normalize image
        img_tensor = transforms.ToTensor()(img)
        img_tensor = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )(img_tensor)

        # Compute one-hot encoded vector for classification
        num_classes = len(self.class_map)
        one_hot = torch.zeros(num_classes, dtype=torch.int64)
        for cls in class_ids:
            one_hot[cls] = 1

        # Prepare labels
        labels: Labels = {
            "/boundingbox": torch.tensor(bboxes, dtype=torch.float32),  # [cls_id, xmin, ymin, xmax, ymax]
            "/classification": one_hot,  # [num_classes]
        }

        return img_tensor, labels

    def get_classes(self) -> dict[str, dict[str, int]]:
        voc_classes = [
            'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
            'bus', 'car', 'cat', 'chair', 'cow',
            'diningtable', 'dog', 'horse', 'motorbike', 'person',
            'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
        ]
        return {'': {name: idx for idx, name in enumerate(voc_classes)}} 

    @property
    def input_shapes(self) -> dict[str, torch.Size]:
        return {"image": torch.Size([3, self.height, self.width])}

    def apply_color_jitter(self, img: np.ndarray) -> np.ndarray:
        """Applies random color jitter to image"""
        # Random brightness, contrast, and saturation
        brightness = random.uniform(0.8, 1.2)
        contrast = random.uniform(0.8, 1.2)
        saturation = random.uniform(0.8, 1.2)

        # Convert to HSV
        img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32)
        
        # Apply transformations
        img_hsv[..., 1] *= saturation
        img_hsv[..., 1] = np.clip(img_hsv[..., 1], 0, 255)
        img_hsv[..., 2] *= brightness * contrast
        img_hsv[..., 2] = np.clip(img_hsv[..., 2], 0, 255)

        # Convert back to RGB
        img = cv2.cvtColor(img_hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)
        return img

<a name ="creating-a-custom-loader"></a>

### 🧐 Inspecting Dataset using a **Custom Loader**

Below we show an example of inspecting the dataset.  
We **must** provide the width and height for the `BaseLoaderTorch`.


In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Function to denormalize the image
def denormalize(img_tensor, mean, std):
    img = img_tensor.clone()
    for t, m, s in zip(img, mean, std):
        t.mul_(s).add_(m)
    return img

# Convert tensor to NumPy array for display
def tensor_to_numpy(img_tensor):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    img = denormalize(img_tensor, mean, std)
    # Convert from CxHxW to HxWxC
    img = img.numpy().transpose(1, 2, 0)
    # Clip and convert to uint8 (values in 0-255)
    img = (img * 255).clip(0, 255).astype('uint8')
    return img

def plot_image_with_boxes(img_tensor, labels):
    img = tensor_to_numpy(img_tensor['image'])
    h, w, _ = img.shape

    fig, ax = plt.subplots(1)
    ax.imshow(img)

    # Bounding boxes are stored as [class_id, xmin, ymin, xmax, ymax] in normalized coordinates
    boxes = labels["/boundingbox"].numpy()
    for box in boxes:
        class_id, xmin, ymin, xmax, ymax = box
        # Convert from normalized to pixel coordinates
        x = xmin * w
        y = ymin * h
        box_w = (xmax - xmin) * w
        box_h = (ymax - ymin) * h

        rect = patches.Rectangle(
            (x, y),
            box_w,
            box_h,
            linewidth=2,
            edgecolor='red',
            facecolor='none'
        )
        ax.add_patch(rect)
        ax.text(
            x,
            y,
            str(int(class_id)),
            color='yellow',
            bbox=dict(facecolor='red', alpha=0.5)
        )

    plt.axis('off')
    plt.show()

loader = VOCLoaderTorch(view=["train"], height=512, width=512) # BaseLoaderTorch expects height and width to be set

image, labels = loader[np.random.randint(0, len(loader))]

plot_image_with_boxes(image, labels)

<a name ="️️training"></a>

## 🏋️‍♂️ Training

<a name="️configuration"></a>

### ⚙️ Configuration

Below we define the training configuration for model training. Important notes:

- We **must** set `trainer.preprocessing.train_image_size`, which defines the width and height for the `BaseLoaderTorch`, effectively telling the loader what size of images we will train with.
- Additionally, it is important to set the normalization parameters, even though we are using our own custom implementation.  
  This is necessary **only** for visualizations during training and does not affect the actual data loading, since we have our own custom loader.


👉 For the full list of all parameters, please check [Luxonis-Train](https://github.com/luxonis/luxonis-train/tree/main).


In [None]:
%%writefile detection_light_model.yaml
model:
  name: detection_light
  predefined_model:
    name: DetectionModel
    params:
      variant: light
      loss_params:
        iou_type: "siou"

        # Should be 2.5 * accumulate_grad_batches for best results
        iou_loss_weight: 20

        # Should be 1 * accumulate_grad_batches for best results
        class_loss_weight: 8

loader:
  test_view: val # there is no test set in VOC
  name: VOCLoaderTorch
  params:
    root: ./data/VOCdevkit/
    year: 2007
  

trainer:
  preprocessing:
    train_image_size: [384, 512] # Needed for the BaseLoaderTorch
    normalize: # Needed just for the visualization denormalization in luxonis-train
      active: true
      params:
        mean: [0.485, 0.456, 0.406]
        std: [0.229, 0.224, 0.225]
  
  precision: "16-mixed"
  batch_size: 8
  epochs: 300
  # For best results, always accumulate gradients to
  # effectively use 64 batch size
  accumulate_grad_batches: 8
  n_workers: 8
  validation_interval: 1
  n_log_images: 50

  callbacks:
    - name: EMACallback
      params:
        decay: 0.9999
        use_dynamic_decay: True
        decay_tau: 2000
    - name: ExportOnTrainEnd
    - name: TestOnTrainEnd

  training_strategy:
    name: "TripleLRSGDStrategy"
    params:
      warmup_epochs: 2
      warmup_bias_lr: 0.05
      warmup_momentum: 0.5
      lr: 0.0032
      lre: 0.000384
      momentum: 0.843
      weight_decay: 0.00036
      nesterov: True


<a name ="train"></a>

### 🦾 Train


To start the training, we need to initialize the `LuxonisModel`, pass it the path to the configuration file, and call the `train()` method on it.

 **Note**: LuxonisTrain also supports all these commands through its CLI ([documentation here](https://github.com/luxonis/luxonis-train/tree/main?tab=readme-ov-file#-cli)), no code required. For custom nodes, simply provide the `--source` flag with the path to where your custom components are initialized:

 ```bash
 luxonis_train --source custom_components.py train --config detection_ligth_model.yaml
 ```

We won't use the CLI for this tutorial, but feel free to use it in your own projects.

In [None]:
from luxonis_train import LuxonisModel

path = "./detection_light_model.yaml"
model = LuxonisModel(cfg=path)

model.train()

`LuxonisTrain` has also already implemented automatic tracking of training runs. By default, `Tensorboard` is used, and to look at the losses, metrics, and visualizations during training, we can inspect the logs. If you check the `output` folder, you'll see that every run creates a new directory, and each run also has its training logs in the `./output/tensorboard_logs` where the name of the folder matches the run's name. 