In [1]:
import os
import torch
import torchvision.transforms as transforms
import pandas as pd
import numpy as np

from torch.utils.data import Dataset, DataLoader
from PIL import Image

In [2]:
food_dir = './products_dataset'

FOOD = [
    'FreshApple', 'FreshBanana', 'FreshMango', 'FreshOrange', 'FreshStrawberry',
    'RottenApple', 'RottenBanana', 'RottenMango', 'RottenOrange', 'RottenStrawberry',
    'FreshBellpepper', 'FreshCarrot', 'FreshCucumber', 'FreshPotato', 'FreshTomato',
    'RottenBellpepper', 'RottenCarrot', 'RottenCucumber', 'RottenPotato', 'RottenTomato'
]

In [3]:
class LabeledDataset():
    def __init__(self, food_dir, food_classes, transform=None):
        self.food_dir = food_dir
        self.food_classes = food_classes
        self.transform = transform
        self.images_paths = []
        self.labels = []

        for cls_name in food_classes:
            class_path = os.path.join(food_dir, cls_name)

            for image_name in os.listdir(class_path):
                image_path = os.path.join(class_path, image_name)
                # img = Image.open(image_path).convert("RGB")
                self.images_paths.append(image_path)
                self.labels.append(food_classes.index(cls_name))
        
    def __len__(self):
        return len(self.images_paths)
    
    def __getitem__(self, index):
        image = Image.open(self.images_paths[index]).convert("RGB")
        label = self.labels[index]

        if self.transform:
            image = self.transform(image)

        return image, label

In [4]:
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [5]:
food_dataset = LabeledDataset(food_dir, FOOD, transform=data_transforms)

In [6]:
trainset_len = int(0.8 * len(food_dataset))
testset_len = len(food_dataset) - trainset_len

In [7]:
train_dataset, test_dataset = torch.utils.data.random_split(food_dataset, [trainset_len, testset_len])

In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)