# Junk Food Multi-label Classification with KNN

This notebook implements a **K-Nearest Neighbors (KNN)** model for image classification from a **COCO JSON dataset**.

## Before you start

Make sure you have access to GPU. In case of any problems, navigate to `Edit` -> `Notebook settings` -> `Hardware accelerator`, set it to `GPU`, click `Save` and try again.

In [None]:
!nvidia-smi

Sun Dec 28 08:44:49 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   70C    P8             11W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
import os
HOME = os.getcwd()
print("HOME:", HOME)

HOME: /content


In [None]:
!mkdir -p {HOME}/datasets
%cd {HOME}/datasets


/content/datasets


## Install packages using pip

In [None]:
!pip install roboflow==1.2.11 torch==2.9.0 torchvision==0.24.0 scikit-learn==1.6.1 tqdm==4.67.1

Collecting roboflow==1.2.11
  Downloading roboflow-1.2.11-py3-none-any.whl.metadata (9.7 kB)
Collecting idna==3.7 (from roboflow==1.2.11)
  Downloading idna-3.7-py3-none-any.whl.metadata (9.9 kB)
Collecting opencv-python-headless==4.10.0.84 (from roboflow==1.2.11)
  Downloading opencv_python_headless-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting pi-heif<2 (from roboflow==1.2.11)
  Downloading pi_heif-1.1.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (6.5 kB)
Collecting pillow-avif-plugin<2 (from roboflow==1.2.11)
  Downloading pillow_avif_plugin-1.5.2-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (2.1 kB)
Collecting filetype (from roboflow==1.2.11)
  Downloading filetype-1.2.0-py2.py3-none-any.whl.metadata (6.5 kB)
Downloading roboflow-1.2.11-py3-none-any.whl (89 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.9/89.9 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading idna-3.7-py

## Download dataset from Roboflow

Don't forget to change the `API_KEY` with your dataset key.

We replicate your original dataset setup. Even though the dataset is labeled for object detection, we’ll use the full image classification approach with KNN. Labels will be derived from the most frequent class per image.

In [None]:
from roboflow import Roboflow
from google.colab import userdata

rf = Roboflow(api_key=userdata.get('ROBOFLOW_API_KEY'))
project = rf.workspace(userdata.get('ROBOFLOW_WORKSPACE_ID')).project(userdata.get('ROBOFLOW_PROJECT_ID'))
version = project.version(userdata.get('ROBOFLOW_DATASET_VERSION'))
dataset = version.download("coco")

loading Roboflow workspace...
loading Roboflow project...


Downloading Dataset Version Zip in Junk-Food-Detection-10 to coco:: 100%|██████████| 293482/293482 [00:10<00:00, 27415.50it/s]





Extracting Dataset Version Zip to Junk-Food-Detection-10 in coco:: 100%|██████████| 5280/5280 [00:02<00:00, 2024.61it/s]


In [None]:
%cd {HOME}

/content


## Dataset Loading and Label Extraction

We will use a classification model. So, for labeling, we use two classes: junk-food-ad and non-junk-food-ad. Given the fact that the dataset is multiclass, the rule is: if there is at least one bounding box belonging to a particular image, it's junk-food-ad. Otherwise, it's non-junk-food-ad

In [None]:
import json
import os
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np


def load_coco_annotations(json_path: str) -> Dict:
    with open(json_path, 'r') as f:
        return json.load(f)


def process_dataset_part(
    part_dir: str,
    annotations_filename: str = "_annotations.coco.json"
) -> Tuple[List[str], np.ndarray, List[str]]:
    annotations_path = os.path.join(part_dir, annotations_filename)

    if not os.path.exists(annotations_path):
        raise FileNotFoundError(f"Annotations file not found: {annotations_path}")

    # Load annotations
    coco_data = load_coco_annotations(annotations_path)

    # Create category mapping, excluding "junk-food"
    category_id_to_name = {
        cat['id']: cat['name']
        for cat in coco_data['categories']
        if cat['name'] != 'junk-food'
    }
    all_category_names = sorted(set(category_id_to_name.values()))

    # Create a mapping of image_id to set of category names
    image_to_categories = {}
    for annotation in coco_data['annotations']:
        image_id = annotation['image_id']
        category_id = annotation['category_id']

        # Skip if category is not in our filtered mapping
        if category_id not in category_id_to_name:
            continue

        category_name = category_id_to_name[category_id]

        if image_id not in image_to_categories:
            image_to_categories[image_id] = set()
        image_to_categories[image_id].add(category_name)

    # Process images in order
    image_paths = []
    labels_list = []

    for image_info in coco_data['images']:
        image_id = image_info['id']
        file_name = image_info['file_name']

        image_path = os.path.join(part_dir, file_name)
        image_paths.append(image_path)

        # Create multi-hot encoded label vector
        label_vector = np.zeros(len(all_category_names), dtype=int)
        if image_id in image_to_categories:
            for category_name in image_to_categories[image_id]:
                idx = all_category_names.index(category_name)
                label_vector[idx] = 1

        labels_list.append(label_vector)

    labels_array = np.array(labels_list)
    return image_paths, labels_array, all_category_names


def process_full_dataset(
    dataset_root: str,
    parts: List[str] = ['train', 'valid', 'test']
) -> Tuple[Dict[str, List[str]], Dict[str, np.ndarray], List[str]]:

    all_image_paths = {}
    all_labels = {}
    classes = None

    for part in parts:
        part_dir = os.path.join(dataset_root, part)

        if not os.path.exists(part_dir):
            print(f"Warning: Directory not found: {part_dir}. Skipping...")
            continue

        image_paths, labels, part_classes = process_dataset_part(part_dir)

        # Ensure all parts have the same classes
        if classes is None:
            classes = part_classes
        elif classes != part_classes:
            print(f"Warning: Classes differ in {part}. Using classes from first part.")

        all_image_paths[part] = image_paths
        all_labels[part] = labels

    return all_image_paths, all_labels, classes


image_paths_dict, labels_dict, classes = process_full_dataset(dataset.location)

print("\n" + "="*50)
print("DATASET SUMMARY (Multi-label)")
print("="*50)
print(f"\nClasses ({len(classes)}): {classes}")
print(f"\nDataset parts processed:")

for part in image_paths_dict.keys():
    print(f"\n{part.upper()}:")
    print(f"  Total images: {len(image_paths_dict[part])}")
    print(f"  Label matrix shape: {labels_dict[part].shape}")
    print(f"  Label distribution:")
    for i, cls in enumerate(classes):
        count = labels_dict[part][:, i].sum()
        percentage = (count / len(labels_dict[part]) * 100) if len(labels_dict[part]) > 0 else 0
        print(f"    - {cls}: {count} ({percentage:.1f}%)")

    # Multi-label statistics
    labels_per_image = labels_dict[part].sum(axis=1)
    print(f"  Labels per image:")
    print(f"    - Mean: {labels_per_image.mean():.2f}")
    print(f"    - Min: {labels_per_image.min()}")
    print(f"    - Max: {labels_per_image.max()}")
    print(f"    - Images with 0 labels: {(labels_per_image == 0).sum()}")


DATASET SUMMARY (Multi-label)

Classes (7): ['french_fries', 'fried_chicken', 'hamburger', 'ice_cream', 'junk_food_logo', 'pizza', 'soda']

Dataset parts processed:

TRAIN:
  Total images: 4614
  Label matrix shape: (4614, 7)
  Label distribution:
    - french_fries: 388 (8.4%)
    - fried_chicken: 315 (6.8%)
    - hamburger: 379 (8.2%)
    - ice_cream: 468 (10.1%)
    - junk_food_logo: 1863 (40.4%)
    - pizza: 411 (8.9%)
    - soda: 603 (13.1%)
  Labels per image:
    - Mean: 0.96
    - Min: 0
    - Max: 5
    - Images with 0 labels: 1689

VALID:
  Total images: 440
  Label matrix shape: (440, 7)
  Label distribution:
    - french_fries: 41 (9.3%)
    - fried_chicken: 36 (8.2%)
    - hamburger: 34 (7.7%)
    - ice_cream: 42 (9.5%)
    - junk_food_logo: 180 (40.9%)
    - pizza: 40 (9.1%)
    - soda: 67 (15.2%)
  Labels per image:
    - Mean: 1.00
    - Min: 0
    - Max: 4
    - Images with 0 labels: 163

TEST:
  Total images: 218
  Label matrix shape: (218, 7)
  Label distribution:
 

## Feature Extraction of train set using pretrained models

KNN itself cannot extract visual features, it only compares numeric vectors.  
Therefore, we use **pre-trained** models (without their classification heads) to extract image embeddings of train set.

These embeddings (feature vectors) represent each image in a high-dimensional space that captures visual similarity.  
The extracted features are stored as a NumPy matrix and later fed into the KNN classifier.

In [None]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm
import numpy as np

def extract_features(image_paths, model, feature_dim, transform, model_name):
    features = []
    with torch.no_grad():
        for path in tqdm(image_paths, desc=f"Extracting features - {model_name}"):
            try:
                img = Image.open(path).convert("RGB")
                tensor = transform(img).unsqueeze(0).to(device)
                feat = model(tensor).squeeze().cpu().numpy()
                features.append(feat)
            except Exception as e:
                print(f"Error with {path}: {e}")
                features.append(np.zeros(feature_dim))
    return np.array(features)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model_configs = [
    {
        'name': 'ResNeXt-101',
        'loader': lambda: models.resnext101_32x8d(weights=models.ResNeXt101_32X8D_Weights.DEFAULT),
        'modifier': lambda m: torch.nn.Sequential(*list(m.children())[:-1]),
        'transform': models.ResNet50_Weights.DEFAULT.transforms(),
        'feature_dim': 2048
    },
    {
        'name': 'EfficientNet V2',
        'loader': lambda: models.efficientnet_v2_m(weights=models.EfficientNet_V2_M_Weights.DEFAULT),
        'modifier': lambda m: torch.nn.Sequential(*list(m.children())[:-1]),
        'transform': models.EfficientNet_V2_M_Weights.DEFAULT.transforms(),
        'feature_dim': 1280
    },
    {
        'name': 'ConvNeXt',
        'loader': lambda: models.convnext_base(weights=models.ConvNeXt_Base_Weights.DEFAULT),
        'modifier': lambda m: torch.nn.Sequential(*list(m.children())[:-1]),
        'transform': models.ConvNeXt_Base_Weights.DEFAULT.transforms(),
        'feature_dim': 1024
    },
    {
        'name': 'ViT',
        'loader': lambda: models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT),
        'modifier': lambda m: (setattr(m.heads, 'head', torch.nn.Identity()), m)[1],
        'transform': models.ViT_B_16_Weights.DEFAULT.transforms(),
        'feature_dim': 768
    },
    {
        'name': 'Swin Transformer',
        'loader': lambda: models.swin_v2_b(weights=models.Swin_V2_B_Weights.DEFAULT),
        'modifier': lambda m: torch.nn.Sequential(*list(m.children())[:-1]),
        'transform': models.Swin_B_Weights.DEFAULT.transforms(),
        'feature_dim': 1024
    },
    {
        'name': 'DINOv2',
        'loader': lambda: torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14'),
        'modifier': lambda m: m,  # No modification needed
        'transform': transforms.Compose([
            transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]),
        'feature_dim': 768
    }
]

loaded_models = []
for config in model_configs:
    model = config['loader']()
    model = config['modifier'](model)
    model = model.to(device)
    model.eval()
    loaded_models.append({
        'model': model,
        'name': config['name'],
        'transform': config['transform'],
        'feature_dim': config['feature_dim']
    })

# Extract features for all models on train set
all_features = {}
for model_info in loaded_models:
    features = extract_features(
        image_paths_dict['train'],
        model_info['model'],
        model_info['feature_dim'],
        model_info['transform'],
        model_info['name']
    )
    all_features[model_info['name']] = features
    print(f"Feature matrix shape - {model_info['name']}: {features.shape}")

Using device: cuda
Downloading: "https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth" to /root/.cache/torch/hub/checkpoints/resnext101_32x8d-110c445d.pth


100%|██████████| 340M/340M [00:06<00:00, 53.4MB/s]


Downloading: "https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_v2_m-dc08266a.pth


100%|██████████| 208M/208M [00:01<00:00, 204MB/s]


Downloading: "https://download.pytorch.org/models/convnext_base-6075fbad.pth" to /root/.cache/torch/hub/checkpoints/convnext_base-6075fbad.pth


100%|██████████| 338M/338M [00:02<00:00, 133MB/s]


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


100%|██████████| 330M/330M [00:01<00:00, 201MB/s]


Downloading: "https://download.pytorch.org/models/swin_v2_b-781e5279.pth" to /root/.cache/torch/hub/checkpoints/swin_v2_b-781e5279.pth


100%|██████████| 336M/336M [00:01<00:00, 185MB/s]


Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip


xFormers is not available (SwiGLU)
xFormers is not available (Attention)
xFormers is not available (Block)


Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vitb14_pretrain.pth


100%|██████████| 330M/330M [00:00<00:00, 428MB/s]
Extracting features - ResNeXt-101: 100%|██████████| 4614/4614 [02:05<00:00, 36.69it/s]


Feature matrix shape - ResNeXt-101: (4614, 2048)


Extracting features - EfficientNet V2: 100%|██████████| 4614/4614 [03:00<00:00, 25.54it/s]


Feature matrix shape - EfficientNet V2: (4614, 1280)


Extracting features - ConvNeXt: 100%|██████████| 4614/4614 [01:31<00:00, 50.16it/s]


Feature matrix shape - ConvNeXt: (4614, 1024)


Extracting features - ViT: 100%|██████████| 4614/4614 [01:25<00:00, 53.94it/s]


Feature matrix shape - ViT: (4614, 768)


Extracting features - Swin Transformer: 100%|██████████| 4614/4614 [03:21<00:00, 22.95it/s]


Feature matrix shape - Swin Transformer: (4614, 1024)


Extracting features - DINOv2: 100%|██████████| 4614/4614 [02:02<00:00, 37.70it/s]

Feature matrix shape - DINOv2: (4614, 768)





## "Training" the KNN Classifiers

KNN is trained (fitted) using a simple distance-based rule:
- Each image is classified based on the majority vote of its *k* nearest neighbors in the feature space.
- We use `k=5` neighbors for this experiment.

After training, we compute accuracy.

In [None]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import accuracy_score, f1_score
import numpy as np

RESULTS_PATH = os.path.join(HOME, "runs", "classify")
os.makedirs(RESULTS_PATH, exist_ok=True)

y_train = labels_dict['train']

# Train KNN classifiers with MultiOutputClassifier for all models
trained_models = {}
for model_name, features in all_features.items():
    base_knn = KNeighborsClassifier(n_neighbors=5, n_jobs=-1)
    multi_label_knn = MultiOutputClassifier(base_knn, n_jobs=-1)
    multi_label_knn.fit(features, y_train)

    trained_models[model_name] = multi_label_knn
    print(f"Trained Multi-label KNN for {model_name}")


def evaluate_model(X, y, split_name, model, model_name):
    split_dir = os.path.join(RESULTS_PATH, split_name)
    os.makedirs(split_dir, exist_ok=True)
    y_pred = model.predict(X)

    # Subset Accuracy (exact match ratio)
    subset_accuracy = accuracy_score(y, y_pred)

    # Micro F1 (aggregate across all label-sample pairs)
    micro_f1 = f1_score(y, y_pred, average='micro', zero_division=0)

    # Macro F1 (average F1 across labels)
    macro_f1 = f1_score(y, y_pred, average='macro', zero_division=0)

    return {
        'subset_accuracy': subset_accuracy,
        'micro_f1': micro_f1,
        'macro_f1': macro_f1,
        'y_pred': y_pred
    }

Trained Multi-label KNN for ResNeXt-101
Trained Multi-label KNN for EfficientNet V2
Trained Multi-label KNN for ConvNeXt
Trained Multi-label KNN for ViT
Trained Multi-label KNN for Swin Transformer
Trained Multi-label KNN for DINOv2


## Predictions on valid set

In [None]:
# Extract features for validation set
all_valid_features = {}
for model_info in loaded_models:
    features = extract_features(
        image_paths_dict['valid'],
        model_info['model'],
        model_info['feature_dim'],
        model_info['transform'],
        model_info['name']
    )
    all_valid_features[model_info['name']] = features

y_valid = labels_dict['valid']

# Evaluate all models on valid set
for model_name, model in trained_models.items():
    results = evaluate_model(all_valid_features[model_name], y_valid, "valid", model, model_name)
    print(f"{model_name} - Valid Set:")
    print(f"  Subset Accuracy: {results['subset_accuracy']:.4f}")
    print(f"  Micro F1:        {results['micro_f1']:.4f}")
    print(f"  Macro F1:        {results['macro_f1']:.4f}")
    print()

Extracting features - ResNeXt-101: 100%|██████████| 440/440 [00:11<00:00, 37.36it/s]
Extracting features - EfficientNet V2: 100%|██████████| 440/440 [00:17<00:00, 25.45it/s]
Extracting features - ConvNeXt: 100%|██████████| 440/440 [00:08<00:00, 54.47it/s]
Extracting features - ViT: 100%|██████████| 440/440 [00:08<00:00, 53.98it/s]
Extracting features - Swin Transformer: 100%|██████████| 440/440 [00:19<00:00, 22.86it/s]
Extracting features - DINOv2: 100%|██████████| 440/440 [00:11<00:00, 38.20it/s]


ResNeXt-101 - Valid Set:
  Subset Accuracy: 0.7409
  Micro F1:        0.8153
  Macro F1:        0.7816

EfficientNet V2 - Valid Set:
  Subset Accuracy: 0.7455
  Micro F1:        0.8142
  Macro F1:        0.7957

ConvNeXt - Valid Set:
  Subset Accuracy: 0.7455
  Micro F1:        0.8089
  Macro F1:        0.7850

ViT - Valid Set:
  Subset Accuracy: 0.7591
  Micro F1:        0.8289
  Macro F1:        0.8083

Swin Transformer - Valid Set:
  Subset Accuracy: 0.7500
  Micro F1:        0.8310
  Macro F1:        0.8187

DINOv2 - Valid Set:
  Subset Accuracy: 0.7886
  Micro F1:        0.8530
  Macro F1:        0.8581



## Metrics on test set

In [None]:
all_test_features = {}
for model_info in loaded_models:
    features = extract_features(
        image_paths_dict['test'],
        model_info['model'],
        model_info['feature_dim'],
        model_info['transform'],
        model_info['name']
    )
    all_test_features[model_info['name']] = features

y_test = labels_dict['test']

# Evaluate all models on test set
for model_name, model in trained_models.items():
    results = evaluate_model(all_test_features[model_name], y_test, "test", model, model_name)
    print(f"{model_name} - Test Set:")
    print(f"  Subset Accuracy: {results['subset_accuracy']:.4f}")
    print(f"  Micro F1:        {results['micro_f1']:.4f}")
    print(f"  Macro F1:        {results['macro_f1']:.4f}")
    print()

Extracting features - ResNeXt-101: 100%|██████████| 218/218 [00:06<00:00, 35.64it/s]
Extracting features - EfficientNet V2: 100%|██████████| 218/218 [00:07<00:00, 27.76it/s]
Extracting features - ConvNeXt: 100%|██████████| 218/218 [00:04<00:00, 47.89it/s]
Extracting features - ViT: 100%|██████████| 218/218 [00:03<00:00, 54.97it/s]
Extracting features - Swin Transformer: 100%|██████████| 218/218 [00:09<00:00, 22.88it/s]
Extracting features - DINOv2: 100%|██████████| 218/218 [00:05<00:00, 39.16it/s]


ResNeXt-101 - Test Set:
  Subset Accuracy: 0.7202
  Micro F1:        0.7928
  Macro F1:        0.8051

EfficientNet V2 - Test Set:
  Subset Accuracy: 0.7156
  Micro F1:        0.7959
  Macro F1:        0.8047

ConvNeXt - Test Set:
  Subset Accuracy: 0.7385
  Micro F1:        0.8117
  Macro F1:        0.8319

ViT - Test Set:
  Subset Accuracy: 0.7706
  Micro F1:        0.8293
  Macro F1:        0.8343

Swin Transformer - Test Set:
  Subset Accuracy: 0.7294
  Micro F1:        0.8128
  Macro F1:        0.8178

DINOv2 - Test Set:
  Subset Accuracy: 0.7936
  Micro F1:        0.8722
  Macro F1:        0.8915

