In [30]:
import os
import numpy as np
from skimage import io
import keras.backend as K
from unet import unet, preproc

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Model prediction and ensembling

## Set parameters 
__Laboratory, consensus strategy, and model weights__
- `LAB`: one of `inns1`, `inns2`, `mue`, `wue1`, `wue2`
- `INIT`: original model initialization, e.g., `from-scratch` available for _lab_wue1_, `fine-tuned` for other labs
- `ENSEMBLE`: Ensemble name (`consensus1` available)
- `CHECKPOINT_DIR`: relative path to stored model weights

In [31]:
LAB = 'inns2'
INIT = 'fine-tuned'
ENSEMBLE = 'consensus1'
CHECKPOINT_DIR = 'model_library'

# Other parameters as used in our paper
BATCH_SIZE = 4
TILE_SHAPE = (540,540)
PADDING = (184,184)

## Load and prepare data

(data not included in github repository)

In [33]:
DATA_PATH = "bioimage_data/lab-{}/images/".format(LAB) 
CHECKPOINT_PATH = '{}/lab-{}/{}/{}'.format(CHECKPOINT_DIR, LAB, INIT, ENSEMBLE)
MODELS = [x for x in os.listdir(CHECKPOINT_PATH) if x.startswith('model')]

# Get image IDs
file_ids = [x.rsplit('.',1)[0] for x in os.listdir(DATA_PATH)]

# Load images
images = [np.expand_dims(io.imread(os.path.join(DATA_PATH, x), as_gray=True), axis=2)
          for x in [s + '.tif' for s in file_ids]]

# Create generator
data = [{'rawdata': img, 'element_size_um': [1,1]} for img in images]
tile_generator = preproc.TileGenerator(data, TILE_SHAPE, PADDING)

  4%|▍         | 1/26 [00:00<00:03,  8.03it/s]

Processing test samples


100%|██████████| 26/26 [00:03<00:00,  8.08it/s]


## Compute masks and ensemble predictions

In [34]:
softmax_dict = {}

for model in MODELS:
    
    # Get checkpoint names
    cp_name = os.listdir(os.path.join(CHECKPOINT_PATH, model))[0]
    mod_cp = os.path.join(model, cp_name)
    
    # Create unet model
    cp_model = unet.Unet2D(snapshot= os.path.join(CHECKPOINT_PATH, mod_cp))
    
    # Predict new masks from selected checkpoints
    predictions = cp_model.predict(tile_generator)
    softmax_dict[mod_cp] = [predictions[0][i][:,:,1] for i in range(len(file_ids))]
    
    # Free GPU RAM
    sess = K.get_session()
    K.clear_session()
    sess.close()  
        
    # Save binary predictions
    path = os.path.join('pred_masks', 'lab-'+LAB, ENSEMBLE)
    bin_path = os.path.join(path, mod_cp)
    if not os.path.isdir(bin_path): os.makedirs(bin_path)
    for i, idx in enumerate(file_ids):
        file_name = idx + '.png'
        io.imsave(os.path.join(bin_path, file_name), (predictions[1][i]*255).astype('uint8'))
        
# Calculate and save softmax average
for i, idx in enumerate(file_ids):
    stack_tmp = [softmax_dict[cp][i] for cp in softmax_dict]
    stack_tmp = np.mean(stack_tmp, axis=0)
    file_name = idx + '.png'
    bin_path = os.path.join(path, 'ensemble', LAB+'_'+ENSEMBLE+'_ensemble')
    if not os.path.isdir(bin_path): os.makedirs(bin_path)
    io.imsave(os.path.join(bin_path, file_name), (stack_tmp > 0.5).astype('uint8') * 255)

100%|██████████| 312/312 [00:14<00:00, 20.82it/s]
100%|██████████| 312/312 [00:15<00:00, 20.94it/s]
100%|██████████| 312/312 [00:15<00:00, 20.66it/s]
100%|██████████| 312/312 [00:15<00:00, 20.59it/s]
100%|██████████| 312/312 [00:15<00:00, 20.81it/s]
