In [None]:
import concurrent
from pathlib import Path
from typing import Tuple


from edt import edt
from tifffile import imread, imwrite
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import gaussian_filter, grey_closing

In [None]:
def bcm3d_targets(labels: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    
    euclidian_dist = edt(labels)
    binary_mask = labels > 0
    
    next_cell_dist_ = np.zeros_like(euclidian_dist)
    
    label_vals = np.unique(labels)
    label_vals = label_vals[label_vals != 0]
    
    for v in label_vals:
        mask = labels == v
        
        selection = euclidian_dist[mask]
        euclidian_dist[mask] = selection / selection.max()
        
        labels_ = np.ones_like(next_cell_dist_)
        labels_[binary_mask] = 0
        labels_[mask] = 1
        proximity = edt(labels_)
        next_cell_dist_[mask] = 1/proximity[mask]        
    
    cell_ext_dist = euclidian_dist ** 3
    next_cell_dist = binary_mask - euclidian_dist
    
    # Note(erjel): Paper describes Gaussian blur with simga = (5,5,5)
    cell_ext_dist = gaussian_filter(cell_ext_dist, sigma=(2,2,2))
        
    next_cell_dist *= next_cell_dist_
    
    # Note(erjel): Unclear kernel size for grey closing in paper
    next_cell_dist = grey_closing(next_cell_dist, size=(2,2,2))
    # Note(erjel): Paper describes Gaussian blur with simga = (5,5,5)
    next_cell_dist = gaussian_filter(next_cell_dist, sigma=(2,2,2))
    
    return cell_ext_dist, next_cell_dist

In [None]:
labels = imread('training_data/patches-semimanual-raw-64x128x128/train/masks/im1.tif')

In [None]:
a, b = bcm3d_targets(labels)

In [None]:
labels = [imread(p) for p in sorted(Path('training_data/patches-semimanual-raw-64x128x128/train/masks').glob('*.tif'))]

In [None]:
l = len(labels)
with tqdm(total=l) as pbar:
    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = {executor.submit(bcm3d_targets, arg): arg for arg in labels}
        results = {}
        for future in concurrent.futures.as_completed(futures):
            arg = futures[future]
            results[arg] = future.result()
            pbar.update(1)

# Training

In [None]:
from argparse import ArgumentParser, Namespace
from glob import glob
from pathlib import Path
from typing import Tuple

from csbdeep.utils import normalize
from csbdeep.data import shuffle_inplace
import matplotlib.pyplot as plt
import numpy as np
from tifffile import imread
from tqdm import tqdm


from iterative_biofilm_annotation.unet.utils import crop, SegConfig, CustomDataGenerator

In [None]:
from csbdeep.models import BaseModel
from csbdeep.data import PadAndCropResizer
from csbdeep.internals.nets import common_unet
from csbdeep.utils.tf import CARETensorBoardImage

from csbdeep.utils.tf import keras_import

keras = keras_import()

In [None]:
class BCM3DModel(BaseModel):    
    @property
    def _config_class(self):
        return SegConfig
    
    def _build(self):
        return common_unet(n_dim=3, n_depth=self.config.unet_depth,
                           n_first=32, residual=True,
                           last_activation='linear',
                           n_channel_out=self.config.n_channel_out)((None,None,None,self.config.n_channel_in))

    def _prepare_for_training(self, validation_data, lr):      
        self.keras_model.compile(optimizer=keras.optimizers.Adam(lr),
                                 loss=keras.losses.MeanAbsoluteError(),
                                 metrics=['mae','accuracy'])
        self.callbacks = self._checkpoint_callbacks()
        self.callbacks.append(keras.callbacks.TensorBoard(log_dir=str(self.logdir/'logs'),
                                                          write_graph=False, profile_batch=0))

        self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=validation_data,
                                                   log_dir=str(self.logdir/'logs'/'images'),
                                                   n_images=3, prob_out=False))
        self._model_prepared = True
        
    def train(self, X,Y, validation_data, lr, batch_size, epochs, steps_per_epoch):
        if not self._model_prepared:
            self._prepare_for_training(validation_data, lr)
            
        training_data = CustomDataGenerator(X,Y,batch_size)
        
        history = self.keras_model.fit(training_data, validation_data=validation_data,
                                       epochs=epochs, steps_per_epoch=steps_per_epoch,
                                       callbacks=self.callbacks, verbose=1)
        self._training_finished()
        return history
    
    def predict(self, img, axes=None, normalizer=None, resizer=PadAndCropResizer()):
        normalizer, resizer = self._check_normalizer_resizer(normalizer, resizer)
        axes_net = self.config.axes
        if axes is None:
            axes = axes_net
        axes = axes_check_and_normalize(axes, img.ndim)
        axes_net_div_by = tuple((2**self.config.unet_depth if a in 'XYZ' else 1) for a in axes_net)
        x = self._make_permute_axes(axes, axes_net)(img)
        x = normalizer(x, axes_net)
        x = resizer.before(x, axes_net, axes_net_div_by)        
        pred = self.keras_model.predict(x[np.newaxis])[0]
        pred = resizer.after(pred, axes_net)
        return pred

In [None]:
def train(
    basedir: Path,
    modelname: str,
    dataset: str,
    patch_size: Tuple[int],
    epochs: int,
    steps: int,
    ) -> None:
    return

In [None]:
patch_size = (48, 96, 96)
modelname = 'care_bcm3d_target1_v2'
epochs = 100
steps = 100
basedir = Path('models')


In [None]:
    # load and crop out central patch (for simplicity)
    X_train = [crop(imread(x), patch_size) for x in sorted(glob(f'training_data/patches-semimanual-raw-64x128x128/train/images/*.tif'))]
    Y_train = [crop(imread(y), patch_size) for y in sorted(glob(f'training_data/patches-semimanual-raw-64x128x128/train/target_bacm3d_1/*.tif'))]

    # load and crop out central patch (for simplicity)
    X_valid = [crop(imread(x), patch_size) for x in sorted(glob(f'training_data/patches-semimanual-raw-64x128x128/valid/images/*.tif'))]
    Y_valid = [crop(imread(y), patch_size) for y in sorted(glob(f'training_data/patches-semimanual-raw-64x128x128/valid/target_bacm3d_1/*.tif'))]

    # normalize input image
    X_train = [normalize(x,1,99.8) for x in tqdm(X_train)]

    # normalize input image
    X_valid = [normalize(x,1,99.8) for x in tqdm(X_valid)]

    # convert to numpy arrays
    X_train, Y_train = np.expand_dims(np.stack(X_train),-1), np.expand_dims(np.stack(Y_train), -1)

    # convert to numpy arrays
    X_valid, Y_valid = np.expand_dims(np.stack(X_valid),-1), np.expand_dims(np.stack(Y_valid), -1)

In [None]:
high_std = False
common_unet = False

In [None]:
if high_std:

    plt.hist(X_train.std(axis=(1,2,3,4)), 100);

    sel = X_train.std(axis=(1,2,3,4))>0.16
    X_train = X_train[sel]
    Y_train = Y_train[sel]

    sel = X_valid.std(axis=(1,2,3,4))>0.16
    X_valid = X_valid[sel]
    Y_valid = Y_valid[sel]

In [None]:
    X_train.shape, X_valid.shape

# Common UNet

In [None]:
if common_unet:
    config = SegConfig(n_channel_in=1, n_channel_out=1, unet_depth=2)
    model = BCM3DModel(config, modelname, basedir=str(basedir))
    model

    # shuffle data
    shuffle_inplace(X_train, Y_train, seed=0)
    shuffle_inplace(X_valid, Y_valid, seed=0)

    # for demonstration purposes: training only for a very short time here
    history = model.train(X_train,Y_train, validation_data=(X_valid,Y_valid),
                        lr=4e-4, batch_size=1, epochs=epochs, steps_per_epoch=steps) # Does it improve it batch size = 16?

#    return

# Use CARE instead?

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import axes_dict, plot_some, plot_history
from csbdeep.utils.tf import limit_gpu_memory
from csbdeep.models import Config, CARE

In [None]:
plt.figure(figsize=(12,5))
plot_some(X_valid[:5],Y_valid[:5])
plt.suptitle('5 example validation patches (top row: source, bottom row: target)');

In [None]:
#limit_gpu_memory(fraction=1/2)

In [None]:
axes = 'SZYXC'
config = Config(axes, n_channel_in=1, n_channel_out=1, train_steps_per_epoch=100)
model = CARE(config, modelname, basedir=basedir)
history = model.train(X_train,Y_train, validation_data=(X_valid,Y_valid))

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);