In [None]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
import torch

class FoodDataset(Dataset):
    def __init__(self, csv_file, img_root_dir, label_encoder=None, transform=None):
        """
        csv_file: path to train.csv
        img_root_dir: root path to image folder (not including relative path)
        label_encoder: sklearn LabelEncoder instance (optional)
        transform: torchvision transforms
        """
        self.df = pd.read_csv(csv_file).reset_index(drop=True)
        self.img_root_dir = img_root_dir
        self.transform = transform

        # แปลงชื่อ class เป็นตัวเลข
        if label_encoder is None:
            from sklearn.preprocessing import LabelEncoder
            self.le = LabelEncoder()
            self.df['class_encoded'] = self.le.fit_transform(self.df['class'])
        else:
            self.le = label_encoder
            self.df['class_encoded'] = self.le.transform(self.df['class'])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        img_path = os.path.join(self.img_root_dir, row['Image'])  # img_root + relative path
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        label_class = torch.tensor(row['class_encoded'], dtype=torch.long)
        label_mos = torch.tensor(row['mos'], dtype=torch.float32)

        return image, label_class, label_mos

In [None]:
from torchvision import transforms
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader

# Path to data
CSV_PATH = "train.csv"
IMG_ROOT = "images"  # root dir ที่รวม train/img1.jpg ไว้

# Transform สำหรับรูป
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# สร้าง dataset และ dataloader
dataset = FoodDataset(CSV_PATH, IMG_ROOT, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Test
img, label_class, label_mos = dataset[0]
print("Image shape:", img.shape)
print("Class label (int):", label_class)
print("MOS:", label_mos)
print("Class name:", dataset.le.inverse_transform([label_class.item()])[0])