In [10]:
import os 
import cv2
import numpy as np
import torch
from pathlib import Path
from torch.utils.data import Dataset

In [2]:
train_data_path = './dataset/dataset/train'
test_data_path = './dataset/dataset/test'

classes_train = os.listdir(train_data_path)
classes_test = os.listdir(train_data_path)

assert classes_train==classes_test

In [3]:
# train_images_path and test_images_path
def get_image_paths(data_path, classes_train):
    images_paths = []
    for class_ in classes_train:
        image_directory = os.path.join(data_path, class_)
        images_paths = [os.path.join(image_directory, image_name) for image_name in os.listdir(image_directory)]
    return images_paths

In [4]:
train_image_paths = get_image_paths(train_data_path, classes_train)
test_image_paths = get_image_paths(test_data_path, classes_train)

In [5]:
assert all(list(map(lambda filepath: os.path.exists(filepath), train_image_paths)))
assert all(list(map(lambda filepath: os.path.exists(filepath), test_image_paths)))

In [6]:
thresh = int(0.8*len(train_image_paths))
train_img_paths = train_image_paths[:thresh]
valid_img_paths = train_image_paths[thresh:] 

len(train_image_paths), len(train_img_paths), len(valid_img_paths), len(test_image_paths)

(623, 498, 125, 100)

In [8]:
image = cv2.imread(train_img_paths[0])#.shape
image = np.array(image, np.int32)
# image

image = torch.from_numpy(image)
type(image), image.dtype, image.permute(2, 0, 1).shape

(torch.Tensor, torch.int32, torch.Size([3, 480, 640]))

In [9]:
class_to_index = {class_name:index for index, class_name in enumerate(classes_train)}
class_to_index

{'Closed': 0, 'no_yawn': 1, 'Open': 2, 'yawn': 3}

In [11]:
# Create Custom Dataset
class ClassificationDataset(Dataset):
    def __init__(self, image_paths, transform=None) -> None:
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, index):
        image_filepath = self.image_paths[index]
        image = cv2.imread(image_filepath)
        image = np.array(image, np.float32)

        image = torch.from_numpy(image).permute(2, 0, 1)

        label = Path(image_filepath).parts[-2]
        label = class_to_index[label]

        return image, label