# Data normalization

Data normalization means that we want to set the mean and std of each image channel to a set value.
Usually, the mean is set to 0 and standard deviation to 1.

1. This keeps the data in the range that the activation function is activated.
2. Also, it makes sure that the distribution in the 3 channels is similar. We want to have learning rate that is appropriate to the different channels.

Normalization should facilitate learning.

The strategy here is as follows:

1. Calculated the mean and standard deviation of the images channels in our dataset.
2. Apply this normalization to our transofrm pipeline when we get images from the dataset.
3. Apply this normalization when we get new images to process (assuming they are coming from the same distribution as those in the dataset).

In this notebook, we will calculate the mean and std of each channel and save the information in the project configuration file for later use.

In [None]:
!pip install albumentations==1.3.0
!git clone https://github.com/kevin-allen/unetTracker
!pip install -r unetTracker/requirements.txt
!pip install -e unetTracker

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

fn = "/content/drive/My Drive/dsfolder"
if os.path.exists(fn):
  print("We can access the dsfolder directory.")
else:
  raise IOError("Problem accessing the dsfolder directory.")

In [None]:
# this will run the code in the setup_project.py and create a variable called `project`
fn = "/content/drive/My Drive/dsfolder/setup_project.py"
if os.path.exists(fn):
  print("We can access the file.")
else:
  raise IOError("Problem accessing the file.")

%run "/content/drive/My Drive/dsfolder/setup_project.py"

In [None]:
from torch import optim
import torch
import torch.nn as nn
import albumentations as A

from unetTracker.dataset import UNetDataset
from unetTracker.unet import Unet

In [None]:
dataset = UNetDataset(image_dir=project.image_dir, mask_dir=project.mask_dir, coordinate_dir=project.coordinate_dir,
                      image_extension=project.image_extension)

In [None]:
len(dataset)

## Calculate means and standard deviation of each color channel

We have a function in the dataset class to do just this. It loads all images, calculates the mean and standard deviation per color channel, and returns them.

In [None]:
means,stds = dataset.get_normalization_values()

In [None]:
print("means:",means)
print("stds:", stds)

## Saving normalization values in the project configuration

We can save the normalization values in the project object and to the configuration file.

I am dividing the values by the maximum value for a pixel (usually 255).

The formula for normalization in the [Albumentations normalization](https://albumentations.ai/docs/api_reference/augmentations/transforms/) function is : `img = (img - mean * max_pixel_value) / (std * max_pixel_value)`

In [None]:
project.set_normalization_values(means/255.0,stds/255.0)
project.normalization_values

In [None]:
project.save_configuration()

In [None]:
project.load_configuration()

## Apply normalization when training the network

You can apply normalization by setting the transform argument of the dataset.

Here is an example.

In [None]:
means = project.normalization_values["means"]
stds = project.normalization_values["stds"]

transform = A.Compose([A.Normalize(mean=means, std=stds)])
transform

Here is one dataset without normalization and one with normalization.

In [None]:
datasetNoNorm = UNetDataset(image_dir=project.image_dir,
                            mask_dir=project.mask_dir,
                            coordinate_dir=project.coordinate_dir,
                            image_extension=project.image_extension)
datasetWithNorm = UNetDataset(image_dir=project.image_dir,
                              mask_dir=project.mask_dir,
                              coordinate_dir=project.coordinate_dir,
                              transform=transform, # we pass our transform function to the UNetDataset object
                              image_extension=project.image_extension)

In [None]:
imgNoNorm,_,_ = datasetNoNorm[0]
imgWithNorm,_,_ = datasetWithNorm[0]

In [None]:
fig,ax = plt.subplots(1,2,figsize=(6,3))
ax[0].imshow(imgNoNorm.permute(1,2,0)/255)
ax[0].set_title("No normalization")
ax[1].imshow(imgWithNorm.permute(1,2,0))
ax[1].set_title("With normalization")
plt.show()

In [None]:
fig,ax = plt.subplots(1,2,figsize=(6,3),layout="constrained")
ax[0].hist(imgNoNorm.flatten(),bins=50)
ax[0].set_xlabel("Pixel values")
ax[0].set_title("No normalization")
ax[1].hist(imgWithNorm.flatten(),bins=50)
ax[1].set_xlabel("Pixel values")
ax[1].set_title("With normalization")
plt.show()