In [1]:
import os
from custom_lib.data_prep import data_transformation_pipeline, data_loader
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split



In [9]:
data_dir = "~/Documents/data/"
data_folder = "kaggle_tb"

rotate_angle = 5
center_crop = 224
image_size = 224
normalize = True
train_prop = .8
batch_size = 32

In [3]:
train_transform = data_transformation_pipeline(image_size=image_size, center_crop=center_crop, rotate_angle=5,
                                               normalize=True, is_train=True)
val_transform = data_transformation_pipeline(image_size=image_size, center_crop=center_crop, is_train=False, normalize=True)

In [4]:
data_path = f"{data_dir}{data_folder}"




In [7]:
def data_loader(data_path, train_transform, val_transform, train_prop, batch_size):
    
    # Load full dataset without transform (we'll assign it later)
    full_dataset = datasets.ImageFolder(root=data_path)
    
    # Create indices
    num_samples = len(full_dataset)
    indices = torch.arange(num_samples)

    # Extract class labels for stratification
    class_labels = [full_dataset.targets[i] for i in indices]

    # Split indices into train, validation, and test sets (stratified)
    train_indices, temp_indices = train_test_split(indices, train_size=train_prop, random_state=42, stratify=class_labels)
    val_indices, test_indices = train_test_split(temp_indices, train_size=0.5, random_state=42, stratify=[class_labels[i] for i in temp_indices])

    # Create separate datasets with appropriate transformations
    train_dataset = datasets.ImageFolder(root=data_path, transform=train_transform)
    val_dataset = datasets.ImageFolder(root=data_path, transform=val_transform)
    test_dataset = datasets.ImageFolder(root=data_path, transform=val_transform)

    # Create DataLoaders using `SubsetRandomSampler`
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=SubsetRandomSampler(train_indices))
    val_loader = DataLoader(val_dataset, batch_size=batch_size * 2, sampler=SubsetRandomSampler(val_indices))
    test_loader = DataLoader(test_dataset, batch_size=batch_size * 2, sampler=SubsetRandomSampler(test_indices))

    print(f"Train size {len(train_indices)}. Val size {len(val_indices)}. Test size {len(test_indices)}.")

    return train_loader, val_loader, test_loader

In [10]:
data_loader(data_path = data_path, train_transform = train_transform, val_transform=val_transform, 
             train_prop=train_prop, batch_size=batch_size)

Train size 3360. Val size 420. Test size 420.


(<torch.utils.data.dataloader.DataLoader at 0x15a70af10>,
 <torch.utils.data.dataloader.DataLoader at 0x15a61e410>,
 <torch.utils.data.dataloader.DataLoader at 0x15a61e4d0>)