
# Nutrition5k × YOLO11 — Classification Training (Calorie/Macro buckets)

**Nutrition5k** is primarily a *regression* dataset (predict calories/macros), not a labeled classification set.  
To reuse the same YOLO‑classification workflow you used for Food‑11/101, this notebook derives **coarse labels** from nutrition metadata:

- **`primary_macro`** (default): `carb`, `protein`, or `fat` depending on which macronutrient (g) is highest.
- **`calorie_bin`** (optional): 5 quantile bins based on total calories per dish.

It then builds `train/`, `val/`, `test/` splits and trains **`yolo11n-cls.pt`**.

> If your Kaggle ref stores images as **bytes inside tables**, this notebook can *materialize* RGB images to JPEGs.


In [1]:

# --- 1) Paths & knobs ----------------------------------------------------------
import os, pathlib, random, shutil, sys, re, ast, base64, io
from pathlib import Path
from collections import defaultdict

# Root where you copied the Kaggle cache
DATASET_ROOT = "/home/kristoffel/datasets/dataset-01-Nutrition5k"   # ← edit if needed
MODEL_DIR    = "/home/kristoffel/models"

# Where to place materialized RGB images (if starting from bytes/tables)
RGB_DIR = os.path.join(DATASET_ROOT, "images_rgb")  # we'll create <dish_id> subfolders

# Output splits for YOLO classification
TRAIN_DIR = os.path.join(DATASET_ROOT, "train")
VAL_DIR   = os.path.join(DATASET_ROOT, "val")
TEST_DIR  = os.path.join(DATASET_ROOT, "test")

# Labeling strategy ('primary_macro' or 'calorie_bin')
LABEL_STRATEGY = "primary_macro"   # ← edit me

# Split ratios
VAL_RATIO  = 0.10
TEST_RATIO = 0.10
SEED = 42
random.seed(SEED)

os.makedirs(MODEL_DIR, exist_ok=True)
assert os.path.isdir(DATASET_ROOT), f"Missing dataset root: {DATASET_ROOT}"
print("Paths OK. Root:", DATASET_ROOT)


Paths OK. Root: /home/kristoffel/datasets/dataset-01-Nutrition5k


In [2]:

# --- 2) Locate source tables and/or images ------------------------------------
# Nutrition5k Kaggle variants differ. We'll try to find:
#  - a dish-level nutrition table (for calories/macros) -> used to derive labels
#  - an images table containing byte-encoded RGB (if no JPEGs are present)
#
# We also support a mirror which already has JPEGs in directories.

from glob import glob
import pandas as pd

def find_first_file(root: str, patterns):
    for pat in patterns:
        hits = sorted(glob(os.path.join(root, "**", pat), recursive=True))
        if hits:
            return hits[0]
    return None

# Candidate files for nutrition (dish-level)
NUTR_PATTERNS = [
    "*dish_nutrition.*", "*dish*nutrition.*", "*nutrition*.csv", "*nutrition*.parquet", "*nutrition*.feather", "*nutrition*.pkl"
]

# Candidates for image tables
IMG_PATTERNS = [
    "*dish_images.*", "*images_table.*", "*images*.parquet", "*images*.feather", "*images*.pkl", "*images*.csv"
]

# Already-materialized JPEGs (side-angle mirror or preprocessed)
JPEG_ROOT = None
for cand in [os.path.join(DATASET_ROOT, "images"), os.path.join(DATASET_ROOT, "images_rgb"), DATASET_ROOT]:
    if os.path.isdir(cand) and any(Path(cand).glob("**/*.jpg")):
        JPEG_ROOT = cand
        break

nutr_file = find_first_file(DATASET_ROOT, NUTR_PATTERNS)
img_table = find_first_file(DATASET_ROOT, IMG_PATTERNS)

print("nutrition file:", nutr_file)
print("image table   :", img_table)
print("jpeg root     :", JPEG_ROOT)


nutrition file: None
image table   : /home/kristoffel/datasets/dataset-01-Nutrition5k/dish_images.pkl
jpeg root     : None


In [4]:

# --- 3) Load nutrition and derive labels --------------------------------------
import pandas as pd
import numpy as np

assert nutr_file, "Could not find a nutrition table. Please point NUTR_PATTERNS to your dish-level nutrition file."

def read_any_table(path: str) -> pd.DataFrame:
    low = path.lower()
    if low.endswith(".csv"):
        return pd.read_csv(path)
    if low.endswith(".parquet"):
        return pd.read_parquet(path)
    if low.endswith(".feather"):
        import pyarrow  # noqa
        return pd.read_feather(path)
    if low.endswith((".pkl", ".pickle")):
        import pickle
        return pd.read_pickle(path)
    raise ValueError(f"Unsupported table format: {path}")

df_nutr = read_any_table(nutr_file)
cols = {c.strip(): c for c in df_nutr.columns}
norm = {k.lower().replace(" ", "").replace("(", "").replace(")", ""): v for k,v in cols.items()}

def pick(name_alts):
    for a in name_alts:
        key = a.lower().replace(" ", "").replace("(", "").replace(")", "")
        if key in norm:
            return norm[key]
    return None

# Try to identify key columns (robust to different namings)
dish_col = pick(["dish_id","id","sample_id","plate_id","dish"])
cal_col  = pick(["calories","kcal","energykcal","energy_kcal"])
pro_col  = pick(["protein","protein_g","proteing"])
fat_col  = pick(["fat","fat_g","fatg","totalfat_g","totalfat"])
carb_col = pick(["carbohydrate","carbohydrates","carbs","carb_g","carbohydrates_g","carbohydrateg"])

for need, col in [("dish id", dish_col), ("calories", cal_col), ("protein", pro_col), ("fat", fat_col), ("carbohydrate", carb_col)]:
    print(f"{need:12s}:", col)

assert dish_col is not None, "Could not find a dish ID column in the nutrition table."

# Keep rows that have nutrition values (when available)
df = df_nutr.copy()
df = df.dropna(subset=[dish_col])

# Build labels
if LABEL_STRATEGY == "primary_macro" and all(x is not None for x in [pro_col, fat_col, carb_col]):
    def primary_macro(r):
        vals = [(r[pro_col], "protein"), (r[fat_col], "fat"), (r[carb_col], "carb")]
        # Handle missing / NaN
        vals = [(0 if pd.isna(v) else float(v), name) for v,name in vals]
        return max(vals)[1]
    df["label"] = df.apply(primary_macro, axis=1)
    class_names = ["carb","protein","fat"]
elif LABEL_STRATEGY == "calorie_bin" and cal_col is not None:
    # 5 quantile bins for calories
    df = df.dropna(subset=[cal_col])
    df["label"] = pd.qcut(df[cal_col].astype(float), q=5, duplicates="drop").astype(str)
    class_names = sorted(df["label"].unique().tolist())
else:
    raise RuntimeError("Chosen LABEL_STRATEGY is not supported by available columns. "
                       "Try LABEL_STRATEGY='primary_macro' or ensure calories/macros exist.")

# Keep only dish_id, label
df_lbl = df[[dish_col, "label"]].drop_duplicates().reset_index(drop=True)
df_lbl.columns = ["dish_id","label"]
print("Label distribution:")
print(df_lbl["label"].value_counts())


AssertionError: Could not find a nutrition table. Please point NUTR_PATTERNS to your dish-level nutrition file.

In [None]:

# --- 4) Materialize JPEGs if needed -------------------------------------------
# Many Nutrition5k mirrors store images as bytes in a table.
# We'll attempt to decode to JPEGs in images_rgb/<dish_id>/img_<idx>.jpg

from io import BytesIO
from PIL import Image
import pandas as pd
import numpy as np
import ast, base64, re

if JPEG_ROOT is None:
    assert img_table, "No JPEGs found and no image table located. Provide an images table or use a mirror with JPEGs."
    df_img = read_any_table(img_table)
    print("Images table columns:", df_img.columns.tolist())

    # Try to find dish id column and an RGB bytes column
    cols = {c.strip(): c for c in df_img.columns}
    norm = {k.lower().replace(" ", "").replace("(", "").replace(")", ""): v for k,v in cols.items()}

    dish_img_col = None
    for a in ["dish_id","id","sample_id","plate_id","dish"]:
        k = a.lower().replace(" ", "").replace("(", "").replace(")", "")
        if k in norm:
            dish_img_col = norm[k]; break
    assert dish_img_col is not None, "Could not find a dish id column in images table."

    rgb_col = None
    for a in ["rgb","rgb_bytes","image","image_bytes","rgbimage","rgbimagebytes"]:
        k = a.lower().replace(" ", "").replace("(", "").replace(")", "")
        if k in norm:
            rgb_col = norm[k]; break
    assert rgb_col is not None, "Could not find an RGB image bytes column."

    # Make output root
    out_root = Path(RGB_DIR)
    out_root.mkdir(parents=True, exist_ok=True)

    def decode_to_bytes(x):
        # Handle: bytes, base64 str, python bytes literal in string
        if isinstance(x, (bytes, bytearray)):
            return bytes(x)
        if isinstance(x, memoryview):
            return x.tobytes()
        if isinstance(x, str):
            s = x.strip()
            # Python bytes literal
            if (s.startswith("b'") or s.startswith('b"')) and s.endswith(("'",'"')):
                return ast.literal_eval(s)
            # Base64?
            try:
                return base64.b64decode(s, validate=True)
            except Exception:
                pass
        raise ValueError("Unsupported RGB encoding type")

    count, skipped = 0, 0
    for ridx, row in df_img.iterrows():
        dish_id = str(row[dish_img_col])
        try:
            rgb_bytes = decode_to_bytes(row[rgb_col])
            img = Image.open(BytesIO(rgb_bytes)).convert("RGB")
        except Exception as e:
            skipped += 1
            continue
        dish_dir = out_root / dish_id
        dish_dir.mkdir(exist_ok=True, parents=True)
        # sequential filename
        out_path = dish_dir / f"img_{len(list(dish_dir.glob('img_*.jpg'))):05d}.jpg"
        img.save(out_path, format="JPEG", quality=90)
        count += 1
        if count % 500 == 0:
            print("saved", count, "images...")
    print(f"Saved {count} images to {out_root} ({skipped} skipped).")
    JPEG_ROOT = str(out_root)
else:
    print("Found JPEGs under:", JPEG_ROOT)


In [None]:

# --- 5) Pair images to labels and build splits --------------------------------
import os, glob, random, shutil
from pathlib import Path
from collections import defaultdict

# Build map dish_id -> list of image paths
dish_to_paths = defaultdict(list)
for p in Path(JPEG_ROOT).glob("**/*.jpg"):
    # expect path .../<dish_id>/filename.jpg
    dish_id = p.parent.name
    dish_to_paths[dish_id].append(str(p))

# Join with labels
label_map = dict(df_lbl.values)  # dish_id -> label
pairs = [(dish, label_map.get(dish, None), imgs) for dish, imgs in dish_to_paths.items()]
pairs = [(d,l,imgs) for (d,l,imgs) in pairs if l is not None and len(imgs)>0]

print("Dishes with labels & images:", len(pairs))

# Split by dish (to avoid leakage across splits)
by_label = defaultdict(list)
for dish, label, imgs in pairs:
    by_label[label].append((dish, imgs))

splits = {"train": [], "val": [], "test": []}
for label, items in by_label.items():
    random.shuffle(items)
    n = len(items)
    n_val  = max(1, int(n * VAL_RATIO))
    n_test = max(1, int(n * TEST_RATIO))
    n_train = max(0, n - n_val - n_test)
    splits["val"].extend([(label, d, imgs) for d,imgs in items[:n_val]])
    splits["test"].extend([(label, d, imgs) for d,imgs in items[n_val:n_val+n_test]])
    splits["train"].extend([(label, d, imgs) for d,imgs in items[n_val+n_test:]])

for d in (TRAIN_DIR, VAL_DIR, TEST_DIR):
    if os.path.isdir(d):
        shutil.rmtree(d)
    os.makedirs(d, exist_ok=True)
    for cls in sorted(set(df_lbl["label"])):
        os.makedirs(os.path.join(d, cls), exist_ok=True)

def safe_link(src, dst):
    try:
        os.symlink(src, dst)
    except Exception:
        shutil.copy2(src, dst)

def materialize(split_name):
    cnt = 0
    for label, dish, imgs in splits[split_name]:
        for src in imgs:
            dst = os.path.join(DATASET_ROOT, split_name, label, f"{dish}__{Path(src).name}")
            if not os.path.exists(dst):
                safe_link(src, dst)
                cnt += 1
    print(f"{split_name}: wrote {cnt} image links/files.")

materialize("train"); materialize("val"); materialize("test")
print("✅ Splits ready:")
print("  train:", sum(len(files) for _,_,files in splits["train"]))
print("  val  :", sum(len(files) for _,_,files in splits["val"]))
print("  test :", sum(len(files) for _,_,files in splits["test"]))


In [None]:

# --- 6) Train YOLO11n-cls on derived labels -----------------------------------
# !pip install -U ultralytics
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

from ultralytics import YOLO
import torch, ultralytics

print("Ultralytics:", ultralytics.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    import torch
    print("Device:", torch.cuda.get_device_name(0))

EPOCHS = 30   # Nutrition5k is large; start moderate, adjust as needed
IMGSZ  = 224
BATCH  = 64
RUN_NAME = "nutrition5k_yolo11n_cls_" + ("macro" if "macro" in LABEL_STRATEGY else "calbins")

model = YOLO("yolo11n-cls.pt")
results = model.train(
    data=DATASET_ROOT,     # directory with train/val/test
    epochs=EPOCHS,
    imgsz=IMGSZ,
    batch=BATCH,
    lr0=1e-3,
    patience=10,
    project=MODEL_DIR,
    name=RUN_NAME,
    plots=True,
    device=0 if torch.cuda.is_available() else "cpu",
)
print("Training run saved to:", results.save_dir)


In [None]:

# --- 7) Evaluate on val and test ----------------------------------------------
from ultralytics import YOLO
import glob, os, torch

RUN_PREFIX = "nutrition5k_yolo11n_cls_"
cands = glob.glob(os.path.join(MODEL_DIR, RUN_PREFIX + "*", "weights", "best.pt"))
assert cands, f"No best.pt found under {MODEL_DIR}/{RUN_PREFIX}*/weights/"
best_path = max(cands, key=os.path.getmtime)
print("Using best:", best_path)

model = YOLO(best_path)

metrics_val = model.val(
    data=DATASET_ROOT,
    split="val",
    imgsz=224,
    project=MODEL_DIR,
    name=RUN_PREFIX + "val",
    device=0 if torch.cuda.is_available() else "cpu",
)

metrics_test = model.val(
    data=DATASET_ROOT,
    split="test",
    imgsz=224,
    project=MODEL_DIR,
    name=RUN_PREFIX + "test",
    device=0 if torch.cuda.is_available() else "cpu",
)
print("Done.")
