In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import numpy as np

In [2]:
# Define configurations for key parameters and path to root of spectrogram images
DATA_DIR = '../data/processed/fan'
# ResNet50 trained on 224x224 images
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 32
VALIDATION_SPLIT = 0.2
RANDOM_SEED = 42

In [3]:
# Define the image transformations (preprocessing) to meet ResNet50 input requirements
data_transforms = transforms.Compose([
    # Resize images to size expected by ResNet50
    transforms.Resize(IMAGE_SIZE),
    # Convert the image to a PyTorch Tensor, moves the color channel
    # from the last dimension to the first dimension (H x W x C -> C x H x W).
    transforms.ToTensor(),
    # Normalize tensor values. The values for mean and std are the standard
    # ones for models pre-trained on the ImageNet dataset.
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [4]:
# Create initial dataset with ImageFolder
# ImageFolder automatically finds classes based on folder names ('abnormal', 'normal')
# and applies the transformations we defined above.
full_dataset = datasets.ImageFolder(DATA_DIR, transform=data_transforms)

# Print out the classes found and their corresponding indicies
print(f"Classes found: {full_dataset.classes}")
# e.g., {'abnormal': 0, 'normal': 1}
print(f"Class to index mapping: {full_dataset.class_to_idx}")

Classes found: ['abnormal', 'normal']
Class to index mapping: {'abnormal': 0, 'normal': 1}


In [5]:
# Create Stratified Train/Validation Split
# Split the data while maintaining the same percentage of samples
# for each class in both training and validation sets.

# Get the labels from the full dataset
labels = full_dataset.targets
# Create list of indices from 0 to N-1, where N is the number of images
indices = list(range(len(labels)))

# Get stratified indices, test_size defines proportion for validation set
# stratify=labels ensures the split is proportional to the class labels
train_indices, val_indices = train_test_split(
    indices, test_size=VALIDATION_SPLIT, stratify=labels, random_state=RANDOM_SEED
)

# Create PyTorch subset objects using genereated indices
train_dataset = Subset(full_dataset, train_indices)
validation_dataset = Subset(full_dataset, val_indices)

print(f"Total images: {len(full_dataset)}")
print(f"Number of training images: {len(train_dataset)}")
print(f"Number of validation images: {len(validation_dataset)}")

Total images: 1440
Number of training images: 1152
Number of validation images: 288


In [6]:
# DataLoaders wrap the datasets and provide an easy way to iterate over
# data in batches, with options for shuffling and parallel data loading.

# The training loader shuffles data to ensure the model doesn't learn the order of the data.
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

# No need to shuffle the validation data
validation_loader = DataLoader(
    dataset=validation_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False
)

In [7]:
# Verification step, check one batch to see model input
train_features, train_labels = next(iter(train_loader))
print(f"\nFeature batch shape: {train_features.size()}")
print(f"\nLabels batch shape: {train_labels.size()}")


Feature batch shape: torch.Size([32, 3, 224, 224])

Labels batch shape: torch.Size([32])


In [8]:
import sys
sys.path.append('../src')

from data_setup import create_dataloaders

# Define Configuration
DATA_DIR = '../data/processed/fan'
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 32

# Create the dataloaders
train_loader, validation_loader, class_to_idx = create_dataloaders(
    data_dir=DATA_DIR,
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    validation_split=VALIDATION_SPLIT,
    random_seed=RANDOM_SEED
)

print("\nDataLoaders are ready.")
print(f"Classes: {class_to_idx}")

[INFO] Total images: 1440
[INFO] Training images: 1152
[INFO] Validation images: 288

DataLoaders are ready.
Classes: {'abnormal': 0, 'normal': 1}
