In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import tensorflow as tf
assert tf.test.is_built_with_cuda()

from tensorflow.python.client import device_lib

device_lib.list_local_devices()

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from shutil import copy
from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import Path, download_and_extract_zip_file, plot_some
from csbdeep.io import save_tiff_imagej_compatible
from csbdeep.models import CARE

In [None]:
subpath = "beads/f8_01highc"
model_name = "v0_spe1000_on_56x80x80"

if subpath == "beads/f4_01highc":
    assert model_name in ["ep100", "ep400"]
    def postprocess(restored):
        assert len(restored.shape) == 3
        restored =  restored[3:-2, 2:-2, :]
        assert restored.shape == (51, 180, 280)
        return restored
    
    y_postprocess = postprocess
elif subpath == "beads/f8_01highc":
    assert model_name in ["v0_spe400_on_56x64x64", "v0_spe400_on_56x80x80", "v0_spe1000_on_56x80x80"]
    def postprocess(restored):
        assert len(restored.shape) == 3
        restored =  restored[3:-2, :, :]
        assert restored.shape == (51, 376, 576)
        return restored
    
    y_postprocess = postprocess
elif subpath == "heart/static":
    assert model_name in ["v0_spe1000_on_48x88x88"]
    def postprocess(restored):
        assert len(restored.shape) == 3
        restored =  restored[:, :, :]
        assert restored.shape == (48, 376, 576)
        return restored
    
    y_postprocess = postprocess
else:
    raise NotImplementedError(experiment)
    
ls = "ls_reg" if "beads" in subpath else "ls_trf"
model_basedir = "/g/kreshuk/LF_computed/lnet/care/models/" + subpath
assert (Path(model_basedir) / model_name).exists(), Path(model_basedir) / model_name
data_path = Path("/scratch/beuttenm/lnet/care/") / subpath
results = Path("/g/kreshuk/LF_computed/lnet/care/results") / subpath / model_name
results.mkdir(parents=True, exist_ok=True)

<hr style="height:2px;">

# Raw low-SNR image and associated high-SNR ground truth

Plot the test stack pair and define its image axes, which will be needed later for CARE prediction.

In [None]:
idx = 0
axes = "ZYX"
y = imread(str(data_path / f"test/{ls}/{idx:05}.tif"))
x = imread(str(data_path / f"test/lr/{idx:05}.tif"))
print('image size =', x.shape)

plt.figure(figsize=(16,10))
plot_some(np.stack([x,y]),
          title_list=[['low (maximum projection)','GT (maximum projection)']], 
          pmin=2,pmax=99.8);

<hr style="height:2px;">

# CARE model

Load trained model (located in base directory `models` with name `my_model`) from disk.  
The configuration was saved during training and is automatically loaded when `CARE` is initialized with `config=None`.

In [None]:
# model = CARE(config=None, name='ep400', basedir='models')
model = CARE(config=None, name=model_name, basedir=model_basedir)

## Apply CARE network to raw image

Predict the restored image (image will be successively split into smaller tiles if there are memory issues).

In [None]:
%%time
restored = model.predict(x, axes)

Alternatively, one can directly set `n_tiles` to avoid the time overhead from multiple retries in case of memory issues.

**Note**: *Out of memory* problems during `model.predict` can also indicate that the GPU is used by another process. In particular, shut down the training notebook before running the prediction (you may need to restart this notebook).

In [None]:
# %%time
# restored = model.predict(x, axes, n_tiles=(1,4,4))

## Save restored image

Save the restored image stack as a ImageJ-compatible TIFF image, i.e. the image can be opened in ImageJ/Fiji with correct axes semantics.

In [None]:
# save_tiff_imagej_compatible(str(results / f"{idx:05}.tif"), restored, axes)

<hr style="height:2px;">

# Raw low/high-SNR image and denoised image via CARE network

Plot the test stack pair and the predicted restored stack (middle).

In [None]:
plt.figure(figsize=(16,10))
plot_some(np.stack([x,restored,y]),
          title_list=[['low (maximum projection)','CARE (maximum projection)','GT (maximum projection)']], 
          pmin=2,pmax=99.8);

In [None]:
for file_path in tqdm(list(data_path.glob(f"test/lr/*.tif"))):
    y_out_path = Path("/g/kreshuk/LF_computed/lnet/care/gt") / subpath / f"test/{ls}" / file_path.name
    if not y_out_path.exists():
        y_out_path.parent.mkdir(parents=True, exist_ok=True)
        y_path = data_path / f"test/{ls}" / file_path.name
        y = imread(str(y_path))
        y = y_postprocess(y)
        save_tiff_imagej_compatible(str(y_out_path), y, axes)        

    x = imread(str(file_path))
    restored = model.predict(x, axes)
    restored = postprocess(restored)
    save_tiff_imagej_compatible(str(results / file_path.name), restored, axes)
    
print('done')