
# Setup


## Imports


In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

from functools import partial
import logging
import pathlib
from pathlib import Path
from pprint import pprint
import sys
from typing import *
import time
import yaml
from yaml import YAMLObject

import humanize
from matplotlib import pyplot as plt, cm
import numpy as np
import pandas as pd
from pymicro.file import file_utils
import tensorflow as tf
from numpy.random import RandomState

from tensorflow import keras
from tensorflow.keras import utils
from tensorflow.keras import optimizers
from tensorflow.keras import callbacks as keras_callbacks
from tensorflow.keras import losses

from tomo2seg import modular_unet
from tomo2seg.logger import logger
from tomo2seg import data, viz
from tomo2seg.data import Volume
from tomo2seg.metadata import Metadata
from tomo2seg.volume_sequence import (
    VolumeCropSequence, MetaCrop3DGenerator, ET3DUniformCuboidAlmostEverywhere, 
    UniformGridPosition, GTUniformEverywhere, ET3DConstantEverywhere, 
    VSConstantEverywhere, GTConstantEverywhere, SequentialGridPosition
)
from tomo2seg import volume_sequence
from tomo2seg.model import Model as Tomo2SegModel
from tomo2seg import callbacks as tomo2seg_callbacks

In [None]:
logger.setLevel(logging.DEBUG)

In [None]:
random_state = 42
random_state = np.random.RandomState(random_state)
runid = int(time.time())
logger.info(f"{runid=}")

In [None]:
logger.debug(f"{tf.__version__=}")
logger.info(f"Num GPUs Available: {len(tf.config.list_physical_devices('GPU'))}\nThis should be 2 on R790-TOMO.")
logger.debug(f"Both here should return 2 devices...\n{tf.config.list_physical_devices('GPU')=}\n{tf.config.list_logical_devices('GPU')=}")

# xla auto-clustering optimization (see: https://www.tensorflow.org/xla#auto-clustering)
# this seems to break the training
tf.config.optimizer.set_jit(False)

# get a distribution strategy to use both gpus (see https://www.tensorflow.org/guide/distributed_training)
strategy = tf.distribute.MirroredStrategy()  
# strategy = tf.distribute.MirroredStrategy(devices=[""])  


# Data

In [None]:
from tomo2seg.datasets import (
    VOLUME_COMPOSITE_V1 as VOLUME_NAME_VERSION,
#     VOLUME_COMPOSITE_V1_REDUCED as VOLUME_NAME_VERSION,
    VOLUME_COMPOSITE_V1_LABELS_REFINED3 as LABELS_VERSION
)

volume_name, volume_version = VOLUME_NAME_VERSION
labels_version = LABELS_VERSION

logger.info(f"{volume_name=} {volume_version=} {labels_version=}")

In [None]:
# Metadata/paths objects

## Volume
volume = Volume.with_check(
    name=volume_name, version=volume_version
)
logger.info(f"{volume=}")

def _read_raw(path_: Path, volume_: Volume): 
    # from pymicro
    return file_utils.HST_read(
        str(path_),  # it doesn't accept paths...
        # pre-loaded kwargs
        autoparse_filename=False,  # the file names are not properly formatted
        data_type=volume.metadata.dtype,
        dims=volume.metadata.dimensions,
        verbose=True,
    )

read_raw = partial(_read_raw, volume_=volume)

logger.info("Loading data from disk.")

## Data
voldata = read_raw(volume.data_path) / 255  # normalize
logger.debug(f"{voldata.shape=}")

voldata_train = volume.train_partition.get_volume_partition(voldata)
voldata_val = volume.val_partition.get_volume_partition(voldata)
logger.debug(f"{voldata_train.shape=} {voldata_val.shape=}")

del voldata

## Labels
vollabels = read_raw(volume.versioned_labels_path(labels_version))
logger.debug(f"{vollabels.shape=}")

vollabels_train = volume.train_partition.get_volume_partition(vollabels)
vollabels_val = volume.val_partition.get_volume_partition(vollabels)
logger.debug(f"{vollabels_train.shape=} {vollabels_val.shape=}")

del vollabels

# Data crop sequences

In [None]:
batch_size_per_replica = 8  
batch_size = batch_size_per_replica * (n_replicas := strategy.num_replicas_in_sync)

logger.info(f"{batch_size_per_replica=}\n{n_replicas=}\n{batch_size=}")

common_random_state = 143
# crop_shape = (256, 256, 1)  # multiple of 16 (requirement of a 4-level u-net)
crop_shape = (320, 320, 1)  # multiple of 16 (requirement of a 4-level u-net)

## Train

In [None]:
data = voldata_train
labels = vollabels_train
volume_shape = data.shape
labels_list = volume.metadata.labels

crop_seq_train = VolumeCropSequence(
    data_volume=data,
    labels_volume=labels,
    labels=labels_list,
    meta_crop_generator=MetaCrop3DGenerator(
        volume_shape=volume_shape,
        crop_shape=crop_shape,
        x0y0z0_generator=(
            grid_pos_gen := UniformGridPosition.build_from_volume_crop_shapes(
                volume_shape=volume_shape, 
                crop_shape=crop_shape,
                random_state=RandomState(common_random_state),
            )
        ),
        et_field=ET3DConstantEverywhere.build_no_displacement(grid_position_generator_=grid_pos_gen),
        gt_field=GTUniformEverywhere.build_2d(
            random_state=RandomState(common_random_state),
            grid_position_generator_=grid_pos_gen,
        ),
        vs_field=VSConstantEverywhere.build_no_shift(grid_position_generator_=grid_pos_gen),
    ),
    batch_size=batch_size,
    # this volume cropper only returns random crops, 
    #so the number of crops per epoch/batch is w/e i want
    epoch_size=1,
    meta_crops_hist_path=None,  # todo add a new path to the model and save this
    debug__no_data_check=True,  # remove me!
)

## Val

In [None]:
# val volume

data = voldata_val
labels = vollabels_val
volume_shape = data.shape
labels_list = volume.metadata.labels

crop_seq_val = VolumeCropSequence(
    # data source
    data_volume=data,
    labels_volume=labels,
    labels=labels_list,
    
    # data augmentation
    meta_crop_generator=MetaCrop3DGenerator(
        volume_shape=volume_shape,
        crop_shape=crop_shape,
        x0y0z0_generator=(
#             grid_pos_gen := SequentialGridPosition.build_min_overlap(
#                 volume_shape=volume_shape, crop_shape=crop_shape,
#             )
            grid_pos_gen := SequentialGridPosition.build_from_volume_crop_shapes(
                volume_shape=volume_shape, crop_shape=crop_shape,
                n_steps_x=2, n_steps_y=2, n_steps_z=200,
            )
        ),
        et_field=ET3DConstantEverywhere.build_no_displacement(grid_position_generator_=grid_pos_gen),
        gt_field=GTConstantEverywhere.build_gt2d_identity(grid_position_generator_=grid_pos_gen),
        vs_field=VSConstantEverywhere.build_no_shift(grid_position_generator_=grid_pos_gen),
    ),
    
    # others
    batch_size=batch_size,
    epoch_size=len(grid_pos_gen),  # go through all the crops in validation    
    meta_crops_hist_path=None,  # todo add a new path to the model and save this
    debug__no_data_check=True,  # remove me!
)

# Model

In [None]:
try:
    tomo2seg_model
except NameError:
    print("already deleted (:")
else:
    del tomo2seg_model

In [None]:
from cnn_segm import keras_custom_loss

In [None]:
model_master_name = "unet-2d-small"
model_version = "vanilla00"

In [None]:
model_factory_function = modular_unet.u_net
model_factory_kwargs = dict(
    input_shape = crop_shape,
    nb_filters_0 = 16,
)

try:
    tomo2seg_model
    
except NameError:
    
    tomo2seg_model = Tomo2SegModel(
        model_master_name, 
        model_version, 
        runid=runid,
        factory_function=model_factory_function,
        factory_kwargs=model_factory_kwargs,
    )
                
else:
    logger.warning("The model is already defined. To create a new one: `del tomo2seg_model`")

finally:
    
    logger.info(f"{tomo2seg_model=}")
    
    logger.info("Compiling model.")
    
    with strategy.scope():
        if not tomo2seg_model.autosaved_model_path.exists():
#             assert not tomo2seg_model.model_path.exists(), f"Please delete '{tomo2seg_model.model_path}' to resave it if you wish to regenerate it."
            model = model_factory_function(
                output_channels=len(volume.metadata.labels), 
                name=tomo2seg_model.name,
                **model_factory_kwargs
            )
        else:
            logger.warning("An autosaved model already exists, loading it instead of creating a new one!")
            model = keras.models.load_model(tomo2seg_model.autosaved_model_path_str, compile=False)
       
        model.compile(
            loss=keras_custom_loss.jaccard2_loss, 
            optimizer=optimizers.Adam(lr=.003)
        )
        model.save(tomo2seg_model.model_path)

In [None]:
# write the model summary in a file
with tomo2seg_model.summary_path.open("w") as f:
    def print_to_txt(line):
        f.writelines([line + "\n"])
    model.summary(print_fn=print_to_txt, line_length=140)
    
# same for the architecture
utils.plot_model(model, show_shapes=True, to_file=tomo2seg_model.architecture_plot_path);

logger.info(f"Check the summary and the figure of the model in the following locations:\n{tomo2seg_model.summary_path}\n{tomo2seg_model.architecture_plot_path}")

# Callbacks

In [None]:
autosave_cb = keras_callbacks.ModelCheckpoint(
    tomo2seg_model.autosaved_model_path_str, 
    monitor="val_loss", 
    verbose=2, 
    save_best_only=True, 
    mode="auto",
)

# todo load if it already exists
history_cb = tomo2seg_callbacks.History(
    optimizer=model.optimizer,
    backup=5,
    csv_path=tomo2seg_model.history_path,
)

# Summary before training

stuff that i use after the training but i want it to appear in the 


## Metadata

todo put this back to work

## Volume slices

todo put this back to work

## Generator samples

todo put this back to work

# Learning rate range test

todo put this back to work


# Training


## Learning rate test

In [None]:
def log_schedule_factory(start_pow10, stop_pow10, n_per_scale, wait, offset_epoch=0):
    """From 10 ** start_pow10 until 10 ** stop_pow10 with n_per_scale points between each scale of 10."""
    n = (n_per_scale + 1) * abs(stop_pow10 - start_pow10) + 1
    schedule = np.array(wait * [10 ** start_pow10])
    schedule = np.concatenate([schedule, np.logspace(start_pow10, stop_pow10, n)])
    logger.info(f"log schedule {n=} {wait=} {wait+n=}")
    def log_schedule(epoch, lr):
        epoch -= offset_epoch 
        if epoch >= schedule.shape[0]:
            return schedule[-1]
        return schedule[epoch]
    return log_schedule

In [None]:
# model = tf.keras.models.load_model(str(model_paths.autosaved_model_path) + ".hdf5")

In [None]:
from tensorflow.keras import backend as K

# lr = 0.001
# K.set_value(model.optimizer.learning_rate, lr)

lr_schedule_cb = keras_callbacks.LearningRateScheduler(
#     schedule=log_schedule_factory(-6, -2, 9, 9),
    schedule=log_schedule_factory(-2, -1, 13, 0, offset_epoch=61),
    verbose=2,
)

crop_seq_train.epoch_size = 10

callbacks = [
    autosave_cb,
    history_cb,
    keras_callbacks.TerminateOnNaN(),
    lr_schedule_cb
]

In [None]:
n_epochs = 15

model.fit(
    # data sequences
    x=crop_seq_train,
    validation_data=crop_seq_val,
    
    # epochs
#     initial_epoch=0,
#     epochs=n_epochs,
    initial_epoch=history_cb.last_epoch + 1,  # for some reason it is 0-starting and others 1-starting...
    epochs=history_cb.last_epoch + 1 + n_epochs,  
    
    # others
    callbacks=callbacks,  
    verbose=2,
    use_multiprocessing=False,   
);

## Decreasing learning rate

In [None]:
from tensorflow.keras import backend as K

lr_schedule_cb = keras_callbacks.LearningRateScheduler(
#     schedule=log_schedule_factory(-2, -4, 11, 0, offset_epoch=76),
    schedule=log_schedule_factory(-4, -5, 31, 0, offset_epoch=101),
    verbose=2,
)

crop_seq_train.epoch_size = 10

history_cb.optimizer = model.optimizer

callbacks = [
    autosave_cb,
    history_cb,
    keras_callbacks.TerminateOnNaN(),
    lr_schedule_cb
]

In [None]:
n_epochs = 33

model.fit(
    # data sequences
    x=crop_seq_train,
    validation_data=crop_seq_val,
    
    # epochs
#     initial_epoch=0,
#     epochs=n_epochs,
    initial_epoch=history_cb.last_epoch + 1,  # for some reason it is 0-starting and others 1-starting...
    epochs=history_cb.last_epoch + 1 + n_epochs,  
    
    # others
    callbacks=callbacks,  
    verbose=2,
    use_multiprocessing=False,   
);

# History

In [None]:
fig, axs = plt.subplots(nrows := 2, 1, figsize=(2*(sz := 5), nrows * sz), dpi=100)
fig.set_tight_layout(True)

hist_display = viz.TrainingHistoryDisplay(
    history_cb.history, 
    model_name=tomo2seg_model.name,
    loss_name=model.loss.__name__,
).plot(axs, with_lr=True)

axs[0].set_yscale("log")
axs[-1].set_yscale("log")

viz.mark_min_values(hist_display.ax_loss_, hist_display.plots_["loss"][0])
viz.mark_min_values(hist_display.ax_loss_, hist_display.plots_["val_loss"][0], txt_kwargs=dict(rotation=0))

hist_display.fig_.savefig(
    tomo2seg_model.model_path / (hist_display.title + ".png"),
    format='png',
)

In [None]:
history_cb.dataframe.to_csv(history_cb.csv_path, index=True)

In [None]:
model.save(tomo2seg_model.model_path)