In [5]:
import os
from pathlib import Path
import torch
import wandb

from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import transforms

In [6]:
import sys
BASE_PATH = "./"
sys.path.append(BASE_PATH)

import import_ipynb
from utils import get_num_cpu_cores, is_linux, is_windows

importing Jupyter notebook from utils.ipynb


In [57]:
def calculate_mean_std(data_loader):
    imgs = torch.stack([img_t for img_t, _ in data_loader.dataset], dim=3)
    print(imgs.shape)
    
    mean = imgs.view(1, -1).mean(dim=-1)
    std = imgs.view(1, -1).std(dim=-1)
    
    print("mean : {0}, std : {1}".format(mean, std))
    return mean, std

In [61]:
def get_fashion_mnist_data():
    data_path = os.path.join(BASE_PATH, "Data", "fashion_mnist")
    
    f_mnist_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transforms.ToTensor())
    f_mnist_train, f_mnist_validation = random_split(f_mnist_train, [55_000, 5_000])
    
    print("Num Train Samples: ", len(f_mnist_train))
    print("Num Validation Samples: ", len(f_mnist_validation))
    print("Sample Shape: ", f_mnist_train[0][0].shape)  # torch.Size([1, 28, 28])
    
    # num_data_loading_workers = get_num_cpu_cores() if is_linux() or is_windows() else 0
    num_data_loading_workers = 1
    print("Number of Data Loading Workers:", num_data_loading_workers)
    
    train_data_loader = DataLoader(
        dataset=f_mnist_train, batch_size=wandb.config.batch_size, shuffle=True,
        pin_memory=True, num_workers=num_data_loading_workers
    )
    
    validation_data_loader = DataLoader(
        dataset=f_mnist_validation, batch_size=wandb.config.batch_size,
        pin_memory=True, num_workers=num_data_loading_workers
    )

    mean, std = calculate_mean_std(train_data_loader)
    
    f_mnist_transforms = nn.Sequential(
        transforms.ConvertImageDtype(torch.float),
        transforms.Normalize(mean = mean, std = std),
    )
    
    return train_data_loader, validation_data_loader, f_mnist_transforms

In [62]:
def get_fashion_mnist_test_data():
    data_path = os.path.join(BASE_PATH, "Data", "fashion_mnist")
    
    f_mnist_test_images = datasets.FashionMNIST(data_path, train=False, download=True)
    f_mnist_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transforms.ToTensor())
    
    print("Num Test Samples: ", len(f_mnist_test))
    print("Sample Shape: ", f_mnist_test[0][0].shape)  # torch.Size([1, 28, 28])
    
    test_data_loader = DataLoader(dataset=f_mnist_test, batch_size=len(f_mnist_test))
    
    mean, std = calculate_mean_std(test_data_loader)
    
    f_mnist_transforms = nn.Sequential(
        transforms.ConvertImageDtype(torch.float),
        transforms.Normalize(mean = mean, std = std),
    )
    
    return f_mnist_test_images, test_data_loader, f_mnist_transforms

In [63]:
if __name__ == "__main__":
    config = {'batch_size': 2048,}
    wandb.init(mode="disabled", config=config)
    
    train_data_loader, validation_data_loader, f_mnist_transforms = get_fashion_mnist_data()
    print()
    f_mnist_test_images, test_data_loader, f_mnist_transforms = get_fashion_mnist_test_data()

Num Train Samples:  55000
Num Validation Samples:  5000
Sample Shape:  torch.Size([1, 28, 28])
Number of Data Loading Workers: 1
torch.Size([1, 28, 28, 55000])
mean : tensor([0.2858]), std : tensor([0.3529])

Num Test Samples:  10000
Sample Shape:  torch.Size([1, 28, 28])
torch.Size([1, 28, 28, 10000])
mean : tensor([0.2868]), std : tensor([0.3524])
