# DataLoader Test

Need write dataloader for custom images (and specify label as well!)

In [1]:
# General Imports
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt

import torch
from torchvision import datasets, transforms

In [2]:
# Custom Image Dataset & Pipeline

import os
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.io import read_image, ImageReadMode

torch.manual_seed(42) #set global seed


# PlanetaryDataloader
class PlanetaryImages(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        #image = read_image(img_path, mode=ImageReadMode.GRAY) #had to convert all images to GrayScale
        image = read_image(img_path, mode=ImageReadMode.RGB)
        
        # If GrayScale, convert to RGB
        # if len(image.shape) < 3:

        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label



# Pipeline
input_size = 224 #GhostNet Required Size

transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(), #used these two originally, images are slightly smaller so switched to center crop
        transforms.ToTensor(),
        #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) #hard-coded normalization values -- GRAYSCALE
    ])

# Call Custom Dataset & Apply Transforms
planetary_data = PlanetaryImages("../PlanetaryImages.csv", "../PlanetaryImages/", transform=transform)

# Get Train & Test Sets
train_set, test_set = random_split(planetary_data, [6500, 872])


# Define DataLoaders & Get Dict for em
trainloader = DataLoader(train_set, batch_size=64, shuffle=True)
testloader = DataLoader(test_set, batch_size=64, shuffle=True)

dataloaders = {'train':trainloader, 'val':testloader}



In [5]:
s = next(iter(dataloaders["train"]))

s[0].shape



torch.Size([64, 3, 224, 224])