# Imports


In [1]:
%load_ext autoreload

In [19]:
%autoreload 2

import functools
import logging
import operator
import socket
from typing import *

import humanize
import numpy as np
import tensorflow as tf
from tensorflow.keras import optimizers, utils

In [4]:
n_gpus = len(tf.config.list_physical_devices('GPU'))

tf_version = tf.__version__
logger.info(f"{tf_version=}")

hostname = socket.gethostname()
logger.info(
    f"Hostname: {hostname}\nNum GPUs Available: {n_gpus}\nThis should be:\n\t"
    + '\n\t'.join(['2 on R790-TOMO', '1 on akela', '1 on hathi', '1 on krilin'])
)

logger.debug(
    "physical GPU devices:\n\t"
    + "\n\t".join(map(str, tf.config.list_physical_devices('GPU'))) + "\n"
    + "logical GPU devices:\n\t"
    + "\n\t".join(map(str, tf.config.list_logical_devices('GPU')))
)

# xla auto-clustering optimization (see: https://www.tensorflow.org/xla#auto-clustering)
# this seems to break the training
tf.config.optimizer.set_jit(False)

# get a distribution strategy to use both gpus (see https://www.tensorflow.org/guide/distributed_training)
gpu_strategy = tf.distribute.MirroredStrategy()
logger.debug(f"{gpu_strategy=}")

INFO::tomo2seg::{<ipython-input-4-70ddb701656b>:<module>:004}::[2020-12-13::14:15:18.906]
tf_version='2.2.0'

INFO::tomo2seg::{<ipython-input-4-70ddb701656b>:<module>:007}::[2020-12-13::14:15:18.907]
Hostname: akela.materiaux.ensmp.fr
Num GPUs Available: 1
This should be:
	2 on R790-TOMO
	1 on akela
	1 on hathi
	1 on krilin

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


# Model

In [21]:
logger.setLevel(logging.DEBUG)

In [93]:
%autoreload 2

from tomo2seg import (
    losses as tomo2seg_losses,
    
    modular_unet,
    utils as tomo2seg_utils,
)
from tomo2seg.logger import logger

# 176/288 is the biggest that covers the whole validation shape
# it's kind of stupid that i'm limited by that...
crop_shape = (48, 48, 5)  # multiple of 16 (requirement of a 4-level u-net)

model_master_name = "unet2halfd"
model_version = "II-enc-scratch"

model_is_2halfd = True
model_is_2d = False

n_classes = 3

model_factory_function = modular_unet.u_net
# model_factory_function = modular_unet.u_net2halfd_IIenc
model_factory_kwargs = {
    **modular_unet.kwargs_IIenc03,
    **dict(
        convlayer=modular_unet.ConvLayer.conv2d,
        input_shape=crop_shape,
        output_channels=n_classes,
#         nb_filters_0=2,
        #         nb_filters_0 = 4,
                nb_filters_0 = 8,
#                 nb_filters_0 = 10,
#                 nb_filters_0 = 12,
#                 nb_filters_0 = 16,
        #         nb_filters_0 = 32,
    ),
}

logger.info("Creating the Keras model.")

# with gpu_strategy.scope():

model = model_factory_function(
    name="unet2halfd-better",
    **model_factory_kwargs
)

loss = tomo2seg_losses.jaccard2_flat
optimizer = optimizers.Adam(lr=.003)

model.compile(loss=loss, optimizer=optimizer)

utils.plot_model(model);

INFO::tomo2seg::{<ipython-input-93-0f0da792325c>:<module>:041}::[2020-12-13::16:07:43.479]
Creating the Keras model.



In [94]:
model.summary()

Model: "unet2halfd-better"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              [(None, 48, 48, 5)]  0                                            
__________________________________________________________________________________________________
enc-block-0-conv1 (Conv2D)      (None, 48, 48, 8)    368         input[0][0]                      
__________________________________________________________________________________________________
enc-block-0-conv1-bn (BatchNorm (None, 48, 48, 8)    32          enc-block-0-conv1[0][0]          
__________________________________________________________________________________________________
enc-block-0-conv1-relu (Activat (None, 48, 48, 8)    0           enc-block-0-conv1-bn[0][0]       
__________________________________________________________________________________

In [95]:
model_internal_nvoxel_factor = tomo2seg_utils.get_model_internal_nvoxel_factor(model)

logger.debug(f"{model_internal_nvoxel_factor=}")

DEBUG::tomo2seg::{utils.py:get_model_internal_nvoxel_factor:023}::[2020-12-13::16:07:46.925]
input_layer=<tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7f78b92d9610>

DEBUG::tomo2seg::{utils.py:get_model_internal_nvoxel_factor:029}::[2020-12-13::16:07:46.926]
input_nvoxels=11520

DEBUG::tomo2seg::{utils.py:get_model_internal_nvoxel_factor:041}::[2020-12-13::16:07:46.927]
max_internal_nvoxels=110592 (110,592)

DEBUG::tomo2seg::{<ipython-input-95-5282fa28b51e>:<module>:003}::[2020-12-13::16:07:46.928]
model_internal_nvoxel_factor=10

