In [None]:
import os
import shutil
import pandas as pd

import torch
import PIL.Image as Image
import torchvision.transforms as transforms
import time
from torch.utils.data import Dataset, DataLoader
from skimage import io, transform
from importlib import reload
import numpy as np

In [None]:
import sys
sys.path.append('../../code/scripts')
import dataset

In [None]:
reload(dataset)

# make sure you comment out the normalization in dataset.py first
def get_normalization_values(dataset_name):
    # computes mean and std of channels over images in the training seat
    if dataset_name == 'isic_no_sonic':
        dataloaders = dataset.get_data_loaders('../../data',
                                     'isic/df_no_sonic_age_over_50_id.csv',
                                     'isic/ImagesSmaller',
                                     'benign_malignant_01',
                                     'auc_roc',
                                     all_group_colnames=['age_over_50_id'])
    
    elif dataset_name == 'isic_with_sonic':
        dataloaders = dataset.get_data_loaders('../../data',
                                     'isic/df_with_sonic_age_over_50_id.csv',
                                     'isic/ImagesSmallerWithSonic',
                                     'benign_malignant_01',
                                     'auc_roc',
                                     all_group_colnames=['age_over_50_id'])

    elif dataset_name == 'cifar4':
        dataloaders = dataset.get_data_loaders('../../data',
                                         'cifar4/df_cifar4_labels.csv',
                                         'cifar4/images',
                                         'air',
                                         'animal',
                                         all_group_colnames=['animal'])
    
    # train_loader_eval_does not shuffle the data
    train_loader, train_loader_eval, _,_,_ = dataloaders
    
    means_by_batch = []
    sds_by_batch = []
    vars_by_batch = []
    counts = 0
    for batch_idx, sample in enumerate(train_loader):
        means = sample['image'].mean(dim=(0,2,3))
        sds = sample['image'].std(dim=(0,2,3))
        vars_ = sample['image'].var(dim=(2,3))
        means_by_batch.append(means)
        sds_by_batch.append(sds)
        vars_by_batch.append(vars_)
        counts += len(sample['image'])
        
    means = np.array([np.array(x) for x in means_by_batch]).mean(axis=0)
    # sds_batch_avgd is closer to what is described here: 
    # https://pytorch.org/docs/stable/torchvision/models.html
    sds_batch_avgd = np.array([np.array(x) for x in sds_by_batch]).mean(axis=0)
    sds = np.sqrt(np.vstack([np.array(x) for x in vars_by_batch]).sum(axis=0) / counts)
    
    return means, sds, sds_batch_avgd

In [None]:
# you might not have the last one ready to go, if so skip that
dataset_names = ['cifar4','isic_no_sonic','isic_with_sonic'] 

for dataset_name in dataset_names:
    dataset_normalization_values = get_normalization_values(dataset_name)
    print('{}:'.format(dataset_name))
    print('means: ',dataset_normalization_values[0])
    print('average sds: ', dataset_normalization_values[2])
    print('precise sds: ', dataset_normalization_values[1])
    print()