# 3D segmentation of golgi aparatus with 3D U-Net

![](../figs/golgi.png)



This notebook demonstrates how to train a 3D U-Net model to perform semantic segmentation of the golgi aparatus from 3D FIB-SEM data as described in the paper:

*Müller, Andreas, et al. "3D FIB-SEM reconstruction of microtubule–organelle interaction in whole primary mouse β cells." Journal of Cell Biology 220.2 (2021).*



1. Install tensorflow with gpu support 

2. Install csbdeep and dependencies:

    - `pip install csbdeep tqdm gputools`
    - `pip install git+https://github.com/stardist/augmend.git`
    
    
3. Download the example data (or adapt your own data into the same format)

    - `wget https://syncandshare.desy.de/index.php/s/FikPy4k2FHS5L4F/download/data_golgi.zip`
    - `unzip data_golgi.zip`

   which should result in the following folder structure:
    ```
    data_golgi
    ├── train
    │   ├── images
    │   └── masks
    └── val
        ├── images
        └── masks
    ```

In [None]:
import sys
import numpy as np
from tqdm import tqdm
from tifffile import imread, imwrite
from itertools import chain
from skimage.segmentation import find_boundaries
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from datetime import datetime
from csbdeep.internals.nets import custom_unet
from csbdeep.utils import Path, normalize
from csbdeep.utils.tf import CARETensorBoard, limit_gpu_memory
limit_gpu_memory(fraction=0.8, total_memory=12000)
from csbdeep.data.generate import sample_patches_from_multiple_stacks
from augmend import Augmend, BaseTransform, Elastic, Identity, FlipRot90, AdditiveNoise, CutOut, GaussianBlur, IntensityScaleShift
from model import UNetConfig, UNet
np.random.seed(42)


In [None]:
root = Path('data_golgi')

In [None]:
def get_data(subset = "train", nfiles = None, inds = None, shuffle = True):    
    src = root/subset
    fx = sorted((src/"images").glob("*.tif"))
    fy = sorted((src/"masks").glob("*.tif"))
    assert len(fx) ==len(fy)

    for f1, f2 in zip(fx,fy):
        print(f"{Path(f1).name}")
        print(f"{Path(f2).name}")
    
    if shuffle:
        np.random.seed(42)
        inds0 = np.arange(len(fx))
        np.random.shuffle(inds0)
        fx = np.array(fx)[inds0]
        fy = np.array(fy)[inds0]
    
    if inds is not None:
        fx = np.array(fx)[inds]
        fy = np.array(fy)[inds]
    else:
        fx = fx[:nfiles]
        fy = fy[:nfiles]

    def crop(x):
        return x[tuple(slice(0,(s//8)*8) for s in x.shape)]

    X = [crop(imread(str(f))).astype(np.float32)/255. for f in tqdm(fx)]

    Y = [crop(imread(str(f)).astype(np.uint8)) for f in tqdm(fy)]

    return X,Y

def batch_generator(X,Y, patch_size=(32,112,112), batch_size=4, shuffle = True):
    if len(X) != len(Y):
        raise ValueError("len(X) != len(Y)")

    if len(X) < batch_size:
        raise ValueError("len(X) < batch_size")

    inds = np.arange(len(X))

    if shuffle:
        np.random.shuffle(inds)

    count = 0
    while True:
        b = tuple(sample_patches_from_multiple_stacks([X[i],Y[i]],
                                                      patch_size = patch_size,
                                                      n_samples=1) for i in inds[:batch_size])
        X_batch , Y_batch = zip(*b)
        X_batch = np.stack(X_batch)[:,0]
        Y_batch = np.stack(Y_batch)[:,0]

        yield X_batch, Y_batch

        count += batch_size
        if count+batch_size>=len(X) and shuffle:
            np.random.shuffle(inds)
        inds = np.roll(inds, -batch_size)
        count = count % len(X)

### Training


The following code trains a 3D U-Net model (using a sum of binary crossentropy and dice loss). You can monitor the progress of the model and its losses with tensorboard:

`tensorboard --logdir=models` 

In [None]:
X, Y = get_data("train")
Xv, Yv = get_data("val")


aug = Augmend()
aug.add([FlipRot90(axis = (1,2)),FlipRot90(axis = (1,2))])
aug.add([Elastic(grid=5, amount=5, order=0, use_gpu=True, axis = (0,1,2)),
             Elastic(grid=5, amount=5, order=0, use_gpu=True, axis = (0,1,2))],
            probability=.8)
aug.add([AdditiveNoise(sigma=(0,0.05)),Identity()], probability=.5)
aug.add([IntensityScaleShift(scale=(.7,1.2), shift=(-0.1,0.1), axis = (0,1,2)),Identity()])


def proc_image(x,y, augment = 0):
    """create border mask etc"""
    if augment>0:
        x,y = aug([x,y])
    y = (y>0).astype(np.float32)[...,np.newaxis]
    x = x[...,np.newaxis]
    return x,y

def class_generator(gen, augment = 0):
    for x,y in gen:
        a,b =  tuple(zip(*tuple(proc_image(_x,_y, augment) for _x,_y in zip(x,y))))
        yield np.stack(a), np.stack(b)

gen = class_generator(batch_generator(X,Y,
                                      patch_size=(48,128,128),
                                      batch_size=min(1,len(X))),augment = 1)
gen_val = class_generator(batch_generator(Xv,Yv,batch_size=min(3,len(Xv)),
                                          patch_size=(48,128,128),
                                          shuffle = False),augment = 0)

conf = UNetConfig(axes = "ZYX",
                  unet_n_depth = 3,
                  unet_pool_size = (2,4,4), 
                  train_reduce_lr = {'factor': 0.5,
                                     'patience': 50,
                                    'min_delta': 0},
                  train_class_weight = (1,5))


timestamp = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
model = UNet(conf, name = f"{timestamp}_unet",basedir = "models")


Xvv, Yvv = next(gen_val)

model.train(X=None, Y= None,data_gen = gen, validation_data=[Xvv, Yvv],
            epochs = 300,
            steps_per_epoch = 512)

### Prediction 

We now will apply the model to a new stack 

In [None]:
def apply(model, x0):
    x = x0.astype(np.float32)/255.
    n_tiles = tuple(int(np.ceil(s/196)) for s in x0.shape)
    y_full = model.predict(x, axes = "ZYX", normalizer =None, n_tiles = n_tiles)
    
    y = y_full>=0.5
    
    return y


fname_input = "..."
outdir = 'output'

# load file 
x0 = imread(fname)

model = UNet(None, "unet", basedir = "models")

y = apply(model, x0)


# save output 
out = Path(outdir)

out.mkdir(exist_ok=True, parents=True)
imwrite(out/f"{Path(fname_input).stem}.unet.tif",y.astype(np.uint16))