# StarDist 3D - Training

Let's check that tensorflow-GPU is working ;) 

In [None]:
import tensorflow as tf 
tf.test.is_gpu_available( cuda_only=False, min_cuda_compute_capability=None )

Code below is simply modified from [StarDist example](https://github.com/stardist/stardist/blob/master/examples/3D/2_training.ipynb)

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import sys
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = None
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tqdm import tqdm
from tifffile import imread
from csbdeep.utils import Path, normalize

from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from stardist import Rays_GoldenSpiral
from stardist.matching import matching, matching_dataset
from stardist.models import Config3D, StarDist3D, StarDistData3D

np.random.seed(42)
lbl_cmap = random_label_cmap()

## Data

In [None]:
# Your images should be in two different folders :

#main_dir
#|_main_image_dir
#    |_images
#        |_img1.tif
#        |_...
#    |_masks
#        |_img1.tif
#        |_...
#|_models
#1-Training_notebook
#2-QC_notebook
#

In [None]:
val_fraction = 0.25
main_image_dir = "crops_BIOP_v1"
rdmSeed=42

In [None]:
X = sorted(glob(main_image_dir+'/images/*.tif'))
Y = sorted(glob(main_image_dir+'/masks/*.tif'))
assert all(Path(x).name==Path(y).name for x,y in zip(X,Y))

X = list(map(imread,X))
Y = list(map(imread,Y))
n_channel = 1 if X[0].ndim == 3 else X[0].shape[-1]

In [None]:
axis_norm = (0,1,2)   # normalize channels independently
# axis_norm = (0,1,2,3) # normalize channels jointly
if n_channel > 1:
    print("Normalizing image channels %s." % ('jointly' if axis_norm is None or 3 in axis_norm else 'independently'))
    sys.stdout.flush()

X = [normalize(x,1,99.8,axis=axis_norm) for x in tqdm(X)]
Y = [fill_label_holes(y) for y in tqdm(Y)]

In [None]:
assert len(X) > 1, "not enough training data"
rng = np.random.RandomState(rdmSeed)
ind = rng.permutation(len(X))
n_val = max(1, int(round(val_fraction * len(ind))))
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
X_val, Y_val = [X[i] for i in ind_val]  , [Y[i] for i in ind_val]
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train] 
print('number of images: %3d' % len(X))
print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))

here we just resave the images use for the validation

In [None]:
from skimage.io import imsave
import os

val_dir = 'val/'
X_val_dir = val_dir+'images/'
Y_val_dir = val_dir+'masks/'

if ( not os.path.isdir(val_dir) ): os.mkdir(val_dir)
if ( not os.path.isdir(X_val_dir) ): os.mkdir(X_val_dir)
if ( not os.path.isdir(Y_val_dir) ): os.mkdir(Y_val_dir)

cnt = 1 
for img in X_val:
    imsave(X_val_dir+str(cnt)+'.tif', img)
    cnt+=1

cnt = 1
for img in Y_val:
    imsave(Y_val_dir+str(cnt)+'.tif', img)
    cnt+=1

## Check anisotropy

In [None]:
extents = calculate_extents(Y)
anisotropy = tuple(np.max(extents) / extents)
print('empirical anisotropy of labeled objects = %s' % str(anisotropy))

## Define Configuration

In [None]:
# 96 is a good default choice (see 1_data.ipynb)
n_rays = 96

anisotropy = (1.6,1,1)
train_patch = (48,64,64)
# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = True and gputools_available()

# Predict on subsampled grid for increased efficiency and larger field of view
#grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)
grid = (1,1,1)
# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data
rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)

model_name = "n1_stardist_"+str(n_rays)+"_"+str(anisotropy)+"_"+str(train_patch)+"_"+str(grid)

conf = Config3D (
    rays             = rays,
    grid             = grid,
    anisotropy       = anisotropy,
    n_channel_in     = n_channel,
    # adjust for your data below (make patch size as large as possible)
    train_patch_size = train_patch,
    train_batch_size = 1,
)
print(model_name)
print(conf)
vars(conf)

In [None]:
if use_gpu:
    from csbdeep.utils.tf import limit_gpu_memory
    # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations
    limit_gpu_memory(0.9)
    # alternatively, try this:
    # limit_gpu_memory(None, allow_growth=True)

In [None]:
model = StarDist3D(conf, name=model_name, basedir='models')

In [None]:
def random_fliprot(img, mask, axis=None): 
    if axis is None:
        axis = tuple(range(mask.ndim))
    axis = tuple(axis)
            
    assert img.ndim>=mask.ndim
    perm = tuple(np.random.permutation(axis))
    transpose_axis = np.arange(mask.ndim)
    for a, p in zip(axis, perm):
        transpose_axis[a] = p
    transpose_axis = tuple(transpose_axis)
    img = img.transpose(transpose_axis + tuple(range(mask.ndim, img.ndim))) 
    mask = mask.transpose(transpose_axis) 
    for ax in axis: 
        if np.random.rand() > 0.5:
            img = np.flip(img, axis=ax)
            mask = np.flip(mask, axis=ax)
    return img, mask 

def random_intensity_change(img):
    img = img*np.random.uniform(0.6,2) + np.random.uniform(-0.2,0.2)
    return img

def augmenter(x, y):
    """Augmentation of a single input/label image pair.
    x is an input image
    y is the corresponding ground-truth label image
    """
    # Note that we only use fliprots along axis=(1,2), i.e. the yx axis 
    # as 3D microscopy acquisitions are usually not axially symmetric
    x, y = random_fliprot(x, y, axis=(1,2))
    x = random_intensity_change(x)
    return x, y

## Finally we can start Training ! 

In [None]:
model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter)

## and here we optimize thresholds and plot some metrics

In [None]:
model.optimize_thresholds(X_val, Y_val)

In [None]:
Y_val_pred = [model.predict_instances(x, n_tiles=model._guess_n_tiles(x), show_tile_progress=False)[0]
              for x in tqdm(X_val)]

In [None]:
taus = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
stats = [matching_dataset(Y_val, Y_val_pred, thresh=t, show_progress=False) for t in tqdm(taus)]

In [None]:
fig, (ax1,ax2) = plt.subplots(1,2, figsize=(15,5))

for m in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
    ax1.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax1.set_xlabel(r'IoU threshold $\tau$')
ax1.set_ylabel('Metric value')
ax1.grid()
ax1.legend()

for m in ('fp', 'tp', 'fn'):
    ax2.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax2.set_xlabel(r'IoU threshold $\tau$')
ax2.set_ylabel('Number #')
ax2.grid()
ax2.legend();

You can find your newly trained model in the models folder and open the [QC_notebook](2-QC_notebook.ipynb) to look to the metrics in more details