# CellDiscoveryNet training notebook

This notebook assumes you have already run the `make_CellDiscoveryNet_input` notebook. It trains a CellDiscoveryNet model on the data generated by that notebook.

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1" # adjust which GPU you want to use here

from deepreg.util import build_dataset
from deepreg.registry import REGISTRY
from deepreg.model import layer, layer_util
import deepreg.loss as loss
# from wormalign.network import DDFNetworkTrainer, DDFNetworkTester
import deepreg.model.optimizer as opt
from deepreg.callback import build_checkpoint_callback

import deepreg.predict as predict
import deepreg.train as train
import tensorflow as tf
import h5py
import numpy as np
import matplotlib.pyplot as plt
import math

import nrrd
import nibabel as nib

from deepreg.loss import image as img_loss

import pickle

In [None]:
def normalize_batched_image(batched_image, eps=1e-7):
    """
    Normalizes each image in a batch to [0, 1] range separately.
    """
    # Calculate the min and max values for each image in the batch
    min_vals = tf.math.reduce_min(batched_image, axis=[1, 2, 3], keepdims=True)
    max_vals = tf.math.reduce_max(batched_image, axis=[1, 2, 3], keepdims=True)
    # Normalize each image separately
    batched_image = batched_image - min_vals
    batched_image = batched_image / tf.maximum(max_vals - min_vals, eps)
    return batched_image

def compute_centroids_3d(image, max_val):
    """
    Compute the centroids of all pixels with each unique value in a 3D image.

    :param image: A 3D numpy array representing the image with dimensions (x, y, z).
    :return: A Nx3 numpy array, where N is the maximum value in the image plus one.
             Each row corresponds to the centroid coordinates (x, y, z) for each value.
    """
    centroids = np.zeros((max_val, 3), dtype=np.float32) - 1  # Initialize the centroids array
    for val in range(1,max_val + 1):
        # Find the indices of pixels that have the current value
        indices = np.argwhere(image == val)
        # Compute the centroid if the value is present in the image
        if len(indices) > 0:
            centroid_x = np.mean(indices[:, 0])  # x-coordinate
            centroid_y = np.mean(indices[:, 1])  # y-coordinate
            centroid_z = np.mean(indices[:, 2])  # z-coordinate
            centroids[val-1] = [centroid_x, centroid_y, centroid_z]
    return centroids

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

## Initialize network configuration settings

Make sure to edit the `dir` settings in the YAML configuration file for `train`, `valid`, and `test` to point to the location of your data. Each `dir` should point to a directory containing `fixed_images.h5` and `moving_images.h5` files output by the `make_CellDiscoveryNet_input` notebook.

In [None]:
config_path = "/store1/PublishedData/Data/prj_register/CellDiscoveryNet/train_config.yaml"
log_dir = "/path/to/your/log_dir"
ckpt_path = ""
exp_name = "multicolor_gncc"
max_epochs = 600

config, log_dir, ckpt_path = train.build_config(
    config_path=config_path,
    log_dir=log_dir,
    exp_name=exp_name,
    ckpt_path=ckpt_path,
    max_epochs=max_epochs,
)

batch_size = config["train"]["preprocess"]["batch_size"]

batch_size


## Load data

In [None]:
### build datasets

data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
    dataset_config=config["dataset"],
    preprocess_config=config["train"]["preprocess"],
    split="train",
    training=True,
    repeat=True,
)

data_loader_val, dataset_val, steps_per_epoch_val = build_dataset(
    dataset_config=config["dataset"],
    preprocess_config=config["train"]["preprocess"],
    split="valid",
    training=False,
    repeat=True,
)

## Build model

In [None]:
### build model

model: tf.keras.Model = REGISTRY.build_model(
    config=dict(
        name=config["train"]["method"],
        moving_image_size=data_loader_train.moving_image_shape,
        fixed_image_size=data_loader_train.fixed_image_shape,
        moving_label_size=(200,3),
        fixed_label_size=(200,3),
        index_size=data_loader_train.num_indices,
        labeled=config["dataset"]["train"]["labeled"],
        batch_size=batch_size,
        config=config["train"],
    )
)

## Build optimizer and callbacks

In [None]:
optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
model.compile(optimizer=optimizer)
model.plot_model(output_dir=log_dir)

In [None]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=config["train"]["save_period"],
    update_freq=config["train"].get("update_freq", "epoch"),
)
ckpt_callback, initial_epoch = build_checkpoint_callback(
    model=model,
    dataset=dataset_train,
    log_dir=log_dir,
    save_period=config["train"]["save_period"],
    ckpt_path=ckpt_path,
)
callbacks = [tensorboard_callback, ckpt_callback]

## Fit model

In [None]:
history = model.fit(
    x=dataset_train,
    steps_per_epoch=steps_per_epoch_train,
    initial_epoch=300,
    epochs=config["train"]["epochs"],
    validation_data=dataset_val,
    validation_steps=steps_per_epoch_val,
    callbacks=callbacks,
)

In [None]:
with open("/path/to/your/history.pkl", "wb") as f:
    pickle.dump(history.history, f)