In [5]:
import os
from pathlib import Path
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from sklearn.metrics import f1_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt


In [6]:
class Food101Dataset(Dataset):
    def __init__(self, root, split="train", transform=None):
        self.root = Path(root)
        self.transform = transform

        # Load the text file containing image paths
        txt_file = self.root / "meta" / f"{split}.txt"
        if not txt_file.exists():
            raise FileNotFoundError(f"{txt_file} not found. Check your dataset path!")

        with open(txt_file, "r") as f:
            lines = [line.strip() for line in f.readlines()]

        # Extract class names
        class_names = sorted(list({line.split("/")[0] for line in lines}))
        self.class_to_idx = {c: i for i, c in enumerate(class_names)}
        self.class_names = class_names

        # Prepare sample list = (image_path, label)
        self.samples = [
            (self.root / "images" / f"{line}.jpg", self.class_to_idx[line.split("/")[0]])
            for line in lines
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]

        # Load image
        img = Image.open(img_path).convert("RGB")

        if self.transform:
            img = self.transform(img)

        return img, label
