In [8]:
import torch
import os
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from PIL import Image

In [11]:
class CustomDatasetFromImages(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file
            img_path (string): path to the folder where images are
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        self.to_tensor = transforms.ToTensor()
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        print(self.data_info)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open("data/training/"+single_image_name)


        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)

        # Get label(class) of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]

        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

#if __name__ == "__main__":
    # Call dataset
#    custom_mnist_from_images =  \
#        CustomDatasetFromImages('data/training')

In [12]:
#TRAIN_PATH = 'data/training'
train_data = CustomDatasetFromImages('data/training/labels-tab-csv.csv')
#train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)

print(train_data[0])

            0   1
0    0001.png   1
1    0002.png   2
2    0003.png   3
3    0004.png   4
4    0005.png   5
..        ...  ..
855  0856.png   6
856  0857.png   7
857  0858.png   8
858  0859.png   9
859  0860.png  10

[860 rows x 2 columns]
(tensor([[[0.9922, 0.9922, 0.9922,  ..., 0.9922, 0.9922, 0.9922],
         [0.9922, 0.9922, 0.9922,  ..., 0.9922, 0.9922, 0.9922],
         [0.9922, 0.9922, 0.9922,  ..., 0.9922, 0.9922, 0.9922],
         ...,
         [0.9922, 0.9922, 0.9922,  ..., 0.9922, 0.9922, 0.9922],
         [0.9922, 0.9922, 0.9922,  ..., 0.9922, 0.9922, 0.9922],
         [0.9922, 0.9922, 0.9922,  ..., 0.9922, 0.9922, 0.9922]]]), 1)
