In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import pandas as pd
from math import sqrt

In [None]:
class MNISTMean(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row["Image"])
        arr = np.array(img)
        return int(arr.sum())

In [None]:
df = pd.read_csv("../data/mnist.csv")
df["Image"] = df["Image"].apply(lambda x: f"../data/{x}")
df.head()

In [None]:
mnist_mean = MNISTMean(df)
mnist_mean[0]

In [None]:
batch_size = 512
height = 28
width = 28

In [None]:
mean_loader = DataLoader(
    mnist_mean, batch_size=batch_size, num_workers=12, shuffle=False
)

In [None]:
mean_values = np.empty(len(mnist_mean))
for x_idx, vals in enumerate(mean_loader):
    for y_idx, val in enumerate(vals):
        mean_values[x_idx * batch_size + y_idx] = val

In [None]:
mean = mean_values.sum() / (len(mean_values) * width * height)
mean

In [None]:
class MNISTStandardDeviation(Dataset):
    def __init__(self, df, mean):
        self.df = df
        self.mean = mean

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row["Image"])
        arr = np.array(img)
        return ((arr - self.mean) ** 2).sum()

In [None]:
mnist_stddev = MNISTStandardDeviation(df, mean)
mnist_stddev[0]

In [None]:
stddev_loader = DataLoader(
    mnist_stddev, batch_size=batch_size, num_workers=12, shuffle=False
)

In [None]:
stddev_values = np.empty(len(mnist_stddev))
for x_idx, vals in enumerate(stddev_loader):
    for y_idx, val in enumerate(vals):
        stddev_values[x_idx * batch_size + y_idx] = val

In [None]:
stddev = sqrt(stddev_values.sum() / ((len(stddev_values) * width * height) - 1))
stddev