In [3]:
import os
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import numpy as np
import pickle

In [None]:
use_gpu = torch.cuda.is_available()
use_gpu

In [7]:
class DataLoader():
    def __init__(self, path):
        self.root = path
        self.dirs = ["train", "test"]
        self.mean_std_path = {x: "mean_std_value_" + x + ".pkl" for x in self.dirs}
        self.means = {x: [0,0,0] for x in self.dirs}
        self.stdevs = {x: [0,0,0] for x in self.dirs}
        
        for x in self.dirs:
            if not os.path.exists(path + self.mean_std_path[x]):
                transform = transforms.Compose([transforms.ToTensor()])
                dataset = ImageFolder(path + x, transform)
                self.getMeanStd(x, dataset)
            with open(self.root + self.mean_std_path[x], "rb") as f:
                self.means[x] = pickle.load(f)
                self.stdevs[x] = pickle.load(f)
        
        self.transform = {x: transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Normalize(tuple(self.means[x]), tuple(self.stdevs[x]))
                                        ]) for x in self.dirs}
        self.dataset = {x: ImageFolder(path + x, self.transform[x]) for x in self.dirs}
                
    def getMeanStd(self, type, dataset):
        num = len(dataset)
        for data in dataset:
            img = data[0]
            for i in range(3):
                self.means[type][i] += img[i, :, :].mean()
                self.stdevs[type][i] += img[i, :, :].std()
        
        self.means[type] = np.asarray(self.means[type]) / num
        self.stdevs[type] = np.asarray(self.stdevs[type]) / num
        
        print("{} : normMean = {}".format(type, self.means))
        print("{} : normstdevs = {}".format(type, self.stdevs))
        
        with open(self.root + self.mean_std_path[type], "wb") as f:
            pickle.dump(self.means[type], f)
            pickle.dump(self.stdevs[type], f)
            print("pickle done")
            