In [1]:
import glob
import urllib
import numpy as np
import matplotlib.pyplot as plt
import json
from tqdm.notebook import tqdm
import tifffile

In [2]:
from deep_sxt.models import double_decoder_cnn
from deep_sxt.data_loader.image_loader import image_preprocessing_function
from deep_sxt.inference.tomogram_segmentation import load_tomogram_from_file, segment_slice

2022-11-17 05:34:48.416943: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-17 05:34:48.562031: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-17 05:34:49.087287: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /srv/data/Apps/hdf5/lib:/usr/local/cuda-11.2/lib64:/srv/data/gcc-11.2.0/libexec/gcc/x86_64-pc-linux-gnu/11.2.0:/srv/data/gcc-11.2.0/lib:/srv/data/gcc-11.2.0/lib32:/srv/data/gcc-11.2.0/lib64:/srv/data/Apps

## Checking TensorFlow installation and GPU availability

In [3]:
import tensorflow as tf

print(f"TensorFlow version: {tf.__version__}")
print(f"Number of GPUs available: {len(tf.config.experimental.list_physical_devices('GPU'))}")

TensorFlow version: 2.10.1
Number of GPUs available: 1


## Building the deep network

Network parameters are loaded from the ```json``` file with the given path.

In [None]:
with open('../network_params/model_config_semisupervised_2021_05.json', 'r') as config_file:
    model_config = json.load(config_file)

The ```model``` object contains the deep network used for image segmentaion. 

In [None]:
model = double_decoder_cnn.MyDoubleDecoderNet(config=model_config)

model.build(input_shape=(None, 256, 256, 1))

model.summary ()

### Loading weights of the pre-trained model


The weights are saved as ```HDF5``` files. A specific set of weights will be downloaded when you first run the following cell.

In [None]:
if not glob.glob("../network_params/*.h5"):
    print("Downloading network weights...")
    urllib.request.urlretrieve('https://ftp.mi.fu-berlin.de/pub/cmb-data/deep_sxt/weights_semisupervised_2021_05.h5',
                       '../network_params/weights_semisupervised_2021_05.h5')

In [None]:
model.load_weights('../network_params/weights_semisupervised_2021_05.h5')

## Preparing tomogram for processing

Please provide the full path to your tomogram file in the following cell:

In [None]:
tomogram_file_path = "path_to_your_tomogram_file"

In [None]:
tomogram = load_tomogram_from_file(tomogram_file_path)

### Hint:
you might want to slice the tomogram to remove uninteresting slices before processing by the network.

Just change the ```start_index``` and ```end_index``` accordingly in the following.

In [None]:
start_index = 0
end_index = tomogram.shape[0]

sel_slices = tomogram[start_index:end_index, :, :]

n_slices = sel_slices.shape[0]
slice_shape = sel_slices.shape[1:]

print(f"Number of slices: {n_slices}")
print(f"Slice shape: {slice_shape}")

n_sample_slices = 5

fig = plt.figure (figsize = (10 * n_sample_slices, 20))

for i in range(n_sample_slices):
    
    _index = n_slices // n_sample_slices * (i + 1) - 1
    
    ax = fig.add_subplot(1, n_sample_slices, i + 1)
    ax.set_axis_off()
    ax.set_title(f"slice #{_index}", fontsize=24)
    
    _img = sel_slices[_index, :, :]
    
    ax.imshow (_img, cmap='Greys_r')

## Feeding the tomogram to the segmentation network

This process can take up to several minutes depending the performance of the graphics card on your machine.

The ```segment_slice``` function chops up the image into smaller ```chunk_size```x```chunk_size``` size for processing without running out of memory. It optionally takes two arguments: ```chunk_size``` and ```stride```. If you encounter an OutOfMemory error during the following process, try reducing the ```chunk_size```.  

In [None]:
chunk_size = 600
stride = 400

segmented_tomogram = []

for _slice in tqdm(sel_slices):
    
    output = segment_slice(model, image_preprocessing_function(_slice),
                           chunk_size=np.amin([chunk_size, *_slice.shape]), stride=stride)
    
    segmented_slice = (output > 0.0).astype(np.float32)
    
    segmented_tomogram.append(segmented_slice.copy())
    
segmented_tomogram = np.array(segmented_tomogram)

In [None]:
fig = plt.figure (figsize = (10 * n_sample_slices, 20))

for i in range(n_sample_slices):
    
    _index = n_slices // n_sample_slices * (i + 1) - 1
    
    ax = fig.add_subplot(2, n_sample_slices, i + 1)
    ax.set_axis_off()
    ax.set_title(f"original slice #{_index}", fontsize=24)
    
    ax.imshow (sel_slices[_index, :, :], cmap='Greys_r')
    
    ax = fig.add_subplot(2, n_sample_slices, i + n_sample_slices + 1)
    ax.set_axis_off()
    ax.set_title(f"segmented slice #{_index}", fontsize=24)
    
    ax.imshow (segmented_tomogram[_index, :, :], cmap='hot')

## Saving the segmented tomogram to file

for further processing/visualizaion

In [None]:
np.save ("../outputs/segmented_tomogram", segmented_tomogram)

### Hint:

For 3D reconstruction, the numpy file created by the previous cell suffices.

You can additionally save the segmented output as a TIFF file to use with software such as **ImageJ**:

In [None]:
tifffile.imwrite("../outputs/segmented_tomogram.tiff",
                 data=(segmented_tomogram * 65535.0).astype(np.uint16), compression='zlib',
                 imagej=True)