In [1]:
#source: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

In [7]:
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image

In [8]:
#NOTE: all of these functions are built in python functions that are used by pytorch to define it's Dataset class!
class CustomImageDataset(Dataset):
    
    #the constructor is used to get all the required information ot load the data.
    #it grabs the file location of the image folder, the annotations file (which stores the labels), and specifies whether the data needs any transforms
    def __init__(self, annotations_file, img_dir, transform=None, target_Transform=None):
        
        #annotations file is a two column file in the form of "{file name}, {label}"
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        
    #the __len__ function simple returns the number of items in the dataset
    def __len__(self):
        return len(self.img_labels)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        
        if self.transform:
            image = self.transform(image)
        
        if self.target_transform:
            label = self.target_transform(label)
            
        return image, label