## Calculate mean & std
Calculates the mean and std from the dataset, so that the values can be used in a normalization transform

In [33]:
# https://saturncloud.io/blog/how-to-normalize-image-dataset-using-pytorch/#:~:text=This%20is%20done%20by%20scaling,train%20and%20improve%20its%20accuracy.
import torch
from torch.utils.data import DataLoader
import torchvision.transforms.v2 as v2
from ipynb.fs.full.preprocessing import BacteriaDataset, train_data
import json

image_size = 2048
batch_size = 10

device = (
    'cuda'
    if torch.cuda.is_available()
    else 'mps'
    if torch.backends.mps.is_available()
    else 'cpu'
)

resize_transform = v2.Compose([
    v2.Resize([image_size,image_size]),
    v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
    ])
calc_train_dataset = BacteriaDataset(annotations=train_data['encoded_cats'],images=train_data['image_path'],transform=resize_transform,device=device) 
calc_train_dataloader = DataLoader(calc_train_dataset,batch_size=batch_size,shuffle=True)

def get_mean_std(loader:DataLoader, batch_size:int,device = 'cpu'):
    # Compute the mean and standard deviation of all pixels in the dataset
    tot_mean = torch.zeros(3,device=device)
    tot_std = torch.zeros(3,device=device)
    num_imgs = 0
    for image, _ in loader:
        batch_size, num_channels, height, width = image.shape
        num_imgs += batch_size
        current_mean = image.mean(axis=(0, 2, 3)) # calculated across channels
        current_std = image.std(axis=(0, 2, 3))
        tot_mean += current_mean
        tot_std += current_std

    tot_mean /= num_imgs
    tot_std /= num_imgs

    return tot_mean, tot_std

mean, std = get_mean_std(calc_train_dataloader,batch_size,device)
mean_list = mean.numpy(force=True).tolist()
std_list = std.numpy(force=True).tolist()

image_size_str = str(image_size)
current_mean_std_dict = {image_size_str:{
        'mean':mean_list,
        'std':std_list,
    },
}
#write stats to file, based on image size
stats_file_path = 'means_stds.txt'
with open(stats_file_path, 'r') as stats_file:
        mean_std_dict = json.load(stats_file)
mean_std_dict[image_size_str] = current_mean_std_dict[image_size_str]
with open(stats_file_path,'w') as stats_file:
    json.dump(mean_std_dict,stats_file)

print(f'Image size: {image_size}')
print(f'mean: {mean}\n std: {std}')
print('mean_std_done')

Image size: 2048
mean: tensor([0.0550, 0.0317, 0.0274], device='mps:0')
 std: tensor([0.0205, 0.0260, 0.0258], device='mps:0')
mean_std_done


## test other mean & std

In [37]:
psum = 0.0
psum_sq = 0.0
train_dataset = BacteriaDataset(annotations=train_data['encoded_cats'],images=train_data['image_path'],transform=resize_transform,device=device) 
train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
for image in train_dataloader:
    psum += image.sum(axis = [0,2,3])
    psum_sq += (image ** 2).sum(axis = [0, 2, 3])

count = len()


AttributeError: 'list' object has no attribute 'sum'