# Generate a Histogram Based Noise Model
We will use pairs of noisy observations $x_i$ and clean signal $s_i$ (created by averaging many noisy images) to estimate the conditional distribution $p(x_i|s_i)$.
Note that this noise model is independent of the image content. It is a property of the camera and imaging conditions. 

In [1]:
import torch
dtype = torch.float
device = torch.device("cuda:0") 
import matplotlib.pyplot as plt
import numpy as np
import pickle
import sys
sys.path.append('../../')
import pn2v.utils
import pn2v.histNoiseModel
from pn2v import prediction
from tifffile import imread

### Download data
Download the data from  https://owncloud.mpi-cbg.de/index.php/s/224xGSeHquMbQYu. The link contains three different datasets (Convallaria, mouse skull nuclei and mouse actin). Here we show the pipeline for Convallaria dataset. Load the appropriate dataset from the right path. For us, the path is data folder

### The data

The noise model is a characteristic of your camera. 
The downloaded data folder contains a set of calibration images (For the Convallaria dataset, it is 20190726_tl_50um_500msec_wf_130EM_FD.tif and the data to be denoised is named 20190520_tl_25um_50msec_05pc_488_130EM_Conv.tif)
We can either create a histogram with noisy-GT pairs from data and calibration images or we can bootstrap a suitable histogram noise model after denoising the noisy images with Noise2Void and then using these denoised images as pseudo GT.

Below we will define some parameters to identify the mode calibration data/bootstrap as well as noise model names)

In [None]:
path="../../data/Convallaria_diaphragm/"
dataName = 'convallaria' # Name of the noise model 
mode='calibration' # Either `bootstrap` or `calibration`


In [None]:
if mode =='bootstrap': # Bootstrapping mode 
    nameN2VModel = dataName+'_n2v'
    observation= imread(path+'20190520_tl_25um_50msec_05pc_488_130EM_Conv.tif') #Load the appropriate data
    net=torch.load(path+"/last_"+nameN2VModel+".net")
else: 
    nameN2VModel= None
    observation= imread(path+'20190726_tl_50um_500msec_wf_130EM_FD.tif') # Load the appropriate data

nameNoiseModel = path+'noiseModelHistogram_'+ dataName+'_'+mode

In [None]:
if mode=='bootstrap':
    # This cell is only run if bootstrap mode is selected. This performs N2V denoising for generating pseudoGT.
    results=[]
    meanRes=[]
    resultImgs=[]
    inputImgs=[]
    dataTest = observation

    for index in range(dataTest.shape[0]):

        im=dataTest[index]
        # We are using tiling to fit the image into memory
        # If you get an error try a smaller patch size (ps)
        means = prediction.tiledPredict(im, net, ps=256, overlap=48,
                                                device=device, noiseModel=None)
        resultImgs.append(means)
        inputImgs.append(im)
        print ("image:", index)

Let us first take a look at the distribution of signals $s_i$ that are present in this data.
While most pixels are background, we comfortably cover a range to values of 30000 and below. The signals in the images we want to denoise should be within this range.

In [None]:
checkSignalHist = np.histogram(signal, bins=256)
plt.plot( checkSignalHist[1][:-1], np.clip(checkSignalHist[0],0,20000))

### Creating the noise model
Using the raw pixels $x_i$, and our pseudo ground truth $s_i$, we are now creating a 2D histogram. Rows correspond to different signals $s_i$ and columns to different observations $x_i$. The histogram is normalized so that every row sums to one. It describes the distribution $p(x_i|s_i)$ for each $s_i$. This distribution is our noise model.

In [None]:
# The data contains 100 images of a static sample.
# In case using calibration data mode, we estimate the clean signal by averaging all images.
# In bootstrap mode, we estimate pseudo GT by using N2V denoised images.
if mode=='bootstrap':
    signal = np.array(resultImgs)   
else:
    signal=np.mean(observation,axis=0)[np.newaxis,...]

# Let's look the raw data and our pseudo ground truth signal
print(signal.shape)
plt.figure(figsize=(12, 12))
plt.subplot(1, 2, 2)
plt.title(label='average (ground truth)')
plt.imshow(signal[0],cmap='gray')
plt.subplot(1, 2, 1)
plt.title(label='single raw image')
plt.imshow(observation[0],cmap='gray')
plt.show()

In [None]:
# We set the range of values we want to cover with our model.
# The pixel intensities in the images you want to denoise have to lie within this range.
# The dataset is clipped to values between 0 and 255.
minVal, maxVal = 234, 7402
bins = 256

# We are creating the histogram.
# This can take a minute.
histogram = pn2v.histNoiseModel.createHistogram(bins,minVal,maxVal,observation,signal)

# Saving histogram to disc.
np.save(path+nameNoiseModel+'.npy', histogram)
histogramFD=histogram[0]

## Visualize noise model for a specific signal-bin 

Below we just visualize our GMM based noise model for any given signal. 

In [None]:
# Let's look at the noise model
plt.xlabel('observation bin')
plt.ylabel('signal bin')
plt.imshow(histogramFD**0.25, cmap='gray')
plt.show()

In [None]:
xvals=np.arange(bins)/float(bins)*(maxVal-minVal)+minVal
plt.xlabel('observation')
plt.ylabel('probability density')

# We will now look at the noise distributions for different signals s_i,
# by plotting individual rows of the histogram
index=10
s=((index+0.5)/float(bins)*(maxVal-minVal)+minVal)
plt.plot(xvals,histogramFD[index,:], label='bin='+str(index)+' signal='+str(np.round(s,2)))

index=50
s=((index+0.5)/float(bins)*(maxVal-minVal)+minVal)
plt.plot(xvals,histogramFD[index,:], label='bin='+str(index)+' signal='+str(np.round(s,2)))

index=100
s=((index+0.5)/float(bins)*(maxVal-minVal)+minVal)
plt.plot(xvals,histogramFD[index,:], label='bin='+str(index)+' signal='+str(np.round(s,2)))

index=200
s=((index+0.5)/float(bins)*(maxVal-minVal)+minVal)
plt.plot(xvals,histogramFD[index,:], label='bin='+str(index)+' signal='+str(np.round(s,2)))

index=225
s=((index+0.5)/float(bins)*(maxVal-minVal)+minVal)
plt.plot(xvals,histogramFD[index,:], label='bin='+str(index)+' signal='+str(np.round(s,2)))

index=250
s=((index+0.5)/float(bins)*(maxVal-minVal)+minVal)
plt.plot(xvals,histogramFD[index,:], label='bin='+str(index)+' signal='+str(np.round(s,2)))

plt.legend()
plt.show()