In [37]:
# This script is an example of fine-tuning one of the 
# BlastoSPIM models on other data - in this case Organoids data

# Cell 1: setting path to blastospim-processing-pipeline-Jupyter code and checking environment
# NOTE: please change path_to_code to your own path below

import sys
import numpy as np
from glob import glob
import os
import json
import tifffile as tif
from csbdeep.utils import Path, normalize
from csbdeep.io import save_tiff_imagej_compatible

from stardist import fill_label_holes, random_label_cmap, calculate_extents, gputools_available
from stardist import Rays_GoldenSpiral
from stardist.matching import matching, matching_dataset
from stardist.models import Config3D, StarDist3D, StarDistData3D
from tensorflow.keras.utils import Sequence
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
print('tensorflow version ',tf.__version__)
lbl_cmap = random_label_cmap()
np.random.seed(42)

# Specify path to blastospim-processing-pipeline-Jupyter directory
# path_to_code = "/path/to/your/blastospim-processing-pipeline-Jupyter/"
path_to_code = "/Users/hnunley/Pictures/blastospim-processing-pipeline-Jupyter/"

## Augmentation temporarily commented out -- TODO: put back in
#print("Augmentation")
#from pyimgaug3d.augmentation import GridWarp, Flip, Identity
#from pyimgaug3d.augmenters import ImageSegmentationAugmenter

tensorflow version  2.10.0


In [38]:
# Cell 2: test whether you have GPU access

tf.get_logger().setLevel('ERROR')
if tf.test.gpu_device_name() == '':
    print('You do not have GPU access.')
else:
    print('You have GPU access')

if gputools_available():
    print("Stardist flag for GPU also working")
else:
    print("Stardist flag for GPU NOT working, go back check out")

You do not have GPU access.
Stardist flag for GPU also working


In [39]:
# Cell 3: set these for training for training

nephochs = 100 # number of epochs, use 200 or 400 for actual training
nsteps_per_epoch = 100 # nsteps_per_epoch = 100  -- use 100 for actual training
z_ptch_size = 32 # z_ptch_size = 32  -- z dimension of tha patch size, 32 x 256 x 256 gives best results
ptch_size = 256 # ptch_size = 256 -- x and y dimension of the patch size, don't have to be same
start_percentile = 1 # start_percentile = 1 -- start percentile for normalization of images
end_percentile = 99.8 # end_percentile = 99.8 -- end percentile for normalization of images
n_channel = 1 # n_channel = 1  # n_channel = 3 for RGB images and (2D 3 slice images) and 1 otherwise
axis_norm = (0, 1, 2) # axis_norm = (0, 1, 2)  # (0, 1, 2) for normalizing channels independently does not matter if n_channel = 1
num_val = 10000 # num_val = 10000 # for testing whether the script works set them to a small nuber such as 2 otherwise set it to 100000
num_trn = 20000 # num_trn = 20000 # for testing whether the script works set them to a small nuber such as 2 otherwise set it to 100000

# the actual validation and training examples are the min(num_val/num_trn, len(data_valX, data_trainX) )

In [40]:
# Cell 4: set up your training data

# for best performance - this data should be anisotropic (~10,1,1) along (z,y,x) -- i.e. slices in z are ~10 x further apart than the units in x and y
# data expected to be 3D and in tif or np format

from glob import glob
import os
import sys
import json

#data_path = "/path/to/your/training/and/validation/Data/"
data_path = path_to_code # this assumes you downloaded the relevant sample data into your path_to_code

# store paths to all training data (here timepoints 000, 250, 350 in annotated data)
# These data_... store paths to all raw and label images in the training set.
data_x_000 = sorted(glob(data_path + "2022_64x256x256_Platynereis/dataset_hdf5_000/dataset_hdf5_000/images/*.npy"))
data_y_000 = sorted(glob(data_path + "2022_64x256x256_Platynereis/dataset_hdf5_000/dataset_hdf5_000/masks/*.npy"))

data_x_250 = sorted(glob(data_path + "2022_64x256x256_Platynereis/dataset_hdf5_250/dataset_hdf5_250/images/*.npy"))
data_y_250 = sorted(glob(data_path + "2022_64x256x256_Platynereis/dataset_hdf5_250/dataset_hdf5_250/masks/*.npy"))

data_x_350 = sorted(glob(data_path + "2022_64x256x256_Platynereis/dataset_hdf5_350/dataset_hdf5_350/images/*.npy"))
data_y_350 = sorted(glob(data_path + "2022_64x256x256_Platynereis/dataset_hdf5_350/dataset_hdf5_350/masks/*.npy"))

data_x = data_x_000 + data_x_250 + data_x_350 # for training, raw images
data_y = data_y_000 + data_y_250 + data_y_350 # for training, label images

# store paths to validation data (here timepoint 100 in annotated data)
data_x_100 = sorted(glob(data_path + "2022_64x256x256_Platynereis/dataset_hdf5_100/dataset_hdf5_100/images/*.npy"))
data_y_100 = sorted(glob(data_path + "2022_64x256x256_Platynereis/dataset_hdf5_100/dataset_hdf5_100/masks/*.npy"))

data_val_x = data_x_100
data_val_y = data_y_100

In [41]:
# Cell 5: Data loaders for intensity images 
# make changes here if your format is different (than tif or npy)

class seq_x(Sequence):

    def __init__(self, data_x_trn):
        self.data_ = data_x_trn[0:min(num_trn, len(data_x_trn))]
        print("Total images = {}, Using {}".format(len(data_x_trn), len(self.data_)))

    def __len__(self):
        return len(self.data_)

    def __getitem__(self, idx):
        if self.data_[idx][-4:] == '.npy':
            x = np.load(self.data_[idx])
        elif self.data_[idx][-4:] == '.tif':
            x = tif.imread(self.data_[idx])
        return normalize(x, start_percentile, end_percentile, axis=axis_norm)

In [42]:
# Cell 6: Data loaders for ground-truth label images 
# make changes here if your format is different

class seq_y(Sequence):

    def __init__(self, data_y_trn):
        self.data_ = data_y_trn[0:min(num_trn, len(data_y_trn))]
        print("Total images = {}, Using {}".format(len(data_y_trn), len(self.data_)))
        self.ndim = 3

    def __len__(self):
        return len(self.data_)

    def __getitem__(self, idx):
        if self.data_[idx][-4:] == '.npy':
            y = np.load(self.data_[idx])
        elif self.data_[idx][-4:] == '.tif':
            y = tif.imread(self.data_[idx])
        return fill_label_holes(y.astype("uint8"))

In [43]:
# Cell 7: constructing the training and validation data
# depending on the size, validation is limited to around 20 images when caching

rng = np.random.RandomState(42)
ind = rng.permutation(len(data_y))

ind_train = ind
ind_val = rng.permutation(len(data_val_y))
data_x_val, data_y_val = [data_val_x[i] for i in ind_val], [data_val_y[i] for i in ind_val]
data_x_trn, data_y_trn = [data_x[i] for i in ind_train], [data_y[i] for i in ind_train]
X_trn = seq_x(data_x_trn)
Y_trn = seq_y(data_y_trn)
print("Total validation images = {}, ground truth {}".format(len(data_val_x), len(data_val_y)))

X_val = seq_x(data_x_val)
Y_val = seq_y(data_y_val)

assert len(X_trn) == len(Y_trn), "len(X_trn) == len(Y_trn) not satisfied"
assert len(X_val) == len(Y_val), "len(X_val) == len(Y_val) not satisfied"

print('- training:       %3d' % len(X_trn))
print('- validation:     %3d' % len(X_val))

Total images = 150, Using 150
Total images = 150, Using 150
Total validation images = 50, ground truth 50
Total images = 50, Using 50
Total images = 50, Using 50
- training:       150
- validation:      50


In [44]:
# Cell 8 (OPTIONAL): print out information about parameters in configuration
print(Config3D.__doc__)

Configuration for a :class:`StarDist3D` model.

    Parameters
    ----------
    axes : str or None
        Axes of the input images.
    rays : Rays_Base, int, or None
        Ray factory (e.g. Ray_GoldenSpiral).
        If an integer then Ray_GoldenSpiral(rays) will be used
    n_channel_in : int
        Number of channels of given input image (default: 1).
    grid : (int,int,int)
        Subsampling factors (must be powers of 2) for each of the axes.
        Model will predict on a subsampled grid for increased efficiency and larger field of view.
    n_classes : None or int
        Number of object classes to use for multi-class predection (use None to disable)
    anisotropy : (float,float,float)
        Anisotropy of objects along each of the axes.
        Use ``None`` to disable only for (nearly) isotropic objects shapes.
        Also see ``utils.calculate_extents``.
    backbone : str
        Name of the neural network architecture to be used as backbone.
    kwargs : dict
  

In [45]:
# Cell 9: set up the model
# and specify anisotropy of images
# and specify rays for star-convex shapes
Y2 = [Y_trn[i] for i in range(0, len(Y_trn), 100)]
extents = calculate_extents(Y2)
anisotropy = tuple(np.max(extents) / extents)

print('empirical anisotropy of labeled objects = %s' % str(anisotropy))
n_rays = 96
use_gpu = gputools_available() #True #() # setting this to True did not work (No module named gputools?)
grid = tuple(1 if a > 1.5 else 2 for a in anisotropy)  # WTH is this?
grid = (1,4,4)
rays = Rays_GoldenSpiral(n_rays, anisotropy=anisotropy)

conf = Config3D(
    rays=rays,
    grid=grid,
    anisotropy=anisotropy,
    use_gpu=use_gpu,
    n_channel_in=n_channel,
    # adjust for your data below (make patch size as large as possible)
    train_patch_size=(z_ptch_size, ptch_size, ptch_size),
    # reduce batch size if run out of memory
    train_batch_size=2  #
    #train_sample_cache = False # LB could try larger batch size (not for validation I think)
    #train_learning_rate = .00003
)
print(conf)

empirical anisotropy of labeled objects = (8.285714285714286, 1.0, 1.0691244239631337)
Config3D(n_dim=3, axes='ZYXC', n_channel_in=1, n_channel_out=97, train_checkpoint='weights_best.h5', train_checkpoint_last='weights_last.h5', train_checkpoint_epoch='weights_now.h5', n_rays=96, grid=(1, 4, 4), anisotropy=(8.285714285714286, 1.0, 1.0691244239631337), backbone='unet', rays_json={'name': 'Rays_GoldenSpiral', 'kwargs': {'n': 96, 'anisotropy': (8.285714285714286, 1.0, 1.0691244239631337)}}, n_classes=None, unet_n_depth=2, unet_kernel_size=(3, 3, 3), unet_n_filter_base=32, unet_n_conv_per_depth=2, unet_pool=(2, 2, 2), unet_activation='relu', unet_last_activation='relu', unet_batch_norm=False, unet_dropout=0.0, unet_prefix='', net_conv_after_unet=128, net_input_shape=(None, None, None, 1), net_mask_shape=(None, None, None, 1), train_patch_size=(32, 256, 256), train_background_reg=0.0001, train_foreground_only=0.9, train_sample_cache=True, train_dist_loss='mae', train_loss_weights=(1, 0.2), 

In [46]:
# Cell 10: set up the model
# Note: this path is to the model you would like to train further.
# This training will update the weights stored in that folder -- files will be overwritten.

# The "python3 download_data_and_models_for_finetune.py" in the setup created a copy of the model to finetune.

model_path = path_to_code + "models"
fldr_name = 'late_blastocyst_model'

model = StarDist3D(conf, name=fldr_name, basedir=model_path)
median_size = calculate_extents(Y2, np.median)
fov = np.array(model._axes_tile_overlap('ZYX'))
print(f"median object size:      {median_size}")
print(f"network field of view :  {fov}")
if any(median_size > fov):
    print("WARNING: median object size larger than field of view of the neural network.")

Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.531428, nms_thresh=0.3.
median object size:      [ 7.   58.   54.25]
network field of view :  [26 93 93]


In [47]:
# Cell 11 (OPTIONAL): modify for your data - augment your data with realistic transformations
# TODO: update aug
class aug():
    def __init__(self):
        aug = ImageSegmentationAugmenter()
        #aug.add_augmentation(GridWarp(grid=(2,2,1), max_shift=4))
        aug.add_augmentation(Flip(0))
        aug.add_augmentation(Flip(1))
        aug.add_augmentation(Flip(2))
        aug.add_augmentation(Identity())
        #         aug.add_augmentation(Random_intensity())
        self.aug = aug

    def __call__(self, img, seg):
        img_ = np.expand_dims(img, axis=-1)
        seg_ = np.expand_dims(seg, axis=-1).astype(np.float32)
        aug_img, aug_seg = self.aug([img_, seg_])
        return aug_img[:, :, :, 0].numpy(), aug_seg[:, :, :, 0].numpy().astype(np.uint8)

In [None]:
# Cell 12: train and optimize thresholds
# Note: This may output some warnings.

augumenter_ = None #aug() # currently set augment to none
model.train(X_trn, Y_trn, validation_data=(X_val, Y_val), augmenter=augumenter_, \
            epochs=nephochs, steps_per_epoch=nsteps_per_epoch)
model.optimize_thresholds(X_val, Y_val)