# Noise2Void - 3D Example

In [None]:
from csbdeep.models import Config, CARE
import numpy as np
from csbdeep.utils import plot_some, plot_history
from csbdeep.utils.n2v_utils import manipulate_val_data

from matplotlib import pyplot as plt

import urllib

import os
import zipfile

In [None]:
# Download data
if not os.path.isdir('./data'):
    os.mkdir('./data')

if not os.path.exists('./data/N2V_exampleData3D.zip'):
    data = urllib.request.urlretrieve('https://cloud.mpi-cbg.de/index.php/s/JVxU9uiwM5f0Raz/download', './data/N2V_exampleData3D.zip')
    with zipfile.ZipFile('./data/N2V_exampleData3D.zip', 'r') as zip_ref:
        zip_ref.extractall('./data/N2V_exampleData3D')

## Configure

To use Noise2Void with the CARE framework we have to switch the <code>train_scheme</code> from <code>'CARE'</code> to <code>'Noise2Void'</code>. This will turn on the pixel-masking which is needed for Noise2Void training. Furhtermore it is recommended to increase the <code>train_batch_size</code> and enable <code>batch_norm</code>. 

To keep the network from learning the identity we have to manipulate the input pixels during training. For this we have the parameter <code>n2v_manipulator</code> with default value <code>'uniform_withCP'</code>. Most pixel manipulators will compute the replacement value based on a neighborhood. With <code>n2v_neighborhood_radius</code> we can control its size. 

Other pixel manipulators:
* normal_withoutCP: samples the neighborhood according to a normal gaussian distribution, but without the center pixel
* normal_additive: adds a random number to the original pixel value. The random number is sampled from a gaussian distribution with zero-mean and sigma = <code>n2v_neighborhood_radius</code>
* normal_fitted: uses a random value from a gaussian normal distribution with mean equal to the mean of the neighborhood and standard deviation equal to the standard deviation of the neighborhood.
* identity: performs no pixel manipulation

For faster training multiple pixels per input patch can be manipulated. In our experiments we manipulated about 1.6% of the input pixels per patch. For a patch size of 64 by 64 pixels we manipulated <code>n2v_num_pix</code> = 64 pixels simultaniously. 

For Noise2Void training it is possible to pass arbitrarily large patches to the training method. From these patches random subpatches of size <code>n2v_patch_shape</code> are extracted during training. Default patch shape is set to (64, 64).  

In [None]:
config = Config('SZYXC', n_channel_in=1, n_channel_out=1, unet_kern_size = 3, train_steps_per_epoch=50, train_loss='mse',
                batch_norm = True, train_scheme = 'Noise2Void', train_batch_size = 4, n2v_num_pix = 2048,
                n2v_patch_shape = (32, 64, 64), n2v_manipulator = 'uniform_withCP', n2v_neighborhood_radius='5',
                train_reduce_lr={'factor': 0.5, 'patience': 20, 'min_delta': 0},
                train_epochs=100)

In [None]:
vars(config)

In [None]:
model = CARE(config=config, name='n2v_3D', basedir='models')

## Training Data Preparation

For training we load __one__ set of low-SNR images and normalize them to 0-mean and 1-std. This data is used as input data and stored in the variable <code>X</code>. Our target <code>Y</code> is <code>X</code> concatenated with a zero-tensor of the same shape. This zero-tensor is used for the masking of the pixels during training. 

In [None]:
# We need to normalize the data before we feed it into our network, and denormalize it afterwards.
def normalize(img, mean, std):
    zero_mean = img - mean
    return zero_mean/std

def denormalize(x, mean, std):
    return x*std + mean

In [None]:
# Load the training data
X = np.load('data/N2V_exampleData3D/Fly_train.npy')[...,np.newaxis]
mean, std = np.mean(X), np.std(X)
X = normalize(X, mean, std)

In [None]:
X.shape

In [None]:
# We concatenate an extra channel filled with zeros. It will be internally used for the masking.
Y = np.concatenate((X, np.zeros(X.shape)), axis=4)
print(X.shape, Y.shape)

### Validation Data Preparation

There are two possiblities to build the validation set:

1. Training-Data like: Meaning that the validations loss is only computed on a fixed number of manipulated pixels. This means that we randomly select a fixed number of pixels before training and manipulate them like it will be done during training. 
2. Test-Data like: Meaning that the validation loss is computed on all __not__ manipulated pixels of the validation set. This setup is more like the setup during testing.

In our paper we chose option (1) to have the same loss during validation as during training. But using option (2) will result in a more stable validation loss since it is computed over __all__ instead of a subset of pixels.

In [None]:
# load the validation data
X_val = np.load('data/N2V_exampleData3D/Fly_val.npy')[...,np.newaxis]
X_val = normalize(X_val, mean, std)

# 1. Option (is not implemented yet for 3D data)

# 2. Option
Y_val = np.concatenate((X_val.copy(), np.ones(X_val.shape)), axis=4)
print(X_val.shape, Y_val.shape)

## Training

In [None]:
history = model.train(X,Y, validation_data=(X_val,Y_val))

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss']);

## Evaluation

We do not have ground truth data to calculate a PSNR with this data.
Instead, we will simply look at the denoised images.

In [None]:
# Potentially load a model thaht was trained before.

model.load_weights( name='weights_last.h5')
#model.load_weights( name='weights_now.h5')
#model.load_weights( name='weights_best.h5')

In [None]:
# Load the test data. 
test_lowSNR_raw = np.load('data/N2V_exampleData3D/Fly_test.npy')
test_lowSNR = normalize(test_lowSNR_raw, mean, std)
print(test_lowSNR_raw.shape)

In [None]:
# Denoise the image. 
predictions = denormalize(model.predict(test_lowSNR[0], axes='ZYX',normalizer=None ), mean, std)
print(predictions.shape)

In [None]:
# Lets have a look at the results. 
vmi=np.percentile(predictions,1)
vma=np.percentile(predictions,99.9)

plt.figure(figsize=(9,15))
plt.title('max-projection of raw data')
plt.imshow(np.max(test_lowSNR_raw[0],0),vmin=vmi,vmax=vma,cmap="magma")
plt.show()
plt.figure(figsize=(9,15))
plt.title('max-projection of denoised data')
plt.imshow(np.max(predictions,0),vmin=vmi,vmax=vma,cmap="magma")
plt.show()