# 3D segmentation of secretory granules with 3D stardist

![](../figs/granules.png)



This notebook demonstrates how to train a Stardist model to segment secretory granules from 3D FIB-SEM data as described in the paper:

*Müller, Andreas, et al. "3D FIB-SEM reconstruction of microtubule–organelle interaction in whole primary mouse β cells." Journal of Cell Biology 220.2 (2021).*


For general question regarding those parameters, please see https://github.com/stardist/stardist.

1. Install tensorflow with gpu support 

2. Install stardist and dependencies:

    - `pip install stardist tqdm`
    - `pip install git+https://github.com/stardist/augmend.git`
    
3. Download the example data (or adapt your own data into the same format)

    - `wget https://syncandshare.desy.de/index.php/s/5SJFRtAckjBg5gx/download/data_granules.zip`
    - `unzip data_granules.zip`

   which should result in the following folder structure:
    ```
    data_granules
    ├── train
    │   ├── images
    │   └── masks
    └── val
        ├── images
        └── masks
    ```

In [None]:
from scipy.optimize import minimize_scalar
from csbdeep.utils.tf import limit_gpu_memory
limit_gpu_memory(fraction=0.8, total_memory=12000)
from csbdeep.utils import Path, normalize
from tifffile import imread
from tqdm import tqdm
from glob import glob
from datetime import datetime
import numpy as np
from stardist import fill_label_holes, calculate_extents
from stardist.models import Config3D, StarDist3D
import argparse
from pathlib import Path
from augmend import BaseTransform
from augmend import Augmend, FlipRot90, Elastic, Identity,\
    IntensityScaleShift, AdditiveNoise
import os

In [None]:
root = Path("data_granules")

In [None]:
def get_data(subfolder="train", n=None, normalize_img=True):
    """ load data from """
    src =root/subfolder
    fx = sorted(tuple((src/"images").glob("*.tif")))[:n]
    fy = sorted(tuple((src/"masks").glob("*.tif")))[:n]

    X = tuple(imread(str(f)) for f in tqdm(fx))
    Y = tuple(fill_label_holes(imread(str(f))) for f in tqdm(fy))

    if normalize_img:
        X = tuple(_X.astype(np.float32)/255 for _X in X)

    return X, Y

### Training


The following code trains a 3D stardist model for 100 epochs. Properties of the model that might need adjustment (e.g. number of input channels, patch_size, ...) can be set via the `Config3D` object. 
For general question regarding those parameters, please see https://github.com/stardist/stardist.

In [None]:
X, Y = get_data("train")
Xv, Yv = get_data("val")

extents = calculate_extents(Y)
anisotropy = tuple(np.max(extents) / extents)

n_rays = 96
grid = (2, 2, 2)

print(f"empirical anisotropy of labeled objects = {anisotropy}")
print(f"using grid = {grid}")

conf = Config3D(
        rays=n_rays,
        grid=grid,
        anisotropy=anisotropy,
        use_gpu=False,
        n_channel_in=1,
        backbone="unet",
        unet_n_depth=3,
        train_patch_size=[160, 160, 160],
        train_batch_size=1,
        train_loss_weights=[1, 0.1],
    )
print(conf)
vars(conf)

aug = Augmend()
aug.add([FlipRot90(axis=(0, 1, 2)), FlipRot90(axis=(0, 1, 2))])

aug.add([Elastic(axis=(0, 1, 2), amount=5, grid=6,
                 order=0, use_gpu=True),
         Elastic(axis=(0, 1, 2), amount=5, grid=6,
                 order=0, use_gpu=True)],
        probability=.7)

aug.add([AdditiveNoise(sigma=0.05), Identity()], probability=.5)
aug.add([IntensityScaleShift(scale=(.8, 1.2), shift=(-.1, .1)),
         Identity()], probability=.5)

def simple_augmenter(x, y):
    return aug([x, y])


The cell below will start the training, during which you can monitor the progress of the model and its losses with tensorboard:

`tensorboard --logdir=models`


In [None]:
###############################################################
timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
name = f'{timestamp}_stardist'
basedir = 'models'
    
model = StarDist3D(conf, name=name, basedir=basedir)

model.train(X, Y, validation_data=(X, Y),
                    augmenter=simple_augmenter,
                    epochs=150)

model.optimize_thresholds(X, Y, nms_threshs=[0.1,0.2,0.3])

### Prediction 

We now will apply the model to a new stack 

In [None]:
def apply(model, x0):
    print("normalizing...")
    x = x0.astype(np.float32)/255

    n_tiles = tuple(int(np.ceil(s/160)) for s in x.shape)

    print(f"using {n_tiles} tiles")
    y, polys = model.predict_instances(x, n_tiles=n_tiles)

    rays = polys["rays"]
    polys["rays_vertices"] = rays.vertices
    polys["rays_faces"] = rays.faces
    
    return y, polys 


fname_input = "..."
outdir = 'output'

# load file 
x0 = imread(fname)

# load model and apply it to the stack 
model = StarDist3D(None, name=name, basedir=basedir)
y, polys = apply(model, x0)


# save output 
out = Path(outdir)
out.mkdir(exist_ok=True, parents=True)
imsave(out/f"{Path(fname_input).stem}.stardist.tif", y)
np.savez(out/f"{Path(fname_input).stem}.stardist.npz", **polys)