# Datasets loader for Pytorch

In [1]:
import torch
import torchvision
from torchvision import transforms

# Define the dataset and transforms
dataset = torchvision.datasets.ImageFolder(root='./data/hymenoptera_data/train', transform=transforms.ToTensor())

# Define the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

Custom datasets

In [2]:
from torch.utils.data import Dataset
import pandas as pd

#Create Custome Dataset
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): 
        self.img_labels = pd.read_csv(annotations_file)   #Create img_labels to save name and label of each images
        self.img_dir = img_dir                            #Assign the path to the directory which containing the images
        self.transform = transform                        #Transform Image for training
        self.target_transform = target_transform          #Transform labels for true output
    def __len__(self):
        return len(self.img_labels)
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])   #Full path of each images
        image = read_image(img_path)/255                                      #Read image
        label = self.img_labels.iloc[idx, 1]                                  #Get label for each image
        
        #Covert Grayscale to RGB
        if image.shape[0] == 1:
            image = torch.cat([image, image, image], dim=0)
            
        #Add transform if necessary
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        
        return self.img_labels.iloc[idx, 0],image, label           #return name of images, image in tensor array and label

In [3]:
#Some tranform
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}


In [4]:
data_dir = './datasets/images'

#Create dataloader for data train
data_train =  CustomImageDataset("./datasets/TrainSplit.csv",data_dir,transform = data_transforms['train'])
dataloader_train = torch.utils.data.DataLoader(data_train, batch_size = 1, shuffle=True,drop_last=True)

In [5]:
#checking
len(dataloader_train)

3751

# Datasets loader for TensorFlow

In [6]:
import tensorflow as tf

# Define the dataset and transforms
dataset = tf.keras.preprocessing.image_dataset_from_directory(
    './data/hymenoptera_data/train',     #directory of folder
    batch_size=32,
    image_size=(224, 224),
    seed=123
)

# Define the dataloader
dataloader = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

Found 245 files belonging to 2 classes.


Custom datasets

In [7]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  #set the environment variable to suppress Tensorflow warning messages
import tensorflow as tf
import pandas as pd
from tensorflow import keras
from tensorflow.keras import layers
directory = './datasets/images'
df = pd.read_csv('./datasets/TrainSplit.csv')     #get data from csv file
file_paths = df['image'].values                   #get image name
labels = df['category'].values                    #get label of each image
ds_train = tf.data.Dataset.from_tensor_slices((file_paths, labels))     #create datasets

def read_image(image_file,label):
    image = tf.io.read_file(directory + image_file)                         #read image
    image = tf.image.decode_image(image, channels = 3, dtype = tf.float32)  #decode image
    return image,label

def augment(image,label):
    return image,label

ds_train = ds_train.map(read_image).map(augment).batch(1)   #datasets with the decoded images their corresponding labels
len(ds_train)

3751