# YOLO Food Detection Training


**Note:** For fast training, it's best to use a GPU runtime in Google Colab. You can upload this notebook to your Colab environment.

## Step 0: Environment Setup

**IMPORTANT**: Make sure to enable GPU runtime!
- Go to: **Runtime > Change runtime type > Hardware accelerator > GPU**
- Choose **T4 GPU** (free tier).


In [None]:
# Check GPU availability
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("No GPU detected - please enable GPU runtime!")

In [None]:
!pip install ultralytics -q

In [None]:
# Import libraries
import os
import json
import time
import shutil
import yaml
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings

from tqdm import tqdm
from pathlib import Path
from ultralytics import YOLO
from IPython.display import display, Image, clear_output
from concurrent.futures import ThreadPoolExecutor

# Suppress warnings
warnings.filterwarnings('ignore')

## Step 1: Prepare food image dataset

Instead of using the full 3GB dataset, the following cells focus on the smaller dataset (375MB) from the [dataset-ninja website](https://datasetninja.com/food-recognition).

In [None]:
import os
import tarfile
from tqdm import tqdm


def download_and_extract(dataset_type="sample", extract_dir="/content/extracted_food_recognition"):

    urls = {
        "full": "https://assets.supervisely.com/remote/eyJsaW5rIjogImZzOi8vYXNzZXRzLzk1MF9Gb29kIFJlY29nbml0aW9uIDIwMjIvZm9vZC1yZWNvZ25pdGlvbi0yMDIyLURhdGFzZXROaW5qYS50YXIiLCAic2lnIjogIlpqZisyZURmaEoyZkhWNGRiTHBPWkEzN0NodWhlb28wNlZlQXpQQkdBc1U9In0=",
        "sample": "https://assets.supervisely.com/supervisely-supervisely-assets-public/teams_storage/Y/I/Yu/V12n4DX73dwmLY7sz7Bl83qOdACZBbaB9ctVmlEUPBx1qqaqLqnHsLubJQWGAKu3vrqBPty6hRBOrUaXzgPe1jMk7bI1MVcWCNON9vbLeZdPKuRV9Psis3STGSh7.tar"
    }

    if dataset_type not in urls:
        raise ValueError("dataset_type must be either 'full' or 'sample'")

    url = urls[dataset_type]
    tar_path = f"/content/{dataset_type}_food_dataset.tar"

    print(f"Downloading [{dataset_type}] dataset...")
    !wget -q --show-progress -O "$tar_path" "$url"

    if os.path.exists(extract_dir):
        print(f"{extract_dir} already exists. Skipping extraction.")
        return
    
    with tarfile.open(tar_path, "r") as tar:
        members = tar.getmembers()
        for member in tqdm(members, desc="Extracting"):
            tar.extract(member, path=extract_dir)
            

    print(f"Dataset extracted to {extract_dir}")

download_and_extract("sample")

### Optional: Use dataset-tools for downloading.

In [None]:
!pip install dataset-tools

import os
import dataset_tools as dtools

drive_path = 'content/datasets/FoodRecognition2022'
marker_file = os.path.join(drive_path, 'meta.json')  

if not os.path.exists(marker_file):
    print("Downloading dataset...")
    # It will handle the extraction of the dataset automatically.
    dtools.download(dataset='Food Recognition 2022', dst_dir=drive_path)
else:
    print("Dataset already exists.")

## Step2: Training Preparation

### Filter category with a small number of training images


#### 1. Avaliable categories:

The class/food category names that are detected in this image dataset can be found in the file `meta.json`:

```json
{
    "classes": [
        {
            "title": "bread-wholemeal",
            "shape": "polygon",
            "color": "#0F8A39",
            "geometry_config": {},
            "id": 3415,
            "hotkey": ""
        },
        {
            "title": "jam",
            "shape": "polygon",
            "color": "#8A460F",
            "geometry_config": {},
            "id": 3417,
            "hotkey": ""
        }
        ...
    ]
}
```

#### 2. Imbalanced food class distribution

You can check the food class distribution in [this section](https://datasetninja.com/food-recognition#class-balance).

To improve training efficiency, for categories that have fewer than 30 images, this project will skip these categories.

In [None]:
# hard code the excluded food categories with less than 30 images
excluded_food_categories = set([
    "oil",
    "tea-spice",
    "tea-fruit",
    "tea-ginger",
    "tea-rooibos",
    "chocolate-filled",
    "light-beer",
    "bagel-without-filling",
    "oat-milk",
    "buckwheat-pancake",
    "gummi-bears-fruit-jellies-jelly-babies-with-fruit-essence",
    "corn-flakes",
    "ice-cubes",
    "black-forest-tart",
    "m-m-s",
    "chocolate-milk-chocolate-drink",
    "cake-marble",
    "cake-salted",
    "mango-dried",
    "blackberry",
    "italian-salad-dressing",
    "soup-potato",
    "brazil-nut",
    "pastry-flaky",
    "champagne",
    "macaroon",
    "dumplings",
    "sekt",
    "soya-drink-soy-milk",
    "soya-yaourt-yahourt-yogourt-ou-yoghourt",
    "chorizo",
    "turnover-with-meat-small-meat-pie-empanadas",
    "white-chocolate",
    "margarine",
    "mix-of-dried-fruits-and-nuts",
    "white-radish",
    "grissini",
    "apricot-dried",
    "smoked-cooked-sausage-of-pork-and-beef-meat-sausag",
    "soup-cream-of-vegetables",
    "prosecco",
    "soup-miso",
    "kebab-in-pita-bread",
    "mushroom-average-stewed-without-addition-of-fat-without-addition-of-salt",
    "cooked-sausage",
    "sugar-glazing",
    "maple-syrup-concentrate",
    "philadelphia",
    "aperitif-with-alcohol-aperol-spritz",
    "damson-plum",
    "pie-rhubarb-baked-with-cake-dough",
    "linseeds",
    "lasagne-vegetable-prepared",
    "milk-chocolate-with-hazelnuts",
    "popcorn-salted",
    "rice-jasmin",
    "faux-mage-cashew-vegan-chers",
    "croque-monsieur",
    "tomato-stewed-without-addition-of-fat-without-addition-of-salt",
    "cocoa-powder",
    "perch-fillets-lake",
    "soup-tomato",
    "ham-turkey",
    "fruit-compotes",
    "french-pizza-from-alsace-baked",
    "banana-cake",
    "balsamic-vinegar",
    "eggplant-caviar",
    "naan-indien-bread",
    "chocolate-egg-small",
    "cake-oblong",
    "biscuit-with-butter",
    "pecan-nut",
    "savoury-puff-pastry-stick",
    "sweets-candies",
    "coriander",
    "fish-crunchies-battered",
    "chia-grains",
    "minced-meat",
    "bean-seeds",
    "meat-balls",
    "bouillon-vegetable",
    "coffee-decaffeinated",
    "carrot-cake",
    "paprika-chips",
    "lemon-pie",
    "fig-dried",
    "waffle"
])

In [None]:
def load_and_filter_classes(drive_path, excluded_set):
    """
    Loads class names from meta.json and filters them based on the hardcoded
    EXCLUDED_CATEGORIES set.
    """
    # 1. Load all original class names
    meta_path = os.path.join(drive_path, 'meta.json')
    if not os.path.exists(meta_path):
        raise FileNotFoundError(f"meta.json not found at {meta_path}")
    with open(meta_path, 'r') as f:
        meta_data = json.load(f)
    original_class_names = [cls['title'] for cls in meta_data.get('classes', [])]
    print(f"Loaded {len(original_class_names)} total classes from meta.json")
    
    # 2. Filter classes using the provided exclusion set
    filtered_class_names = [
        name for name in original_class_names if name not in excluded_set
    ]
    num_removed = len(original_class_names) - len(filtered_class_names)
    print(f"Removed {num_removed} specified rare classes.")
    print(f"Kept {len(filtered_class_names)} classes for training.")

    # 3. Create a new class map for only the filtered classes
    filtered_class_map = {name: i for i, name in enumerate(filtered_class_names)}

    return filtered_class_names, filtered_class_map

In [9]:
extract_dir = "/content/extracted_food_recognition"

class_names, class_map = load_and_filter_classes(extract_dir, excluded_food_categories)

Loaded 498 total classes from meta.json
Removed 88 specified rare classes.
Kept 410 classes for training.


### Convert annotation json to yolo detection bounding box

For example, for `training/img/006497.jpg`, annotations can be found in the relevant annotation folder: `training/ann/006497.jpg.json`

However, YOLO needs a label file in text format.

You can use this image to better understand how it works: ![yolo sample img](https://github.com/ultralytics/docs/releases/download/0/two-persons-tie.avif)

The corresponding label text is as follws:
```txt
0 0.481719 0.634028 0.690625 0.713278
0 0.741094 0.524306 0.314750 0.933389
27 0.364844 0.795833 0.078125 0.400000
```

- The first column (`0`) is the category ID. In this case, `0` refers to `Person`. 
- The second (`0.481719`) and third columns (`0.634028`) are used as x,y central point coordinates of this object. 
- The fourth(`0.690625`) and fifth columns(`0.713278`) represent the width and height of the object.


For more information, you can check [this documentation](link).

**Note:** In our project, the original ID in `meta.json` is not used. The `class_map` returns a unique integer ID as the key.

In [2]:
def convert_annotations(split_name, class_map):
    print(f"\nGenerating labels for {split_name} split using {len(class_map)} classes...")

    split_dir = os.path.join(extract_dir, split_name)
    original_ann_dir = os.path.join(split_dir, 'ann')
    original_img_dir = os.path.join(split_dir, 'img')

    yolo_label_dir = os.path.join(split_dir, 'labels')
    os.makedirs(yolo_label_dir, exist_ok=True)

    if not os.path.isdir(original_ann_dir):
        print(f"Warning: Annotation dir not found: {original_ann_dir}. Skipping.")
        return

    ann_files = [f for f in os.listdir(original_ann_dir) if f.endswith('.json')]

    valid_labels = 0
    empty_labels = 0

    for ann_file in tqdm(ann_files, desc=f'Converting {split_name} annotations'):
        json_path = os.path.join(original_ann_dir, ann_file)
        with open(json_path, 'r') as f:
            data = json.load(f)

        if ann_file.endswith('.jpg.json'):
            base_filename = ann_file.replace('.jpg.json', '')
            img_extensions = ['.jpg', '.jpeg', '.png', '.JPG']
        else:
            base_filename = ann_file.replace('.json', '')
            img_extensions = ['.jpg', '.jpeg', '.png', '.JPG']

        img_path = None
        for ext in img_extensions:
            potential_path = os.path.join(original_img_dir, f"{base_filename}{ext}")
            if os.path.exists(potential_path):
                img_path = potential_path
                break

        if not img_path:
            print(f"Warning: No image found for {ann_file}")
            continue

        yolo_lines = []
        img_h, img_w = data['size']['height'], data['size']['width']
        if img_h == 0 or img_w == 0:
            continue

        for obj in data['objects']:
            class_title = obj['classTitle']
            if class_title not in class_map:
                continue
            class_id = class_map[class_title]
            points = np.array(obj['points']['exterior'])
            if points.shape[0] < 1:
                continue
            x_min, y_min = points.min(axis=0)
            x_max, y_max = points.max(axis=0)
            box_w, box_h = x_max - x_min, y_max - y_min
            x_center, y_center = x_min + box_w / 2, y_min + box_h / 2
            x_norm, y_norm, w_norm, h_norm = x_center / img_w, y_center / img_h, box_w / img_w, box_h / img_h
            yolo_lines.append(f"{class_id} {x_norm:.6f} {y_norm:.6f} {w_norm:.6f} {h_norm:.6f}")

        if yolo_lines:
            with open(os.path.join(yolo_label_dir, f"{base_filename}.txt"), 'w') as f:
                f.write('\n'.join(yolo_lines))
            valid_labels += 1
        else:
            empty_labels += 1

    print(f"Created {valid_labels} valid label files, {empty_labels} images had no valid annotations")

In [None]:
convert_annotations("training", class_map)
convert_annotations("validation", class_map)

### Prepare the dataset.yaml

For YOLO training, it's important to let YOLO know your training target. In our case, these are food categories that can be found in our image datasets, as well as the path of our training/validation dataset.

In [None]:
dataset_yaml_path = os.path.join(extract_dir, 'dataset.yaml')

final_class_names = list(class_map.keys())
yaml_content = {
    'path': extract_dir,
    'train': os.path.join('training', 'img'),
    'val': os.path.join('validation', 'img'),
    'nc': len(final_class_names),
    'names': final_class_names
}
with open(dataset_yaml_path, 'w') as f:
    yaml.dump(yaml_content, f, sort_keys=False)

print(f"`dataset.yaml` created at: {dataset_yaml_path}")

## Step 3: Start Training

In [3]:


class FoodDetectionTrainer:
    """
    A comprehensive class to handle the training, evaluation, and testing
    of a YOLOv8 food detection model.
    """
    def __init__(self, dataset_path, model_size='n'):
        """
        Initializes the FoodDetectionTrainer.
        
        Args:
            dataset_path (str): The root directory path of the dataset.
            model_size (str): The model size ('n', 's', 'm', 'l', 'x').
        """
        self.dataset_path = dataset_path
        self.model_size = model_size
        self.model_map = {
            'n': 'yolov8n.pt',    # Nano - fastest, smallest
            's': 'yolov8s.pt',    # Small - balanced
            'm': 'yolov8m.pt',    # Medium - better accuracy
            'l': 'yolov8l.pt',    # Large - high accuracy
            'x': 'yolov8x.pt'     # Extra Large - highest accuracy
        }
        
        # Validate the dataset upon initialization
        self.validate_dataset()
        
    def validate_dataset(self):
        """Validates the dataset structure and files."""
        
        # Check for dataset.yaml
        yaml_path = os.path.join(self.dataset_path, 'dataset.yaml')
        if not os.path.exists(yaml_path):
            raise FileNotFoundError(f"dataset.yaml not found in {self.dataset_path}")
        
        # Load configuration
        with open(yaml_path, 'r') as f:
            config = yaml.safe_load(f)
        
        self.config = config
        self.num_classes = config['nc']
        print(f"Detected {self.num_classes} food categories.")
        
        # Check training and validation sets
        base_path = config.get('path', self.dataset_path)
        train_img_path = os.path.join(base_path, config['train'])
        val_img_path = os.path.join(base_path, config['val'])
        
        for split_name, img_path in [('Training set', train_img_path), ('Validation set', val_img_path)]:
            if not os.path.exists(img_path):
                raise FileNotFoundError(f"{split_name} image directory not found: {img_path}")
            
            # Infer label path from image path
            label_path = os.path.join(os.path.dirname(img_path), 'labels')
            if not os.path.exists(label_path):
                raise FileNotFoundError(f"{split_name} label directory not found: {label_path}")
            
            # Count files
            img_files = [f for f in os.listdir(img_path) 
                        if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            label_files = [f for f in os.listdir(label_path) if f.endswith('.txt')]
            
            print(f"  {split_name}: {len(img_files)} images, {len(label_files)} label files.")
            
            valid_labels = 0
            empty_labels = 0
            for label_file in label_files[:min(50, len(label_files))]:
                with open(os.path.join(label_path, label_file), 'r') as f:
                    content = f.read().strip()
                    if content:
                        valid_labels += 1
                    else:
                        empty_labels += 1
            
            print(f"    Sampled Quality - Valid labels: {valid_labels}, Empty labels: {empty_labels}")
        
    
    def setup_training_config(self, epochs=100, batch_size=16, image_size=640, 
                            learning_rate=0.01, patience=50):
        
        if self.num_classes > 200: 
            recommended_epochs = max(100, min(200, self.num_classes // 2))
            recommended_batch = min(batch_size, 32)  
            recommended_lr = min(learning_rate, 0.01)  
        else:
            recommended_epochs = epochs
            recommended_batch = batch_size
            recommended_lr = learning_rate
        
        self.training_config = {
            'epochs': recommended_epochs,
            'batch': recommended_batch,
            'imgsz': image_size,
            'lr0': recommended_lr,
            'patience': patience,
            'device': 'auto',
            'workers': 8,       
            'cache': 'disk',    
            'amp': True,        
            'optimizer': 'AdamW',
            'weight_decay': 0.0005,
            'warmup_epochs': 5,
            'warmup_momentum': 0.8,
            'warmup_bias_lr': 0.1,
            'mosaic': 1.0,      # Data augmentation
            'mixup': 0.1,
            'copy_paste': 0.1,
            'degrees': 10.0,    # Rotation
            'translate': 0.2,   # Translation
            'scale': 0.9,       # Scale
            'fliplr': 0.5,      # Horizontal flip
            'flipud': 0.0,      # Vertical flip (food is rarely upside down)
            'hsv_h': 0.015,     # HSV-Hue augmentation
            'hsv_s': 0.7,       # HSV-Saturation augmentation
            'hsv_v': 0.4,       # HSV-Value augmentation
        }
        
        
        return self.training_config
    
    def train(self, project_name='food_detection', experiment_name=None):
        """Starts the training process."""
        
        if experiment_name is None:
            experiment_name = f'food_detection_{self.model_size}_{self.num_classes}classes'
        
        print(f"\n Starting food detection model training...")
        print(f" Project: {project_name}")
        print(f"Experiment: {experiment_name}")
        
        # Load model
        model = YOLO(self.model_map[self.model_size])
        print(f"Loaded {self.model_map[self.model_size]} pretrained model.")
        
        # Start training
        dataset_yaml = os.path.join(self.dataset_path, 'dataset.yaml')
        
        try:
            results = model.train(
                data=dataset_yaml,
                project=f'runs/detect/{project_name}',
                name=experiment_name,
                exist_ok=True,
                plots=True,
                save_json=True,
                val=True,
                verbose=True,
                **self.training_config
            )
            
            self.results = results
            self.model = model
            
            return results
            
        except Exception as e:
            print(f"An error occurred during training: {e}")
            raise
    
    def evaluate(self):
        if not hasattr(self, 'results'):
            print("Please complete training before evaluating!")
            return
        
        # Run validation
        dataset_yaml = os.path.join(self.dataset_path, 'dataset.yaml')
        val_results = self.model.val(data=dataset_yaml, plots=True, save_json=True)
        
        # Display main metrics
        metrics = {
            'mAP50': val_results.box.map50,
            'mAP50-95': val_results.box.map,
            'Precision': val_results.box.mp,
            'Recall': val_results.box.mr
        }
        
        print(f"   mAP@0.5: {metrics['mAP50']:.4f} ({metrics['mAP50']*100:.1f}%)")
        print(f"   mAP@0.5:0.95: {metrics['mAP50-95']:.4f} ({metrics['mAP50-95']*100:.1f}%)")
        print(f"   Precision: {metrics['Precision']:.4f} ({metrics['Precision']*100:.1f}%)")
        print(f"   Recall: {metrics['Recall']:.4f} ({metrics['Recall']*100:.1f}%)")
        
        # Performance feedback
        if metrics['mAP50'] > 0.7:
            print("\n Excellent.")
        elif metrics['mAP50'] > 0.5:
            print("\n Great.")
        elif metrics['mAP50'] > 0.3:
            print("\n Fair.")
        else:
            print("\n Needs Improvement.")
        
        return val_results
    
    def visualize_results(self):
        if not hasattr(self, 'results'):
            print("Please complete training before visualizing results!")
            return
        
        results_dir = self.results.save_dir
        
        # Display training plots
        plots = [
            ('results.png', 'Training Curves'),
            ('confusion_matrix.png', 'Confusion Matrix'),
            ('labels.jpg', 'Label Distribution'),
            ('val_batch0_pred.jpg', 'Validation Prediction Sample')
        ]
        
        for filename, title in plots:
            filepath = os.path.join(results_dir, filename)
            if os.path.exists(filepath):
                print(f"\n{title}:")
                try:
                    display(Image(filepath))
                except Exception as e:
                    print(f"Could not display {filename}: {e}")
            else:
                print(f"⚠️ {filename} not found.")
    
    def test_prediction(self, test_image_path, confidence=0.25):
        """Tests a prediction on a single image."""
        if not hasattr(self, 'model'):
            print("Please complete training before testing!")
            return
        
        print(f"Predicting on image: {test_image_path}")
        
        # Make prediction
        results = self.model.predict(
            test_image_path,
            conf=confidence,
            save=True,
            show_labels=True,
            show_conf=True
        )
        
        # Display results
        for result in results:
            boxes = result.boxes
            if boxes and len(boxes) > 0:
                print(f"Detected {len(boxes)} food objects:")
                for i, box in enumerate(boxes):
                    class_id = int(box.cls[0])
                    confidence_score = float(box.conf[0])
                    class_name = self.config['names'][class_id]
                    print(f"  {i+1}. {class_name}: {confidence_score:.3f}")
            else:
                print("No food objects detected.")
        
        return results

In [None]:

def main():
    
    dataset_path = "/content/extracted_food_recognition" 
    trainer = FoodDetectionTrainer(dataset_path, model_size='s')  
    
    config = trainer.setup_training_config(
        epochs=10,        
        batch_size=16,    
        image_size=640,   
        learning_rate=0.01,
        patience=30    
    )
    
    results = trainer.train(
        project_name='food_detection_project',
        experiment_name='food_410classes_v1'
    )
    
    val_results = trainer.evaluate()
    
    trainer.visualize_results()
    
    # Test prediction 
    # test_image = "/path/to/your/test/image.jpg"
    # if os.path.exists(test_image):
    #     trainer.test_prediction(test_image, confidence=0.3)
    
    return trainer

trainer = main()

In [None]:
from google.colab import files

best_model = trainer.results.save_dir + "/weights/best.pt"

print(f"Downloading the best model for local use.')")

files.download(best_model)