# Upsampling 3D with single-channel stacks

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import Path, download_and_extract_zip_file, plot_some, axes_dict, plot_history
from csbdeep.utils.tf import limit_gpu_memory
from csbdeep.io import save_training_data, load_training_data, save_tiff_imagej_compatible
from csbdeep.data import RawData, create_patches
from csbdeep.data.transform import anisotropic_distortions
from csbdeep.models import Config, UpsamplingCARE

# Define variables

Specify data directories, image axes, aspect ratio and subsampling values:

In [None]:
# Path to get the training data
datapath = '0-training'

# Output file containing patches
training_file = '{}/training.npz'.format(datapath)

# Model name
modelname = 'proper'

# Basepath for models directory
models = '1-models'

# Order of input image axes
axes = 'ZYX'

# Aspect ratio is needed for plotting input stacks. To calculate use:
# Z res training / XY res training --> 1.0 / 0.2690547 = 3.716716341
aspect_mod = 3.716716341

# Subsample is needed for the training. To calculate use:
# Z res experiment / Z res training --> 3.0 / 1.0 = 3
subsample = 3

# TensorFlow uses all GPU by default, can be useful to limit it:
# limit_gpu_memory(fraction=1/2)

Load Hi/Lo training stacks:

In [None]:
# High and low training stacks
x = imread('{}/low/slpGap_CARE_1.tif'.format(datapath))
y = imread('{}/high/slpGap_CARE_1.tif'.format(datapath))

# Show shapes
print('image size =', x.shape)
print('image size =', y.shape)

# Inspect input stacks

Plot training data for inspection:

In [None]:
# Plot XY slice
plt.figure(figsize=(16, 15))
plot_some(np.stack([x[50], y[50]]),
          title_list=[['XY slice (low)', 'XY slice (high)']],
          pmin=2, pmax=99.8)

# Plot XZ slice
plt.figure(figsize=(16, 15))
plot_some(np.stack([np.moveaxis(x, 1, 0)[800], np.moveaxis(y, 1, 0)[800]]),
          title_list=[['XZ slice (low)', 'XZ slice (high)']],
          pmin=2, pmax=99.8, aspect=aspect_mod)

# Plot YZ slice
plt.figure(figsize=(16,15))
plot_some(np.stack([np.moveaxis(x, 2, 0)[600], np.moveaxis(y, 2, 0)[600]]),
          title_list=[['YZ slice (low)', 'YZ slice (high)']],
          pmin=2, pmax=99.8, aspect=aspect_mod)

# Generate training data

The training data should be in two folders "low" and "high", where corresponding low and high-SNR stacks are TIFF images with identical filenames.

In [None]:
# Define RawData object with paths and axes
raw_data = RawData.from_folder (
    basepath    = datapath,
    source_dirs = ['low'],
    target_dir  = 'high',
    axes        = axes,
)

Specify how to modify the Z axis to mimic the Z resolution of the experimental stack.

In [None]:
# Define transform between training and experimental data
anisotropic_transform = anisotropic_distortions (
    subsample      = subsample,
    psf            = None,
    subsample_axis = 'Z',
    yield_target   = 'target',
)

Generate 3D patches from the synthetically undersampled low quality input stack and its corresponding high quality stack.

Use a patch size that is a power of two along XYZT, or at least divisible by 8. Typically, you should use more patches the more trainings stacks you have.

In [None]:
X, Y, XY_axes = create_patches (
    raw_data            = raw_data,
    patch_size          = (64, 64, 64),
    n_patches_per_image = 1024,
    transforms          = [anisotropic_transform],
    save_file           = training_file,
)

# Inspect 3D patches

Check dimensions are ok:

In [None]:
assert X.shape == Y.shape
print('shape of X,Y =', X.shape)
print(' axes of X,Y =', XY_axes)

Plot ZY slice of some of the generated patch pairs (odd rows: *source*, even rows: *target*)

In [None]:
for i in range(2):
    plt.figure(figsize=(12, 2))
    sl = slice(8 * i, 8 * (i +1 )), slice(None), slice(None), 0
    plot_some(X[sl], Y[sl],
              title_list=[np.arange(sl[0].start, sl[0].stop)],
              aspect=aspect_mod)
    plt.show()
None;

# Load training data from disk

In [None]:
(X,Y), (X_val,Y_val), training_axes = load_training_data(training_file, validation_split=0.1, verbose=True)

c = axes_dict(training_axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

# Inspect validation data

Load 10% of the generated training data:

In [None]:
plt.figure(figsize=(12,2))
plot_some(X_val[slice(0, 8), slice(None), slice(None), 0],
          Y_val[slice(0, 8), slice(None), slice(None), 0],
          aspect=aspect_mod)
plt.suptitle('8 example validation patches (ZY slice, top row: source, bottom row: target)');

# Configure CARE model

This is a sensible configuration for immediate feedback, but the numbers should be increased considerably (e.g. `train_steps_per_epoch=400`, `train_batch_size=16`) to obtain a well-trained model.

In [None]:
config = Config(training_axes, n_channel_in, n_channel_out, train_steps_per_epoch=400, train_batch_size=16)
print(config)
vars(config)

In [None]:
model = UpsamplingCARE(config, modelname, basedir=models)

# Train CARE model

Training the model will take some time. Use TensorBoard to inspect the losses and predictions during training.

In [None]:
# If each epoch takes longer than 288s, the training will take more than 8h!
history = model.train(X, Y, validation_data=(X_val, Y_val))

Plot the final training history:

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16, 5))
plot_history(history, ['loss', 'val_loss'], ['mse', 'val_mse', 'mae', 'val_mae']);

# Evaluate CARE model

Plot examples validation patches:

In [None]:
plt.figure(figsize=(12, 4))

_P = model.keras_model.predict(X_val[:8])

if config.probabilistic:
    _P = _P[..., :(_P.shape[-1] // 2)]

plot_some(X_val[slice(0, 8), slice(None), slice(None), 0],
          Y_val[slice(0, 8), slice(None), slice(None), 0],
          _P[slice(0, 8), slice(None), slice(None), 0],
          pmax=99.5, aspect=aspect_mod)

plt.suptitle('8 example validation patches (ZY slice)\n'      
             'top row: input (source),  '          
             'middle row: target (ground truth),  '
             'bottom row: predicted from source');

# Export CARE model to Fiji/KNIME

In [None]:
model.export_TF()

<hr style="height:2px;">

# Restore experiment stack

Specify the aspect ratio for plotting and load the stack to be restored:

In [None]:
# Aspect ratio of experiment stack. To calculate use:
# Z res experiment / XY res experiment --> 3.0 / 0.2853695 = 11.149
aspect_exp = 10.512686184

# Define experiment stack
stack_exp = 'slpGap_t70s_z3_C1_t35.tif'

# Load experiment stack
#x = imread('{}/test/{}'.format(datapath, stack_exp))
x = imread('2-results/{}'.format(stack_exp))

Check out image dimensions:

In [None]:
print('      image size =', x.shape)
print('      image axes =', axes)
print('subsample factor =', subsample)

Plot stack to be restored:

In [None]:
# Plot XY slice
plt.figure(figsize=(12, 12))
plt.imshow(x[18], cmap='magma')
plt.title('XY slice')
plt.axis('off')

# Plot XZ slice
plt.figure(figsize=(12,12))
plt.imshow(np.moveaxis(x, 1, 0)[800], aspect=aspect_exp, cmap='magma')
plt.title('XZ slice')
plt.axis('off')

# Plot YZ slice
plt.figure(figsize=(12,12))
plt.imshow(np.moveaxis(x, 2, 0)[600], aspect=aspect_exp, cmap='magma')
plt.title('YZ slice')
plt.axis('off')
None;

# Apply Upsampling 3D CARE model

Load trained model from disk:

In [None]:
model = UpsamplingCARE(config=None, name=modelname, basedir=models)

Check out stack dimensions:

In [None]:
print('input size =', x.shape)

Apply model to experiment stack:

In [None]:
%%time

restored = model.predict(x, axes, subsample, n_tiles=(8, 8, 8))

print(' input size =', x.shape)
print('output size =', restored.shape)
print()

Save the restored image stack as a ImageJ-compatible TIFF image:

In [None]:
Path('2-results').mkdir(exist_ok=True)
save_tiff_imagej_compatible('2-results/{}_{}'.format(model.name, stack_exp), restored, axes)

# Inspect denoised/upsampled image

Check dimensions before plotting:

In [None]:
print(' input size =', x.shape)
print('output size =', restored.shape)

Inspect XY slice:

In [None]:
# Plot XY slices
plt.figure(figsize=(16,15))
plot_some(np.stack([x[18], restored[int(subsample * 18)]]),
          title_list=[['XY slice (source)', 'XY slice (network)']],
          pmin=2, pmax=99.8);

Inspect XZ slices:

In [None]:
# Plot XZ slice
plt.figure(figsize=(16,15))
plt.imshow(np.moveaxis(x, 1, 0)[800], aspect=aspect_exp, cmap='magma')
plt.title('XZ slice (source)')
plt.axis('off')

plt.figure(figsize=(16,15))
plt.imshow(np.moveaxis(restored, 1, 0)[800], aspect=aspect_mod, cmap='magma')
plt.title('XZ slice (network)')
plt.axis('off')
None;

Inspect YZ slices:

In [None]:
# Plot YZ slice
plt.figure(figsize=(16,15))
plt.imshow(np.moveaxis(x, 2, 0)[450], aspect=aspect_exp, cmap='magma')
plt.title('YZ slice (source)')
plt.axis('off')

plt.figure(figsize=(16,15))
plt.imshow(np.moveaxis(restored, 2, 0)[450], aspect=aspect_mod, cmap='magma')
plt.title('YZ slice (network)')
plt.axis('off')
None;