# Create Custom PyTorch Dataset and Upload to PyTorch DataLoaders

In [1]:
# Import necessary libraries
import os
import glob
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from  torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import warnings
warnings.simplefilter('ignore')

In [2]:
# Function to create custom PyTorch dataset from given files.
def alzheimer_dataset():

    # Mean and Standart Deviation values taken from previous works over dataset.
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    # Create transformer for train and validation.
    train_transformer = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
    ])

    # Create transformer for test.
    test_transformer = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
    ])
    
    # Create Datasets
    train_dataset = datasets.ImageFolder(root = 'alzheimer_dataset/train', transform = train_transformer)
    validation_dataset = datasets.ImageFolder(root = 'alzheimer_dataset/validation', transform = train_transformer)
    test_dataset = datasets.ImageFolder(root = 'alzheimer_dataset/test', transform = test_transformer)
    
    return train_dataset, validation_dataset, test_dataset

In [None]:
# Function to visualize samples from created PyTorch dataset.
def visualize_dataset(dataset):
    
    # Visualize samples.
    labels_map = {
        0: "MildDemented",
        1: "ModerateDemented",
        2: "NonDemented",
        3: "VeryMildDemented",
    }

    figure = plt.figure(figsize=(15, 15))
    cols, rows = 5, 5
    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(dataset), size=(1,)).item()
        img, label = dataset[sample_idx]
        # Check the number of channels in the image
        if img.shape[0] == 3:
            # Convert from RGB to grayscale
            img = img.mean(dim=0)
        figure.add_subplot(rows, cols, i)
        plt.title(labels_map[label])
        plt.axis("off")
        plt.imshow(img.squeeze(), cmap="gray")
    plt.show()

In [None]:
# Function to upload PyTorch datasets into PyTorch dataloaders.
def alzheimer_dataloader():
    
    # Crete Data Loaders
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
    validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)
    
    return train_loader, validation_loader, test_loader

In [5]:
# Function to visualize batches from the PyTorch dataloaders.
def visualize_batch(dataloader):
    batch = next(iter(dataloader))
    images, labels = batch

    grid = torchvision.utils.make_grid(images, nrow = 16, normalize=True)
    plt.figure(figsize = (36, 36))
    plt.imshow(np.transpose(grid, (1, 2, 0)))