<hr style="height:2px;">

# Demo: Apply trained CARE model for denoising of SEM castaneum

The trained model is assumed to be located in the folder models with the name my_SEM_model.

More documentation is available at http://csbdeep.bioimagecomputing.com/doc/.

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import Path, download_and_extract_zip_file, plot_some
from csbdeep.io import save_tiff_imagej_compatible
from csbdeep.models import CARE

<hr style="height:2px;">

# Raw low-SNR image and associated high-SNR ground truth

Plot the test stack pair and define its image axes, which will be needed later for CARE prediction.

In [None]:
# Load data, we are looking at the 1.0 usec data (slice 3)
Xtest_1us=imread('data/SEM/test/test.tif').astype(np.float32)[3,:,:,np.newaxis]

#Show input, we are zooming in a bit.
plt.figure(figsize=(7,7))
plt.imshow(Xtest_1us[1500:,:,0],cmap="magma")
plt.show()

<hr style="height:2px;">

# CARE model

Load trained model (located in base directory `models` with name `my_model`) from disk.  
The configuration was saved during training and is automatically loaded when `CARE` is initialized with `config=None`.

In [None]:
model = CARE(config=None, name='my_SEM_model', basedir='models')

## Apply CARE network to raw image

Predict the restored image (image will be successively split into smaller tiles if there are memory issues).

In [None]:
%%time
restored_1us = model.predict(Xtest_1us, 'YXC')

In [None]:
#Show input, we are zooming in a bit.
plt.figure(figsize=(7,7))
plt.imshow(Xtest[1500:,:,0],cmap="magma")
plt.show()

#Show result, we are zooming in a bit.
plt.figure(figsize=(7,7))
plt.imshow(restored[1500:,:,0],cmap="magma")
plt.show()

<hr style="height:2px;">

# Other scan times

Let's have a look at what happens if we apply this model to the other scan times.

In [None]:
# pick an image/scantime
img_slice = 6
#Show input, we are zooming in a bit.
plt.figure(figsize=(7,7))
plt.imshow(Xtest[img_slice,1000:1600,0:600,0],cmap="magma")
plt.title("Input for {scantime:}".format(scantime=scantimes[img_slice]))
plt.show()

#Show result, we are zooming in a bit.
plt.figure(figsize=(7,7))
plt.imshow(restored[img_slice,1000:1600,0:600,0],cmap="magma")
plt.title("Output for {scantime:}".format(scantime=scantimes[img_slice]))
plt.show()