# Imports


In [2]:
%load_ext autoreload

In [25]:
%autoreload 2

from dataclasses import dataclass, asdict, field
from enum import Enum
import functools
import operator
from functools import partial
import logging
import pathlib
from pathlib import Path
from pprint import pprint, PrettyPrinter
import sys
from typing import *
import time
import yaml
from yaml import YAMLObject
import socket

import humanize
from matplotlib import pyplot as plt, cm
import numpy as np
import pandas as pd
from pymicro.file import file_utils
import tensorflow as tf
from numpy.random import RandomState

from tensorflow import keras
from tensorflow.keras import utils
from tensorflow.keras import optimizers
from tensorflow.keras import callbacks as keras_callbacks
from tensorflow.keras import losses
from tensorflow.keras import metrics as keras_metrics

from tomo2seg import slack
from tomo2seg import modular_unet
from tomo2seg.logger import logger
from tomo2seg import data, viz
from tomo2seg.data import Volume
from tomo2seg.metadata import Metadata
from tomo2seg.volume_sequence import (
    MetaCrop3DGenerator, VolumeCropSequence,
    UniformGridPosition, SequentialGridPosition,
    ET3DUniformCuboidAlmostEverywhere, ET3DConstantEverywhere, 
    GTUniformEverywhere, GTConstantEverywhere, 
    VSConstantEverywhere, VSUniformEverywhere
)
from tomo2seg import volume_sequence
from tomo2seg.model import Model as Tomo2SegModel
from tomo2seg import callbacks as tomo2seg_callbacks
from tomo2seg import losses as tomo2seg_losses
from tomo2seg import schedule as tomo2seg_schedule
from tomo2seg import utils as tomo2seg_utils
from tomo2seg import slackme

In [4]:
# this registers a custom exception handler for the whole current notebook
get_ipython().set_custom_exc((Exception,), slackme.custom_exc)

# Args

In [16]:
@dataclass
class Args:

    class EarlyStopMode(Enum):
        no_early_stop = 0
    
    early_stop_mode: EarlyStopMode
    random_state_seed: int = 42
        
    runid: int = None
    is_continuation: bool = False
        
    def __post_init__(self):
        
        if self.is_continuation:
            assert self.runid is not None, f"Incompatible args {self.runid=} {self.is_continuation=}"
        
        if self.runid is None:
            self.runid = int(time.time())

# these are estimates based on things i've seen fit in the GPU
MAX_INTERNAL_NVOXELS = max(
    # seen cases
    4 * (8 * 6) * (96**3),
    8 * (16 * 6) * (320**2),  
    3 * (16 * 6) * (800 * 928),
)

MAX_INTERNAL_NVOXELS *= 5/8  # a smaller gpu on other pcs...

logger.info(f"{MAX_INTERNAL_NVOXELS=} ({humanize.intcomma(MAX_INTERNAL_NVOXELS)})")

# override_batch_size = None
# doing this to reproduce the same conditions...
override_batch_size_per_gpu = None  

# None: continue from the latest model
# 1: continue from model.autosaved_model_path
# 2: continue from model.autosaved2_model_path
# continue_from_autosave: Optional[int] = None 
    
args = Args(
    early_stop_mode = Args.EarlyStopMode.no_early_stop,
#     random_state_seed=30,  # I'll change it so we don't repeat the same crops from the begining
    runid = 1607698009,
#     is_continuation=True,
)

logger.info(f"args\n{PrettyPrinter(indent=4, compact=False).pformat(asdict(args))}")

INFO::tomo2seg::{<ipython-input-16-9c8f50b4cdd9>:<module>:031}::[2020-12-11::16:10:29.904]
MAX_INTERNAL_NVOXELS=133632000.0 (133,632,000.0)

INFO::tomo2seg::{<ipython-input-16-9c8f50b4cdd9>:<module>:049}::[2020-12-11::16:10:29.906]
args
{   'early_stop_mode': <EarlyStopMode.no_early_stop: 0>,
    'is_continuation': False,
    'random_state_seed': 42,
    'runid': 1607698009}




# Setup


In [7]:
logger.setLevel(logging.DEBUG)
random_state = np.random.RandomState(args.random_state_seed)

n_gpus = len(tf.config.list_physical_devices('GPU'))
    
tf_version = tf.__version__
logger.info(f"{tf_version=}")

hostname = socket.gethostname()
logger.info(
    f"Hostname: {hostname}\nNum GPUs Available: {n_gpus}\nThis should be:\n\t" + '\n\t'.join(['2 on R790-TOMO', '1 on akela', '1 on hathi', '1 on krilin'])
)

logger.debug(
    "physical GPU devices:\n\t" + "\n\t".join(map(str, tf.config.list_physical_devices('GPU'))) + "\n" +
    "logical GPU devices:\n\t" + "\n\t".join(map(stsr, tf.config.list_logical_devices('GPU'))) 
)

# xla auto-clustering optimization (see: https://www.tensorflow.org/xla#auto-clustering)
# this seems to break the training
tf.config.optimizer.set_jit(False)

# get a distribution strategy to use both gpus (see https://www.tensorflow.org/guide/distributed_training)
gpu_strategy = tf.distribute.MirroredStrategy()  
logger.debug(f"{gpu_strategy=}")

INFO::tomo2seg::{<ipython-input-7-050df0cd2614>:<module>:007}::[2020-12-11::16:06:22.938]
tf_version='2.2.0'

INFO::tomo2seg::{<ipython-input-7-050df0cd2614>:<module>:010}::[2020-12-11::16:06:22.940]
Hostname: R7920-tomo
Num GPUs Available: 2
This should be:
	2 on R790-TOMO
	1 on akela
	1 on hathi
	1 on krilin

DEBUG::tomo2seg::{<ipython-input-7-050df0cd2614>:<module>:014}::[2020-12-11::16:06:22.941]
physical GPU devices:
	PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')
	PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')
logical GPU devices:
	LogicalDevice(name='/device:GPU:0', device_type='GPU')
	LogicalDevice(name='/device:GPU:1', device_type='GPU')

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
DEBUG::tomo2seg::{<ipython-input-7-050df0cd2614>:<module>:025}::[2020-12-11::16:06:22.945]
gpu_strategy=<tensorflow.python.distribute.mirrored_strategy.Mirrore

# Data

In [8]:
from tomo2seg.datasets import (
#     VOLUME_COMPOSITE_V1 as VOLUME_NAME_VERSION,
#     VOLUME_COMPOSITE_V1_REDUCED as VOLUME_NAME_VERSION,
#     VOLUME_COMPOSITE_V1_LABELS_REFINED3 as LABELS_VERSION,
    VOLUME_FRACTURE00_SEGMENTED00 as VOLUME_NAME_VERSION,
    VOLUME_FRACTURE00_SEGMENTED00_LABELS_REFINED3 as LABELS_VERSION,
)

volume_name, volume_version = VOLUME_NAME_VERSION
labels_version = LABELS_VERSION

logger.info(f"{volume_name=}")
logger.info(f"{volume_version=}")
logger.info(f"{labels_version=}")

INFO::tomo2seg::{<ipython-input-8-151513cc1605>:<module>:012}::[2020-12-11::16:06:26.060]
volume_name='fracture00'

INFO::tomo2seg::{<ipython-input-8-151513cc1605>:<module>:013}::[2020-12-11::16:06:26.062]
volume_version='segmented00'

INFO::tomo2seg::{<ipython-input-8-151513cc1605>:<module>:014}::[2020-12-11::16:06:26.063]
labels_version='jordan'



In [24]:
# Metadata/paths objects

## Volume
volume = Volume.with_check(
    name=volume_name, version=volume_version
)

logger.info(f"args\n{PrettyPrinter(indent=4, compact=False).pformat(asdict(volume))}")

n_classes = len(volume.metadata.labels)

def _read_raw(path_: Path, volume_: Volume): 
    # from pymicro
    return 

read_raw = partial(_read_raw, volume_=volume)

logger.info("Loading data from disk.")

normalization_factor = volume_sequence.NORMALIZE_FACTORS[volume.metadata.dtype]

logger.debug(f"{normalization_factor=}")

## Data
voldata = file_utils.HST_read(
    str(volume.data_path),  # it doesn't accept paths...
    
    autoparse_filename=False,  # the file names are not properly formatted
    data_type=volume.metadata.dtype,
    dims=volume.metadata.dimensions,
    verbose=True,
) / normalization_factor # normalize

logger.debug(f"{voldata.shape=}")

voldata_train = volume.train_partition.get_volume_partition(voldata)
voldata_val = volume.val_partition.get_volume_partition(voldata)

logger.debug(f"{voldata_train.shape=}")
logger.debug(f"{voldata_val.shape=}")

del voldata

## Labels

vollabels = file_utils.HST_read(
    str(volume.versioned_labels_path(labels_version)),
    
    autoparse_filename=False,
    data_type="uint8",
    dims=volume.metadata.dimensions,
    verbose=True,
)
logger.debug(f"{vollabels.shape=}")

vollabels_train = volume.train_partition.get_volume_partition(vollabels)
vollabels_val = volume.val_partition.get_volume_partition(vollabels)

logger.debug(f"{vollabels_train.shape=}")
logger.debug(f"{vollabels_val.shape=}")

del vollabels

DEBUG::tomo2seg::{data.py:with_check:258}::[2020-12-11::16:19:09.170]
vol=Volume(name='fracture00', version='segmented00', _metadata=None)

DEBUG::tomo2seg::{data.py:metadata:195}::[2020-12-11::16:19:09.173]
Loading metadata from `/home/users/jcasagrande/projects/tomo2seg/data/fracture00.segmented00/fracture00.segmented00.metadata.yml`.

INFO::tomo2seg::{<ipython-input-24-c989932cb435>:<module>:008}::[2020-12-11::16:19:09.198]
args
{   '_metadata': {   'dimensions': [731, 301, 200],
                     'dtype': 'uint16',
                     'labels': [0, 1, 2, 3, 4],
                     'labels_names': {   0: 'exterior',
                                         1: 'inside',
                                         2: 'defect',
                                         3: 'porosity',
                                         4: 'crack'},
                     'set_partitions': {   'test': {   'alias': 'test',
                                                       'x_range': [450, 731],


# Model

In [10]:
try:
    tomo2seg_model
except NameError:
    print("already deleted (:")
else:
    del tomo2seg_model

already deleted (:


In [11]:
# 176/288 is the biggest that covers the whole validation shape
# it's kind of stupid that i'm limited by that...
crop_shape = (176, 288, 1)  # multiple of 16 (requirement of a 4-level u-net)

model_master_name = "unet2d-crack"
model_version = "00"

model_is_2halfd = False
model_is_2d = True

model_factory_function = modular_unet.u_net
model_factory_kwargs = {
    **modular_unet.kwargs_vanilla03,
    **dict(
        convlayer=modular_unet.ConvLayer.conv2d,
        input_shape = crop_shape,
        output_channels=n_classes,
#         nb_filters_0 = 2,
#         nb_filters_0 = 4,
#         nb_filters_0 = 8,
#         nb_filters_0 = 12,
        nb_filters_0 = 16,
#         nb_filters_0 = 32,
    ),
}

try:
    tomo2seg_model
    
except NameError:
    logger.info("Creating a Tomo2SegModel.")
    
    tomo2seg_model = Tomo2SegModel(
        model_master_name, 
        model_version, 
        runid=args.runid,
        factory_function=model_factory_function,
        factory_kwargs=model_factory_kwargs,
    )
                
else:
    logger.warning("The model is already defined. To create a new one: `del tomo2seg_model`")

finally:
    logger.info(f"args\n{PrettyPrinter(indent=4, compact=False).pformat(asdict(tomo2seg_model))}")    
    logger.info(f"{tomo2seg_model.name=}")

INFO::tomo2seg::{<ipython-input-11-d16e5795bcb5>:<module>:031}::[2020-12-11::16:06:51.503]
Creating a Tomo2SegModel.

INFO::tomo2seg::{<ipython-input-11-d16e5795bcb5>:<module>:045}::[2020-12-11::16:06:51.505]
args
{   'factory_function': 'tomo2seg.modular_unet.u_net',
    'factory_kwargs': {   'convlayer': <ConvLayer.conv2d: 0>,
                          'depth': 4,
                          'input_shape': (176, 288, 1),
                          'nb_filters_0': 16,
                          'output_channels': 5,
                          'sigma_noise': 0,
                          'unet_block_kwargs': {   'batch_norm': True,
                                                   'dropout': 0,
                                                   'kernel_size': 3,
                                                   'res': True},
                          'unet_down_kwargs': {'batchnorm': True},
                          'unet_up_kwargs': {'batchnorm': True},
                          'updown_c

In [13]:
logger.info("Creating the Keras model.")

with gpu_strategy.scope():
    
    if args.is_continuation:
        logger.warning("Training continuation: a model will be loaded.")

        if continue_from_autosave is None:
            logger.info("Using the LATEST model to continue the training.")
            load_model_path = tomo2seg_model.model_path
        
        elif continue_from_autosave == 1:
            logger.info("Using the AUTOSAVED model to continue the training.")
            load_model_path = tomo2seg_model.autosaved_model_path
        
        elif continue_from_autosave == 2:
            logger.info("Using the (best) AUTOSAVED2 model to continue the training.")
            load_model_path = tomo2seg_model.autosaved2_best_model_path
        
        else:
            raise ValueError(f"{continue_from_autosave=}")
        
    elif (
        tomo2seg_model.model_path.exists() or
        tomo2seg_model.autosaved_model_path.exists()
        # todo uncomment me when implemented
#             or tomo2seg_model.autosaved2_best_model_path.exists()
    ):
        logger.error(f"The model seems to already exist but this is not a continuation. Please, make sure the arguments are correct.")
        raise ValueError(f"{args.is_continuation=} {tomo2seg_model.name=}")
    
    else:
        logger.info(f"A new model will be instantiated!")
        
        
    if args.is_continuation:
        
        assert load_model_path.exists(), f"Inconsistent arguments {args.is_continuation=} {load_model_path=}."
        
        logger.info(f"Loading model {load_model_path.name}")
        
        model = keras.models.load_model(str(load_model_path), compile=False)

        assert model.name == tomo2seg_model.name, f"{model.name=} {tomo2seg_model.name=}"
        
    else:
        
        logger.info(f"Instantiating a new model with model_factory_function={model_factory_function.__name__}.")
      
        model = model_factory_function(
            name=tomo2seg_model.name,
            **model_factory_kwargs
        )

    logger.info("Compiling the model.")

    # using the avg jaccard is dangerous if one of the classes is too
    # underrepresented because it's jaccard will be unstable
    loss = tomo2seg_losses.jaccard2_flat
    optimizer = optimizers.Adam(lr=.003)
    metrics = [
#         tomo2seg_losses.jaccard2_macro_avg,
#         keras_metrics.Accuracy(),
#     ] + [
#         tomo2seg_losses.Jaccard2(class_idx)
#         for class_idx in range(n_classes)
    ]
    
    logger.debug(f"{loss=}")
    logger.debug(f"{optimizer=}")
    logger.debug(f"{metrics=}")
    
    model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
    

INFO::tomo2seg::{<ipython-input-13-c489e97597c1>:<module>:001}::[2020-12-11::16:08:19.565]
Creating the Keras model.

INFO::tomo2seg::{<ipython-input-13-c489e97597c1>:<module>:033}::[2020-12-11::16:08:19.606]
A new model will be instantiated!

INFO::tomo2seg::{<ipython-input-13-c489e97597c1>:<module>:048}::[2020-12-11::16:08:19.608]
Instantiating a new model with model_factory_function=u_net.

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/re

In [14]:
if not args.is_continuation:
    
    logger.info(f"Saving the model at {tomo2seg_model.model_path=}.")
    
    model.save(tomo2seg_model.model_path)

    logger.info(f"Writing the model summary at {tomo2seg_model.summary_path=}.")
    
    with tomo2seg_model.summary_path.open("w") as f:
        def print_to_txt(line):
            f.writelines([line + "\n"])
        model.summary(print_fn=print_to_txt, line_length=140)

    logger.info(f"Printing an image of the architecture at {tomo2seg_model.architecture_plot_path=}.")
    
    utils.plot_model(model, show_shapes=True, to_file=tomo2seg_model.architecture_plot_path);

INFO::tomo2seg::{<ipython-input-14-220d06205e36>:<module>:003}::[2020-12-11::16:08:32.002]
Saving the model at tomo2seg_model.model_path=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/models/unet2d-crack/unet2d-crack.00.fold000.1607-698-009').

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: /home/users/jcasagrande/projects/tomo2seg/data/models/unet2d-crack/unet2d-crack.00.fold000.1607-698-009/assets
INFO::tomo2seg::{<ipython-input-14-220d06205e36>:<module>:007}::[2020-12-11::16:08:50.102]
Writing the model summary at tomo2seg_model.summary_path=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/models/unet2d-crack/unet2d-crack.00.fold000.1607-698-009/summary.txt').

INFO::tomo2seg::{<ipython-input-14-220d06205e36>:<module>:014}::[2020-12-11::16:08:50.195]
Printing an image of the architecture at tomo2seg_model.architecture_plot_path=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/models/une

# Data crop sequences

## Batch size

In [17]:
model_internal_nvoxel_factor = tomo2seg_utils.get_model_internal_nvoxel_factor(model)

logger.debug(f"{model_internal_nvoxel_factor=}")

max_batch_nvoxels = int(np.floor(MAX_INTERNAL_NVOXELS / model_internal_nvoxel_factor))

logger.debug(f"{max_batch_nvoxels=} ({humanize.intcomma(max_batch_nvoxels)})")

crop_nvoxels = functools.reduce(operator.mul, crop_shape)

logger.debug(f"{crop_shape=} ==> {crop_nvoxels=}")

max_batch_size_per_gpu = batch_size_per_gpu = int(np.floor(max_batch_nvoxels / crop_nvoxels))

logger.info(f"{batch_size_per_gpu=}")

if override_batch_size_per_gpu is not None:
    
    assert override_batch_size_per_gpu > 0, f"{override_batch_size_per_gpu=}"
    
    batch_size_per_gpu = override_batch_size_per_gpu
    
    logger.warning(f"{override_batch_size_per_gpu=} given ==> replacing {batch_size_per_gpu=}")

logger.info(f"{n_gpus=}")

batch_size = batch_size_per_gpu * max(1, n_gpus)

logger.info(f"{batch_size=}")

common_random_state = 143

DEBUG::tomo2seg::{utils.py:get_model_internal_nvoxel_factor:023}::[2020-12-11::16:10:37.474]
input_layer=<tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7f06bc974a00>

DEBUG::tomo2seg::{utils.py:get_model_internal_nvoxel_factor:029}::[2020-12-11::16:10:37.475]
input_nvoxels=50688

DEBUG::tomo2seg::{utils.py:get_model_internal_nvoxel_factor:041}::[2020-12-11::16:10:37.477]
max_internal_nvoxels=4866048 (4,866,048)

DEBUG::tomo2seg::{<ipython-input-17-46fe188ba465>:<module>:003}::[2020-12-11::16:10:37.478]
model_internal_nvoxel_factor=96

DEBUG::tomo2seg::{<ipython-input-17-46fe188ba465>:<module>:007}::[2020-12-11::16:10:37.480]
max_batch_nvoxels=1392000 (1,392,000)

DEBUG::tomo2seg::{<ipython-input-17-46fe188ba465>:<module>:011}::[2020-12-11::16:10:37.481]
crop_shape=(176, 288, 1) ==> crop_nvoxels=50688

INFO::tomo2seg::{<ipython-input-17-46fe188ba465>:<module>:015}::[2020-12-11::16:10:37.482]
batch_size_per_gpu=27

INFO::tomo2seg::{<ipython-input-17-46fe188ba465>:<mod

## Common kwargs

In [26]:
metacrop_gen_common_kwargs = dict(
    crop_shape=crop_shape,
    common_random_state_seed=args.random_state_seed,
    is_2halfd=model_is_2halfd,
    gt_type=volume_sequence.GT2D if model_is_2d else volume_sequence.GT3D,
)

logger.debug(f"{metacrop_gen_common_kwargs=}")

vol_crop_seq_common_kwargs = dict(
    output_as_2d=model_is_2d,
    output_as_2halfd=model_is_2halfd,
    labels = volume.metadata.labels,

    # not automated...
#     debug__no_data_check=True,
)

logger.debug(f"{vol_crop_seq_common_kwargs=}")

DEBUG::tomo2seg::{<ipython-input-26-1625f0973e32>:<module>:008}::[2020-12-11::16:20:35.843]
metacrop_gen_common_kwargs={'crop_shape': (176, 288, 1), 'common_random_state_seed': 42, 'is_2halfd': False, 'gt_type': <enum 'GT2D'>}

DEBUG::tomo2seg::{<ipython-input-26-1625f0973e32>:<module>:019}::[2020-12-11::16:20:35.845]
vol_crop_seq_common_kwargs={'output_as_2d': True, 'output_as_2halfd': False, 'labels': [0, 1, 2, 3, 4]}



## Train

In [27]:
data = voldata_train
labels = vollabels_train

volume_shape = data.shape

>
/
/unet2d-crack/unet2d-crack.00.fold000.1607-698-009/
Name
Last Modified


crop_seq_train = VolumeCropSequence(
    data_volume=data,
    labels_volume=labels,
    
    batch_size=batch_size,
    
    meta_crop_generator=MetaCrop3DGenerator.build_setup_train00(
        volume_shape=volume_shape,
        **metacrop_gen_common_kwargs,
        data_original_dtype=volume.metadata.dtype,
    ),
    
    # this volume cropper only returns random crops, 
    # so the number of crops per epoch/batch is w/e i want
    epoch_size=10,
    
    **vol_crop_seq_common_kwargs,
)

INFO::tomo2seg::{volume_sequence.py:build_from_volume_crop_shapes:442}::[2020-12-11::16:20:37.028]
Built UniformGridPosition from volume_shape=(260, 301, 200) and crop_shape=(176, 288, 1) ==> {'x_range': (0, 85), 'y_range': (0, 14), 'z_range': (0, 200)}

DEBUG::tomo2seg::{volume_sequence.py:__post_init__:404}::[2020-12-11::16:20:37.029]
UniformGridPosition ==> npositions=238000 (238,000)

Initializing ET3DConstantEverywhere with a UniformGridPosition.
The {x, y, z}_range values will be overwritten.

Initializing GTUniformEverywhere with a UniformGridPosition.
The {x, y, z}_range values will be overwritten.

Initializing VSUniformEverywhere with a UniformGridPosition.
The {x, y, z}_range values will be overwritten.

DEBUG::tomo2seg::{volume_sequence.py:__post_init__:1364}::[2020-12-11::16:20:37.034]
Initializing VolumeCropSequence.

DEBUG::tomo2seg::{volume_sequence.py:__post_init__:1374}::[2020-12-11::16:20:37.035]
Checking values and labels consistency, this might be a bit slow.

DEBU

## Val

In [28]:
data = voldata_val
labels = vollabels_val

volume_shape = data.shape

# the validation has no reproducibility issues
# so let's push the GPUs (:
val_batch_size = max_batch_size_per_gpu * n_gpus

logger.debug(f"{val_batch_size=}")

grid_pos_gen = SequentialGridPosition.build_min_overlap(
    volume_shape=volume_shape, 
    crop_shape=crop_shape,
    # reduce the total number of crops
#         n_steps_x=11,
#         n_steps_y=11,
#         n_steps_z=8,
)

crop_seq_val = VolumeCropSequence(
    data_volume=data,
    labels_volume=labels,
    
    batch_size=val_batch_size,
    
    # go through all the crops in validation
    epoch_size=len(grid_pos_gen),      
    
    # data augmentation
    meta_crop_generator=MetaCrop3DGenerator.build_setup_val00(
        volume_shape=volume_shape,
        grid_pos_gen=grid_pos_gen,
        **metacrop_gen_common_kwargs,
    ),
    
    **vol_crop_seq_common_kwargs,
)

DEBUG::tomo2seg::{<ipython-input-28-b42244853bf8>:<module>:010}::[2020-12-11::16:21:14.441]
val_batch_size=54

INFO::tomo2seg::{volume_sequence.py:build_min_overlap:510}::[2020-12-11::16:21:14.443]
Building SequentialGridPosition with minimal overlap (smallest n_steps in each directions) n_steps={'n_steps_x': 2, 'n_steps_y': 2, 'n_steps_z': 200}.

INFO::tomo2seg::{volume_sequence.py:build_from_volume_crop_shapes:442}::[2020-12-11::16:21:14.444]
Built SequentialGridPosition from volume_shape=(190, 301, 200) and crop_shape=(176, 288, 1) ==> {'x_range': (0, 15), 'y_range': (0, 14), 'z_range': (0, 200)}

INFO::tomo2seg::{volume_sequence.py:__post_init__:490}::[2020-12-11::16:21:14.448]
The SequentialGridPosition has len(self.positions)=800 different positions (therefore crops).

Initializing ET3DConstantEverywhere with a SequentialGridPosition.
The {x, y, z}_range values will be overwritten.

Initializing GTConstantEverywhere with a SequentialGridPosition.
The {x, y, z}_range values will b

# Callbacks

In [29]:
autosave_cb = keras_callbacks.ModelCheckpoint(
    tomo2seg_model.autosaved2_model_path_str, 
    monitor="val_loss", 
    verbose=1, 
    save_best_only=True, 
    mode="min",
)

logger.debug(f"{autosave_cb=}")

DEBUG::tomo2seg::{<ipython-input-29-be8d862a0547>:<module>:009}::[2020-12-11::16:21:58.934]
autosave_cb=<tensorflow.python.keras.callbacks.ModelCheckpoint object at 0x7f05847ba250>



In [30]:
# this is important because sometimes i update things in the notebook
# so i need to make sure that the objects in the history cb are updated
try:
    history_cb
    
except NameError:
    logger.info("Creating a new history callback.")
    
    history_cb = tomo2seg_callbacks.History(
        optimizer=model.optimizer,
        crop_seq_train=crop_seq_train,
        crop_seq_val=crop_seq_val,
        backup=1,
        csv_path=tomo2seg_model.history_path,
    )
    
else:
    logger.warning("The history callback already exists!")
    
    history_df = history_cb.dataframe

    try:
        history_df_temp = pd.read_csv(tomo2seg_model.history_path)
        # keep the longest one
        history_df = history_df if history_df.shape[0] >= history_df_temp.shape[0] else history_df_temp
        del history_df_temp
    
    except FileNotFoundError:
        logger.info("History hasn't been saved yet.")
        
    except pd.errors.EmptyDataError:
        logger.info("History hasn't been saved yet.")
        
finally:
    # make sure the correct objects are linked 
    history_cb.optimizer = model.optimizer
    history_cb.crop_seq_train = crop_seq_train
    history_cb.crop_seq_val = crop_seq_val

logger.debug(f"{history_cb=}")

INFO::tomo2seg::{<ipython-input-30-e64fcd7ac3fc>:<module>:007}::[2020-12-11::16:22:00.843]
Creating a new history callback.

INFO::tomo2seg::{callbacks.py:__init__:051}::[2020-12-11::16:22:00.877]
Loading history from csv self.csv_path=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/models/unet2d-crack/unet2d-crack.00.fold000.1607-698-009/history.csv').

DEBUG::tomo2seg::{callbacks.py:__init__:071}::[2020-12-11::16:22:00.880]
History hasn't been saved yet.

DEBUG::tomo2seg::{<ipython-input-30-e64fcd7ac3fc>:<module>:040}::[2020-12-11::16:22:00.881]
history_cb=<tomo2seg.callbacks.History object at 0x7f05b0170a90>



In [31]:
logger.debug(f"{history_cb.dataframe.index.size=}")
logger.debug(f"{history_cb.last_epoch=}")

DEBUG::tomo2seg::{<ipython-input-31-81847b52d74e>:<module>:001}::[2020-12-11::16:22:06.924]
history_cb.dataframe.index.size=0

DEBUG::tomo2seg::{<ipython-input-31-81847b52d74e>:<module>:002}::[2020-12-11::16:22:06.926]
history_cb.last_epoch=0



In [32]:
history_plot_cb = tomo2seg_callbacks.HistoryPlot(
    history_callback=history_cb,
    save_path=tomo2seg_model.train_history_plot_wip_path
)
logger.debug(f"{history_plot_cb=}")

DEBUG::tomo2seg::{<ipython-input-32-6edf00a82883>:<module>:005}::[2020-12-11::16:22:18.154]
history_plot_cb=HistoryPlot(history_callback=<tomo2seg.callbacks.History object at 0x7f05b0170a90>, save_path=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/models/unet2d-crack/unet2d-crack.00.fold000.1607-698-009/train-hist-plot-wip.png'))



In [33]:
logger.info(f"Setting up early stop with {args.early_stop_mode=}")

if args.early_stop_mode == Args.EarlyStopMode.no_early_stop:
    pass

else:
    raise NotImplementedError(f"{args.early_stop_mode=}")
#     # todo modify the early stopping to take more conditions (don't stop too early before it doesnt break the jaccard2=.32)
#     early_stop_cb = keras_callbacks.EarlyStopping(  
#         monitor='val_loss', 
#         min_delta=.1 / 100, 
#         patience=50,
#         verbose=2, 
#         mode='auto',
#         baseline=.71,  # 0th-order classifier
#         restore_best_weights=False,
#     )

INFO::tomo2seg::{<ipython-input-33-1b2efaf987e7>:<module>:001}::[2020-12-11::16:22:28.714]
Setting up early stop with args.early_stop_mode=<EarlyStopMode.no_early_stop: 0>



# Summary before training

stuff that i use after the training but i want it to appear in the 


mode## Metadata

todo put this back to work

## Volume slices

todo do this in a notebook

## Generator samples

todo do this in a notebook


# Training


## Teeth log lr schedule

In [41]:
lr_schedule_cb = keras_callbacks.LearningRateScheduler(
    schedule=(
        schedule := tomo2seg_schedule.LogSpaceSchedule(offset_epoch=0, wait=0, start=-3, stop=-5, n_between_scales=50)
    ),
    verbose=2,
)

logger.info(f"{lr_schedule_cb.schedule.range=}")

INFO::tomo2seg::{schedule.py:__post_init__:071}::[2020-12-11::16:25:59.911]
LogSpaceSchedule ==> self.n=103

INFO::tomo2seg::{<ipython-input-41-b9dd8e0ae7a7>:<module>:008}::[2020-12-11::16:25:59.912]
lr_schedule_cb.schedule.range=(0, 103)



In [42]:
callbacks = [
    keras_callbacks.TerminateOnNaN(),
    autosave_cb,
    history_cb,
    history_plot_cb,
    lr_schedule_cb,
]

try:
    early_stop_cb

except NameError:
    pass

else:
    callbacks.append(early_stop_cb)

for cb in callbacks:
    logger.debug(f"using callback {cb.__class__.__name__}")

DEBUG::tomo2seg::{<ipython-input-42-c16af9d5fdc3>:<module>:019}::[2020-12-11::16:26:04.384]
using callback TerminateOnNaN

DEBUG::tomo2seg::{<ipython-input-42-c16af9d5fdc3>:<module>:019}::[2020-12-11::16:26:04.385]
using callback ModelCheckpoint

DEBUG::tomo2seg::{<ipython-input-42-c16af9d5fdc3>:<module>:019}::[2020-12-11::16:26:04.387]
using callback History

DEBUG::tomo2seg::{<ipython-input-42-c16af9d5fdc3>:<module>:019}::[2020-12-11::16:26:04.388]
using callback HistoryPlot

DEBUG::tomo2seg::{<ipython-input-42-c16af9d5fdc3>:<module>:019}::[2020-12-11::16:26:04.389]
using callback LearningRateScheduler



In [None]:
n_epochs = 103

model.fit(
    
    # data sequences
    x=crop_seq_train,
    validation_data=crop_seq_val,

    # epochs
    initial_epoch=0,
    epochs=n_epochs,
#     initial_epoch=history_cb.last_epoch + 1,  # for some reason it is 0-starting and others 1-starting...
#         epochs=history_cb.last_epoch + 1 + n_epochs,  

    # others
    callbacks=callbacks,  
    verbose=2,

    # todo change the volume sequence to dinamically load the volume
    # because it would allow me to pass just a path string therefore
    # making it serializible ==> i will be able to multithread (:
    use_multiprocessing=False,   
);

slack.notify_finished()


Epoch 00001: LearningRateScheduler reducing learning rate to 0.001.
Epoch 1/103
INFO:tensorflow:batch_all_reduce: 142 all-reduces with algorithm = nccl, num_packs = 1

Epoch 00001: val_loss improved from inf to 0.65024, saving model to /home/users/jcasagrande/projects/tomo2seg/data/models/unet2d-crack/unet2d-crack.00.fold000.1607-698-009/unet2d-crack.00.fold000.1607-698-009.autosaved.001-0.650245.hdf5
INFO::tomo2seg::{callbacks.py:on_epoch_end:110}::[2020-12-11::16:34:45.309]
Saving backup of the training history epoch=0 self.csv_path=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/models/unet2d-crack/unet2d-crack.00.fold000.1607-698-009/history.csv')

DEBUG::tomo2seg::{callbacks.py:on_epoch_end:128}::[2020-12-11::16:34:45.348]
epoch=0 is too early to plot something.

ERROR::tomo2seg::{callbacks.py:on_epoch_end:169}::[2020-12-11::16:34:45.444]
AssertionError occurred while trying to plot the history.
Traceback (most recent call last):
  File "/home/users/jcasagrande/projects

# History

In [None]:
fig, axs = plt.subplots(nrows := 2, ncols := 1, figsize=(2.5 * (sz := 5), nrows * sz), dpi=100)
fig.set_tight_layout(True)

hist_display = viz.TrainingHistoryDisplay(
    history_cb.history, 
    model_name=tomo2seg_model.name,
    loss_name=model.loss.__name__,
    x_axis_mode=(
        "epoch", "batch", "crop", "voxel", "time",
    ),
).plot(
    axs, 
    with_lr=True,
    metrics=(
        "loss", 
    ),
)

axs[0].set_yscale("log")
axs[-1].set_yscale("log")

viz.mark_min_values(hist_display.axs_metrics_[0], hist_display.plots_["loss"][0])
viz.mark_min_values(hist_display.axs_metrics_[0], hist_display.plots_["val_loss"][0], txt_kwargs=dict(rotation=0))

hist_display.fig_.savefig(
    tomo2seg_model.model_path / (hist_display.title + ".png"),
    format='png',
)
# plt.close()

In [None]:
history_cb.dataframe.to_csv(history_cb.csv_path, index=True)

In [None]:
model.save(tomo2seg_model.model_path)

In [None]:
this_nb_name = "train-04-tomo88.ipynb"
import os
this_dir = os.getcwd()
logger.warning(f"{this_nb_name=} {this_dir=}")

os.system(f"jupyter nbconvert {this_dir}/{this_nb_name} --output-dir {str(tomo2seg_model.model_path)} --to html")