**In case of problems or questions, please first check the list of [Frequently Asked Questions (FAQ)](https://stardist.net/docs/faq.html).**

Please shutdown all other training/prediction notebooks before running this notebook (as those might occupy the GPU memory otherwise).

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

from stardist import random_label_cmap, calculate_extents, gputools_available
from stardist import Rays_GoldenSpiral
from stardist.matching import matching_dataset
from stardist.models import Config3D, StarDist3D

np.random.seed(42)
lbl_cmap = random_label_cmap()

from os import listdir, makedirs
from pathlib import Path
import tifffile

In [None]:
# Use OpenCL-based computations for data generator during training (requires 'gputools')
use_gpu = True and gputools_available()

print('/!\ USING GPU: ', use_gpu)

# If you need to limit the GPU memory used by TensorFlow/StarDist, you can specify that here 
# from csbdeep.utils.tf import limit_gpu_memory
# # adjust as necessary: limit GPU memory to be used by TensorFlow to leave some to OpenCL-based computations
# limit_gpu_memory(0.8, total_memory=48682)

# Data

<div class="alert alert-block alert-info">
Training data (for input `X` with associated label masks `Y`) can be provided via lists of numpy arrays, where each image can have a different size. Alternatively, a single numpy array can also be used if all images have the same size.  
Input images can either be three-dimensional (single-channel) or four-dimensional (multi-channel) arrays, where the channel axis comes last. Label images need to be integer-valued.
</div>

In [None]:
n_channel = 1
path_to_data = folder = Path(globals()['_dh'][0]).parents[1] / 'data'
path_to_trainval_data = path_to_data / 'datasets_for_stardist/trainval_dataset'
 

X_trn = [tifffile.imread(path_to_trainval_data / f'train/imgs/{file}') for file in listdir(path_to_trainval_data / f'train/imgs')]
Y_trn = [tifffile.imread(path_to_trainval_data / f'train/masks/{file}') for file in listdir(path_to_trainval_data / f'train/masks')]

X_val = [tifffile.imread(path_to_trainval_data / f'val/imgs/{file}') for file in listdir(path_to_trainval_data / f'val/imgs')]
Y_val = [tifffile.imread(path_to_trainval_data / f'val/masks/{file}') for file in listdir(path_to_trainval_data / f'val/masks')]

In [None]:
def plot_img_label(img, lbl, img_title="image (XY slice)", lbl_title="label (XY slice)", z=None, **kwargs):
    if z is None:
        z = img.shape[0] // 2    
    fig, (ai,al) = plt.subplots(1,2, figsize=(12,5), gridspec_kw=dict(width_ratios=(1.25,1)))
    im = ai.imshow(img[z], cmap='gray', clim=(0,1))
    ai.set_title(img_title)    
    fig.colorbar(im, ax=ai)
    al.imshow(lbl[z], cmap=lbl_cmap)
    al.set_title(lbl_title)
    plt.tight_layout()

In [None]:
for i in range(8):
    plt.figure()
    plt.imshow(X_trn[0][i*8], cmap='gray', clim=(0,1))
    plt.figure()
    plt.imshow(Y_trn[0][i*8], cmap=lbl_cmap)

# Configuration

A `StarDist3D` model is specified via a `Config3D` object.

In [None]:
print(Config3D.__doc__)

In [None]:
extents = calculate_extents(Y_trn)
anisotropy = tuple(np.max(extents) / extents)
print(extents)
print('empirical anisotropy of labeled objects = %s' % str(anisotropy))

In [None]:
# 96 is a good default choice (see 1_data.ipynb)
n_rays = 64

# Predict on subsampled grid for increased efficiency and larger field of view
#grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)
grid = (2,2,2)
#grid = (1,1,1)

# Use rays on a Fibonacci lattice adjusted for measured anisotropy of the training data
rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)

conf = Config3D (
    rays             = rays,
    grid             = grid,
    anisotropy       = anisotropy,
    use_gpu          = use_gpu,
    n_channel_in     = n_channel,
    # adjust for your data below (make patch size as large as possible)
    train_patch_size = (64,64,64),
    train_batch_size = 8,
    train_epochs = 500,
    train_steps_per_epoch = 20,
    train_reduce_lr={'factor': 0.3, 'patience': 30}
)
#print(conf)
vars(conf)

In [None]:
model = StarDist3D.from_pretrained('3D_demo')
model.config.anisotropy = (2,2,2)
model.config.rays_json = {'name': 'Rays_GoldenSpiral',
  'kwargs': {'n': 64, 'anisotropy': (1.0, 1.0, 1.0)}}

model.config.train_learning_rate= 0.0002
model.config.name='new_model'

makedirs(path_to_data / 'stardist_models', exist_ok=True)

model.basedir = path_to_data / 'stardist_models'
model.logdir = path_to_data / 'stardist_models'

**Note:** The trained `StarDist3D` model will *not* predict completed shapes for partially visible objects at the image boundary.

Check if the neural network has a large enough field of view to see up to the boundary of most objects.

In [None]:
median_size = calculate_extents(Y_trn, np.median)
fov = np.array(model._axes_tile_overlap('ZYX'))
print(f"median object size:      {median_size}")
print(f"network field of view :  {fov}")
if any(median_size > fov):
    print("WARNING: median object size larger than field of view of the neural network.")

# Data Augmentation

You can define a function/callable that applies augmentation to each batch of the data generator.  
We here use an `augmenter` that applies random rotations, flips, and intensity changes, which are typically sensible for (3D) microscopy images (but you can disable augmentation by setting `augmenter = None`).

If augmend is not installed, use "!pip install git+https://github.com/stardist/augmend.git"

In [None]:
#!pip install git+https://github.com/stardist/augmend.git

In [None]:
from augmend import Augmend, FlipRot90, Elastic, Identity, IntensityScaleShift, AdditiveNoise, Scale, Rotate


rotation_kwargs = dict(axis=(1,2), use_gpu=True)

aug = Augmend()
# aug.add([Scale(amount=(.5,2), order=1, use_gpu=use_gpu), Scale(amount=(.5,2), order=0, use_gpu=use_gpu)], probability=0.5)
aug.add([Rotate(order=1,axis=(1,2), use_gpu=True),Rotate(order=0,axis=(1,2), use_gpu=True)], probability=0.5)

aug.add([FlipRot90(axis=(0,1,2)),FlipRot90(axis=(0,1,2))],probability=0.75)
aug.add([IntensityScaleShift(),Identity()],probability=0.75)

aug.add([AdditiveNoise(sigma=0.05), Identity()],probability=0.5)


def augmenter(x,y):
    """Augmentation of a single input/label image pair.
    x is an input image
    y is the corresponding ground-truth label image
    """
    return aug([x,y])


In [None]:
# plot some augmented examples
img, lbl = X_trn[1],Y_trn[1]
plot_img_label(img, lbl)

for _ in range(10):
    img_aug, lbl_aug = augmenter(img,lbl)
    plot_img_label(img_aug, lbl_aug, img_title="image augmented (XY slice)", lbl_title="label augmented (XY slice)")

# Training

We recommend to monitor the progress during training with [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard). You can start it in the shell from the current working directory like this:

    $ tensorboard --logdir=.

Then connect to [http://localhost:6006/](http://localhost:6006/) with your browser.


In [None]:
model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter)

In [None]:
# in case the training has been interrupted:
model._training_finished()

# Threshold optimization

In [None]:
model.optimize_thresholds(X_val, Y_val, nms_threshs=np.linspace(0.1, 0.6, 6), iou_threshs=np.linspace(0.1, 0.6, 6))

# Evaluation and Detection Performance

In [None]:
Y_val_pred = [model.predict_instances(x, n_tiles=model._guess_n_tiles(x), show_tile_progress=False)[0]
              for x in tqdm(X_val)]

Plot a GT/prediction example  

In [None]:
plot_img_label(X_val[0],Y_val[0], lbl_title="label GT (XY slice)")
plot_img_label(X_val[0],Y_val_pred[0], lbl_title="label Pred (XY slice)")

In [None]:
taus = [0.5, 0.6, 0.7, 0.8, 0.9]
stats = [matching_dataset(Y_val, Y_val_pred, thresh=t, show_progress=False) for t in tqdm(taus)]

In [None]:
fig, (ax1,ax2) = plt.subplots(1,2, figsize=(15,5))

for m in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
    ax1.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax1.set_xlabel(r'IoU threshold $\tau$')
ax1.set_ylabel('Metric value')
ax1.grid()
ax1.legend()

for m in ('fp', 'tp', 'fn'):
    ax2.plot(taus, [s._asdict()[m] for s in stats], '.-', lw=2, label=m)
ax2.set_xlabel(r'IoU threshold $\tau$')
ax2.set_ylabel('Number #')
ax2.grid()
ax2.legend();