## Calculate mean & std
Calculates the mean and std from the dataset, so that the values can be used in a normalization transform

In [15]:
# https://saturncloud.io/blog/how-to-normalize-image-dataset-using-pytorch/#:~:text=This%20is%20done%20by%20scaling,train%20and%20improve%20its%20accuracy.
import torch
from torch.utils.data import DataLoader
import torchvision.transforms.v2 as v2
from ipynb.fs.full.preprocessing import BacteriaDataset, train_data

device = (
    'cuda'
    if torch.cuda.is_available()
    else 'mps'
    if torch.backends.mps.is_available()
    else 'cpu'
)

batch_size = 10
image_size = 2048

resize_transform = v2.Compose([
    v2.Resize([image_size,image_size]),
    v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
    ])
calc_train_dataset = BacteriaDataset(annotations=train_data['encoded_cats'],images=train_data['image_path'],transform=resize_transform,device=device) 
calc_train_dataloader = DataLoader(calc_train_dataset,batch_size=batch_size,shuffle=True)

def get_mean_std(loader:DataLoader, batch_size:int):
    # Compute the mean and standard deviation of all pixels in the dataset
    tot_mean = torch.zeros(3)
    tot_std = torch.zeros(3)
    num_imgs = 0
    for image, _ in loader:
        batch_size, num_channels, height, width = image.shape
        num_imgs += batch_size
        current_mean = image.mean(axis=(0, 2, 3)) # calculated across channels
        current_std = image.std(axis=(0, 2, 3))
        tot_mean += current_mean
        tot_std += current_std

    tot_mean /= num_imgs
    tot_std /= num_imgs

    return tot_mean, tot_std

mean, std = get_mean_std(calc_train_dataloader,batch_size)
print(f'mean: {mean}\n std: {std}')
print('mean_std_done')

mean: tensor([0.0550, 0.0316, 0.0273])
 std: tensor([0.0205, 0.0260, 0.0257])
mean_std_done


## test other mean & std

In [None]:
psum = 0.0
psum_sq = 0.0
train_dataset = BacteriaDataset(annotations=train_data['encoded_cats'],images=train_data['image_path'],transform=resize_transform,device=device) 
train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
for image in train_dataloader:
    psum += image.sum(axis = [0,2,3])
    psum_sq += (image ** 2).sum(axis = [0, 2, 3])

count = len()


KeyboardInterrupt: 