<h1 style="text-align: center;">BCFind training</h1>
This notebook gives an example on how to train BCFind by using modules and classes provided by the library


In [None]:
import os
import lmdb
import pickle
import numpy as np
import tensorflow as tf

from bcfind.data import TrainingDataset
from bcfind.models import ResUNet
from bcfind.losses import FramedCrossentropy3D
from bcfind.metrics import Precision, Recall, F1
from bcfind.localizers import BlobDoG

## UNet

### 1. Dataset

- **Paired lists of input/target paths**

> Input files must be .tiff or .tif \
> Target files must be generated by Vaa3D (.marker) or 3D-Slicer (.json)


In [None]:
path_to_my_data = "My_Data"
tiff_dir = f"{path_to_my_data}/Tiff_files/Train"
gt_dir = f"{path_to_my_data}/GT_files/Train"

tiff_files = [f"{tiff_dir}/{fname}" for fname in os.listdir(tiff_dir)]
gt_files = [f"{gt_dir}/{fname}.marker" for fname in os.listdir(tiff_dir)]

- **Data augmentation**

> BCFind-v2 offers a set of operations for data augmentation which can be selected from the following dictionary. \
> If `augmentations` is set to None, no data augmentation will be performed.


In [None]:
augmentations = {
    "gamma": {"param_range": [0.9, 1.1]},
    # 'contrast': {'param_range': [1., 3.]},
    "brightness": {"param_range": [-0.06, 0.06]},
    # 'zoom': {'param_range': [1.0, 1.1], 'order':1},
    "blur": {"param_range": [0.0, 0.3]},
    "noise": {"param_range": [0.0, 0.03]},
    # 'rotation': {'param_range': [0., 270.], 'axes': [-2, -1]},
    "flip": {"axes": [-2]},
}

augmentations_probs = [
    0.3,
] * len(augmentations)

- **Pre-processing**

> `clip` threshold set a ceiling value for the inputs \
> `center` subtracts a specific value to all input pixels \
> `scale` divides all input pixels by a specific value


In [None]:
preprocessing = {
    "clip": "bit",
    "clip_value": 15,  # clip can be ['bit', 'constant', 'quantile', None]
    "center": None,
    "center_value": None,  # center can be one of ['constant', 'min', 'mean', null]
    "scale": "bit",
    "scale_value": 15,  # scale can be one of ['constant', 'bit', 'max', 'std', null]
}

- **TrainingDataset class**


In [None]:
batch_size = 4
voxel_resolution = (2.0, 0.65, 0.65)
input_shape = (80, 240, 240)

train_data = TrainingDataset(
    tiff_files,
    gt_files,
    batch_size,
    voxel_resolution,
    input_shape,
    augmentations,
    augmentations_probs,
    **preprocessing
)

### 2. Model

- **Build model architecture**

> Building with an input shape of (None, None, None, None, 1) is useful for shape flexibility


In [None]:
model = ResUNet(
    n_blocks=4,
    n_filters=16,
    k_size=(3, 5, 5),
    k_stride=(2, 2, 2),
    dropout=None,
    regularizer=None,
)
model.build((None, None, None, None, 1))

- **Model compile**

> Mandatory definitions:
>
> - `loss`
> - `optimizer`
> - `learning-rate`
>
> Optional definitions:
>
> - `metrics`
> - `learning-rate scheduler`


In [None]:
exclude_border = (3, 9, 9)
learning_rate = 0.01

loss = FramedCrossentropy3D(exclude_border, input_shape, from_logits=True)

prec = Precision(0.006, input_shape, exclude_border, from_logits=True)
rec = Recall(0.006, input_shape, exclude_border, from_logits=True)
f1 = F1(0.006, input_shape, exclude_border, from_logits=True)

lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
    learning_rate,
    first_decay_steps=100,
    t_mul=2,
    m_mul=0.8,
    alpha=1e-4,
)

optimizer = tf.keras.optimizers.SGD(
    lr_schedule, momentum=0.9, nesterov=True, weight_decay=7e-4
)

model.compile(loss=loss, optimizer=optimizer, metrics=[prec, rec, f1])

- **Callbacks**

> `ModelCheckpoint` takes care of saving the model each time the loss value improves
>
> `TensorBoard` monitors the loss and metrics during training


In [None]:
path_to_my_exp = "My_Exp"
unet_checkpoint_dir = f"{path_to_my_exp}/UNet_checkpoints"
tensorboard_dir = f"{path_to_my_exp}/UNet_tensorboard"

MC_callback = tf.keras.callbacks.ModelCheckpoint(
    f"{unet_checkpoint_dir}/model.tf",
    initial_value_threshold=0.1,
    save_best_only=True,
    save_format="tf",
    save_freq="epoch",
    monitor="loss",
    mode="min",
    verbose=1,
)

TB_callback = tf.keras.callbacks.TensorBoard(
    tensorboard_dir,
    profile_batch=0,
    write_graph=True,
)

callbacks = [MC_callback, TB_callback]

- **Model training**


In [None]:
model.fit(
    train_data,
    epochs=3000,
    callbacks=callbacks,
    validation_data=None,
    verbose=1,
)

## Blob Detection with DoG

### 1. Input/target pairs

> `dog_inputs`: \
> The blob detector takes UNet predictions as inputs. We therefore need to save all UNet predictions in a iterable. \
> Since all of them can be too big to fit into memory an lmdb database can be used.
>
> `dog_targets`: \
> DoG targets are the arrays of true coordinates: a list of them is usually fine to fit into memory


In [None]:
from bcfind.utils.models import predict
from bcfind.utils.data import get_input_tf, get_gt_as_numpy


max_input_shape = [160, 480, 480]
lmdb_dir = f"{path_to_my_exp}/UNet_pred_train_lmdb"

n = len(tiff_files)
nbytes = np.prod(max_input_shape) * 1  # 4 bytes for float32: 1 byte for uint8

# UNet predictions
print(f"Saving U-Net predictions in {lmdb_dir}")
db = lmdb.open(lmdb_dir, map_size=n * nbytes * 10)
with db.begin(write=True) as fx:
    for i, tiff_file in enumerate(tiff_files):
        print(f"\nUnet prediction on file {i+1}/{len(tiff_files)}")

        x = get_input_tf(tiff_file, **preprocessing)
        pred = predict(x, model)
        pred = tf.sigmoid(tf.squeeze(pred)).numpy()
        pred = (pred * 255).astype("uint8")

        fname = tiff_file.split("/")[-1]
        fx.put(key=fname.encode(), value=pickle.dumps(pred))

db.close()
dog_inputs = lmdb.open(lmdb_dir, readonly=True)

# True cell coordinates
dog_targets = []
for gt_file in gt_files:
    print(f"Loading file {gt_file}")
    y = get_gt_as_numpy(gt_file)
    dog_targets.append(y)

### 3. Hyper-parameter tuning


In [None]:
ndim = 3
max_match_dist = 10
iterations = 50
dog_checkpoint_dir = f"{path_to_my_exp}/DoG_checkpoints"

dog = BlobDoG(3, voxel_resolution, exclude_border)

with dog_inputs.begin() as fx:
    X = fx.cursor()
    dog.fit(
        X=X,
        Y=dog_targets,
        max_match_dist=max_match_dist,
        n_iter=iterations,
        checkpoint_dir=dog_checkpoint_dir,
        n_cpu=5,
    )