In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from pathlib import Path
from csbdeep.utils import download_and_extract_zip_file, plot_some
from csbdeep.data import RawData, create_patches

In [None]:
# data_path = Path("/scratch/beuttenm/lnet/care/beads/01highc")
# data_path = Path("/scratch/beuttenm/lnet/care/beads/f8_01highc")
# data_path = Path("/scratch/beuttenm/lnet/care/heart/static")
data_path = Path("/scratch/beuttenm/lnet/care/heart/static3")

ipt = "lr" # pred or lr
ls = "ls_reg" if "beads" in str(data_path) else "ls_trf"

We can plot the training stack pair via maximum-projection:

In [None]:
idx = 0
y = imread(str(data_path / f"train/{ls}/{idx:05}.tif"))
x = imread(str(data_path / f"train/{ipt}/{idx:05}.tif"))
print('image size =', x.shape)

plt.figure(figsize=(16,10))
plot_some(np.stack([x,y]),
          title_list=[['low (maximum projection)','GT (maximum projection)']], 
          pmin=2,pmax=99.8);

<hr style="height:2px;">

# Generate training data for CARE

We first need to create a `RawData` object, which defines how to get the pairs of low/high SNR stacks and the semantics of each axis (e.g. which one is considered a color channel, etc.).

Here we have two folders "low" and "GT", where corresponding low and high-SNR stacks are TIFF images with identical filenames.  
For this case, we can simply use `RawData.from_folder` and set `axes = 'ZYX'` to indicate the semantic order of the image axes. 

In [None]:
raw_data = RawData.from_folder (
    basepath    = str(data_path / "train"),
    source_dirs = [ipt],
    target_dir  = ls,
    axes        = 'ZYX',
)

From corresponding stacks, we now generate some 3D patches. As a general rule, use a patch size that is a power of two along XYZT, or at least divisible by 8.  
Typically, you should use more patches the more trainings stacks you have. By default, patches are sampled from non-background regions (i.e. that are above a relative threshold), see the documentation of `create_patches` for details.

Note that returned values `(X, Y, XY_axes)` by `create_patches` are not to be confused with the image axes X and Y.  
By convention, the variable name `X` (or `x`) refers to an input variable for a machine learning model, whereas `Y` (or `y`) indicates an output variable.

In [None]:
patch_size = (48,88,88)
# assert x.shape[0] == patch_size[0]
X, Y, XY_axes = create_patches (
    raw_data            = raw_data,
    patch_size          = patch_size,
    n_patches_per_image = int(np.ceil(x.shape[1] / patch_size[1])) * int(np.ceil(x.shape[2] / patch_size[2])),
    save_file           = str(data_path / f"train_patches_{'x'.join([str(p) for p in patch_size])}.npz"),
)

In [None]:
assert X.shape == Y.shape
print("shape of X,Y =", X.shape)
print("axes  of X,Y =", XY_axes)

## Show

This shows the maximum projection of some of the generated patch pairs (odd rows: *source*, even rows: *target*)

In [None]:
for i in range(2):
    plt.figure(figsize=(16,4))
    sl = slice(8*i, 8*(i+1)), 0
    plot_some(X[sl],Y[sl],title_list=[np.arange(sl[0].start,sl[0].stop)])
    plt.show()
None;