In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

from functools import partial
import logging
import pathlib
from pathlib import Path
import pprint as pprint_module
from pprint import pprint
import sys
from typing import *
import time
import yaml
from yaml import YAMLObject
import copy
import functools
import itertools
import os

import humanize
from matplotlib import pyplot as plt, cm
import numpy as np
from numpy import ndarray
import pandas as pd
from pymicro.file import file_utils
import tensorflow as tf
from numpy.random import RandomState
from progressbar import progressbar as pbar
from enum import Enum
import re
from enum import Enum
from matplotlib import patches
import seaborn as sns
from typing import Type
from dataclasses import dataclass
import dataclasses

from tensorflow import keras
from tensorflow.keras import utils
from tensorflow.keras import optimizers
from tensorflow.keras import callbacks
from tensorflow.keras import losses
from tensorflow.keras import layers

from cnn_segm import keras_custom_loss

from tomo2seg import modular_unet
from tomo2seg.logger import logger
from tomo2seg import data, viz
from tomo2seg.data import Volume
from tomo2seg.metadata import Metadata
from tomo2seg.volume_sequence import (
    VolumeCropSequence, MetaCrop3DGenerator, VSConstantEverywhere, 
    GTConstantEverywhere, SequentialGridPosition, ET3DConstantEverywhere
)
from tomo2seg import volume_sequence
from tomo2seg.model import Model as Tomo2SegModel
from tomo2seg.data import EstimationVolume
from tomo2seg import viz
from tomo2seg import process

# Setup

In [17]:
logger.setLevel(logging.DEBUG)

random_state = 42
random_state = np.random.RandomState(random_state)

n_gpus = len(tf.config.list_physical_devices('GPU'))
    
logger.debug(f"{tf.__version__=}")
logger.info(f"Num GPUs Available: {n_gpus}\nThis should be 2 on R790-TOMO.")
logger.debug(f"Should return 2 devices...\n{tf.config.list_physical_devices('GPU')=}")
logger.debug(f"Should return 2 devices...\n{tf.config.list_logical_devices('GPU')=}")

# xla auto-clustering optimization (see: https://www.tensorflow.org/xla#auto-clustering)
# this seems to break the training
tf.config.optimizer.set_jit(False)

DEBUG::tomo2seg::{<ipython-input-17-bba693eaf462>:<module>:008}::[2020-12-08::15:14:30.411]
tf.__version__='2.2.0'

INFO::tomo2seg::{<ipython-input-17-bba693eaf462>:<module>:009}::[2020-12-08::15:14:30.413]
Num GPUs Available: 2
This should be 2 on R790-TOMO.

DEBUG::tomo2seg::{<ipython-input-17-bba693eaf462>:<module>:010}::[2020-12-08::15:14:30.414]
Should return 2 devices...
tf.config.list_physical_devices('GPU')=[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]

DEBUG::tomo2seg::{<ipython-input-17-bba693eaf462>:<module>:011}::[2020-12-08::15:14:30.416]
Should return 2 devices...
tf.config.list_logical_devices('GPU')=[LogicalDevice(name='/device:GPU:0', device_type='GPU'), LogicalDevice(name='/device:GPU:1', device_type='GPU')]



# Args

In [74]:
from tomo2seg.datasets import (
#     VOLUME_COMPOSITE_V1 as VOLUME_NAME_VERSION,
#     VOLUME_COMPOSITE_V1_REDUCED as VOLUME_NAME_VERSION,
#     VOLUME_COMPOSITE_NEIGHBOUR as VOLUME_NAME_VERSION,    
#     VOLUME_COMPOSITE_FLEX as VOLUME_NAME_VERSION,    
    VOLUME_COMPOSITE_BIAXE as VOLUME_NAME_VERSION,    
)

# runid = 1607343593
try:
    runid
except NameError:
    runid = int(time.time())

args = process.ProcessVolumeArgs(
    model_name="unet2d.vanilla03-f16.fold000.1606-505-109",
    model_type=process.ModelType.input2d, 
    
    volume_name=VOLUME_NAME_VERSION[0], 
    volume_version=VOLUME_NAME_VERSION[1], 
    
    partition_alias=None,
    
    cropping_strategy=process.CroppingStrategy.maximum_size_reduced_overlap, 
    aggregation_strategy=process.AggregationStrategy.average_probabilities, 
    
    runid=runid,
    probabilities_dtype = np.float16,
    
    opts=process.ProcessVolumeOpts(
        save_probas_by_class = False,
        debug__save_figs = True,
    ), 
)

In [5]:
tomo2seg_model = Tomo2SegModel.build_from_model_name(args.model_name)

volume = Volume.with_check(
    name=args.volume_name, version=args.volume_version
)

partition = volume[args.partition_alias] if args.partition_alias is not None else None

estimation_volume = EstimationVolume.from_objects(
    volume=volume, 
    model=tomo2seg_model, 
    set_partition=partition,
    runid=runid,
)

# this is informal metadata for human use
estimation_volume["aggregation_strategy"] = args.aggregation_strategy.name
estimation_volume["cropping_strategy"] = args.cropping_strategy.name
estimation_volume["probabilities_dtype"] = args.probabilities_dtype.__name__

DEBUG::tomo2seg::{data.py:with_check:258}::[2020-12-08::15:00:57.195]
vol=Volume(name='PA66GF30', version='biaxe', _metadata=None)

DEBUG::tomo2seg::{data.py:metadata:195}::[2020-12-08::15:00:57.197]
Loading metadata from `/home/users/jcasagrande/projects/tomo2seg/data/PA66GF30.biaxe/PA66GF30.biaxe.metadata.yml`.

DEBUG::tomo2seg::{data.py:metadata_path:328}::[2020-12-08::15:00:57.214]
Creating metadata file /home/users/jcasagrande/projects/tomo2seg/data/vol=PA66GF30.biaxe.set=whole-volume.model=unet2d.vanilla03-f16.fold000.1606-505-109.runid=1607-436-057/vol=PA66GF30.biaxe.set=whole-volume.model=unet2d.vanilla03-f16.fold000.1606-505-109.runid=1607-436-057.metadata.yml.

DEBUG::tomo2seg::{data.py:__setitem__:336}::[2020-12-08::15:00:57.239]
Writing to file self.metadata_path=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/vol=PA66GF30.biaxe.set=whole-volume.model=unet2d.vanilla03-f16.fold000.1606-505-109.runid=1607-436-057/vol=PA66GF30.biaxe.set=whole-volume.model=unet2d.vani

show inputs

In [6]:
logger.info(f"args\n{pprint_module.PrettyPrinter(indent=4, compact=False).pformat(dataclasses.asdict(args))}")
logger.info(f"{estimation_volume=}")
            
logger.debug(f"{volume=}")
logger.debug(f"{partition=}")
logger.debug(f"{tomo2seg_model=}")

if args.model_type == process.ModelType.input2halfd:
    raise NotImplementedError(f"{args.model_type=}")

INFO::tomo2seg::{<ipython-input-6-ce63d89898fb>:<module>:001}::[2020-12-08::15:00:57.382]
args
{   'aggregation_strategy': <AggregationStrategy.average_probabilities: 0>,
    'cropping_strategy': <CroppingStrategy.minimum_overlap: 1>,
    'model_name': 'unet2d.vanilla03-f16.fold000.1606-505-109',
    'model_type': <ModelType.input2d: 0>,
    'opts': {'debug__save_figs': True, 'save_probas_by_class': False},
    'partition_alias': None,
    'probabilities_dtype': <class 'numpy.float16'>,
    'runid': 1607436057,
    'volume_name': 'PA66GF30',
    'volume_version': 'biaxe'}

INFO::tomo2seg::{<ipython-input-6-ce63d89898fb>:<module>:002}::[2020-12-08::15:00:57.384]
estimation_volume=EstimationVolume(volume_fullname='PA66GF30.biaxe', model_name='unet2d.vanilla03-f16.fold000.1606-505-109', runid=1607436057, partition=None)

DEBUG::tomo2seg::{<ipython-input-6-ce63d89898fb>:<module>:004}::[2020-12-08::15:00:57.385]
volume=Volume(name='PA66GF30', version='biaxe', _metadata=Volume.Metadata(dimen

# Load

##### gpu distribution strategy

In [19]:
# get a distribution strategy to use both gpus (see https://www.tensorflow.org/guide/distributed_training)
# strategy = tf.distribute.MirroredStrategy()  

# there is a bug with MirroredStrategy when you model.predict() with batch_size=1
# https://docs.google.com/document/d/17X1CUvGtlio3pkbKFemSGbF2Qnn0vWAZfCLsgFPoOqg/edit?usp=sharing
one_device = tf.distribute.OneDeviceStrategy(device="/gpu:0")
# logger.info(f"Because {args.model_type=}, MirroredStrategy cannot be used. Switched to {strategy.__class__.__name__}.")
    
logger.debug(f"{one_device=}")

DEBUG::tomo2seg::{<ipython-input-19-64511e033ae3>:<module>:009}::[2020-12-08::15:14:43.177]
one_device=<tensorflow.python.distribute.one_device_strategy.OneDeviceStrategy object at 0x7f1978116ac0>



##### model

In [101]:
def get_model():
    
    logger.info(f"Loading model from autosaved file: {tomo2seg_model.autosaved_model_path.name}")
    
    model = tf.keras.models.load_model(
        tomo2seg_model.autosaved_model_path_str,
        compile=False
    )
    
    logger.debug("Changing the model's input type to accept any size of crop.")
    
    in_ = model.layers[0]
    in_shape = in_.input_shape[0]
    input_n_channels = in_shape[-1:]

    logger.debug(f"{input_n_channels=}")
    
    # make it capable of getting any dimension in the input
    anysize_input = layers.Input(
        shape=[None, None, None] + list(input_n_channels),
        name="input_any_image_size"
    )
    
    logger.debug(f"{anysize_input=}")
    
    model.layers[0] = anysize_input
    
    # todo keep this somewhere instead of copying and pasting
    optimizer = optimizers.Adam()
    loss_func = keras_custom_loss.jaccard2_loss

    model.compile(loss=loss_func, optimizer=optimizer)
    
    return model

with one_device.scope():
    logger.info(f"Loading model with {one_device.__class__.__name__}.")
    model = get_model()

INFO::tomo2seg::{<ipython-input-101-7454cecbba19>:<module>:037}::[2020-12-08::17:27:44.413]
Loading model with OneDeviceStrategy.

INFO::tomo2seg::{<ipython-input-101-7454cecbba19>:get_model:003}::[2020-12-08::17:27:44.415]
Loading model from autosaved file: unet2d.vanilla03-f16.fold000.1606-505-109.autosaved.hdf5

DEBUG::tomo2seg::{<ipython-input-101-7454cecbba19>:get_model:010}::[2020-12-08::17:27:47.706]
Changing the model's input type to accept any size of crop.

DEBUG::tomo2seg::{<ipython-input-101-7454cecbba19>:get_model:016}::[2020-12-08::17:27:47.708]
input_n_channels=(1,)

DEBUG::tomo2seg::{<ipython-input-101-7454cecbba19>:get_model:024}::[2020-12-08::17:27:47.711]
anysize_input=<tf.Tensor 'input_any_image_size_1:0' shape=(None, None, None, None, 1) dtype=float32>



##### data

In [25]:
def _read_raw(path_: Path, volume_: Volume): 
    # from pymicro
    return file_utils.HST_read(
        str(path_),  # it doesn't accept paths...
        # pre-loaded kwargs
        autoparse_filename=False,  # the file names are not properly formatted
        data_type=volume.metadata.dtype,
        dims=volume.metadata.dimensions,
        verbose=True,
    )

read_raw = partial(_read_raw, volume_=volume)

logger.info(f"Loading data from disk at file: {volume.data_path.name}")

voldata = read_raw(volume.data_path) / 255  # normalize

logger.debug(f"{voldata.shape=}")

if partition is not None:
    logger.debug(f"Cutting data with {partition.alias=}")
    data_volume = partition.get_volume_partition(voldata)

else:
    logger.debug(f"No partition.")
    data_volume = voldata

del voldata

logger.debug(f"{data_volume.shape=}")
logger.debug(f"{data_volume.size=}  ({humanize.intword(data_volume.size)})")

INFO::tomo2seg::{<ipython-input-25-777069a45709>:<module>:014}::[2020-12-08::15:18:35.862]
Loading data from disk.

data type is uint8
volume size is 1579 x 1845 x 2002
reading volume... from byte 0
DEBUG::tomo2seg::{<ipython-input-25-777069a45709>:<module>:018}::[2020-12-08::15:19:20.623]
voldata.shape=(1579, 1845, 2002)

DEBUG::tomo2seg::{<ipython-input-25-777069a45709>:<module>:025}::[2020-12-08::15:19:20.629]
No partition.

DEBUG::tomo2seg::{<ipython-input-25-777069a45709>:<module>:030}::[2020-12-08::15:19:20.631]
data_volume.shape=(1579, 1845, 2002)

DEBUG::tomo2seg::{<ipython-input-25-777069a45709>:<module>:031}::[2020-12-08::15:19:20.632]
data_volume.size=5832336510



# Processing

In [77]:
if args.opts.debug__save_figs:
    figs_dir = estimation_volume.dir
    
    logger.debug(f"{figs_dir=}")
    figs_dir.mkdir(exist_ok=True)
    
volume_shape = data_volume.shape

logger.info(f"{volume_shape=}")

DEBUG::tomo2seg::{<ipython-input-77-789dccd181b8>:<module>:004}::[2020-12-08::16:51:51.707]
figs_dir=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/vol=PA66GF30.biaxe.set=whole-volume.model=unet2d.vanilla03-f16.fold000.1606-505-109.runid=1607-436-057')

INFO::tomo2seg::{<ipython-input-77-789dccd181b8>:<module>:009}::[2020-12-08::16:51:51.708]
volume_shape=(1579, 1845, 2002)



## Shapes

In [85]:
MAX_N_VOXELS = 6 * 8 * 4 * (96 ** 3)  # estimation from the number of voxels i know that can fit in the GPU from training unet3d.crop96
MULTIPLE_REQUIREMENT = 16

logger.info(f"{MAX_N_VOXELS=} ({humanize.intcomma(MAX_N_VOXELS)})")
logger.info(f"{MULTIPLE_REQUIREMENT=}")

if args.cropping_strategy == process.CroppingStrategy.maximum_size:
    crop_dims_multiple_16 = process.get_largest_crop_multiple(
        volume_shape, 
        multiple_of=MULTIPLE_REQUIREMENT
    )

elif args.cropping_strategy == process.CroppingStrategy.maximum_size_reduced_overlap:
    # it's not necessarily the real minimum, just an easy way to get a big crop with less overlap
    # get the largest multiple of the requirement above the dimension size / 2
    # that will give a max overlap of 2 * MULTIPLE_REQUIREMENT - 1
    # e.g. with MULTIPLE_REQUIREMENT = 16, the maximum overlap is 31
    crop_dims_multiple_16 = tuple(
        (1 + int((dim / 2) // MULTIPLE_REQUIREMENT)) * MULTIPLE_REQUIREMENT if dim % MULTIPLE_REQUIREMENT != 0 else
        dim
        for dim in volume_shape
    )
    
    logger.info(f"the max overlap in each direction will be {tuple(int(2 * MULTIPLE_REQUIREMENT - s % MULTIPLE_REQUIREMENT) for s in volume_shape)}")
    
else:
    raise ValueError(f"{args.cropping_strategy=}")

logger.debug(f"{crop_dims_multiple_16=} using {args.cropping_strategy=}")

# it has to be multiple of 16 because of the 4 cascaded 2x2-strided 2x2-downsamplings in u-net
if args.model_type == process.ModelType.input2d:
    crop_shape = (
        crop_dims_multiple_16[0],
        crop_dims_multiple_16[1],
        1,
    )

elif args.model_type == process.ModelType.input2halfd:
    raise NotImplemented()
    
elif args.model_type == process.ModelType.input3d:
    crop_shape = crop_dims_multiple_16

logger.debug(f"ideal {crop_shape=} for {args.model_type=} now let's see if the maximum number of voxels is ok...")

crop_shape = process.reduce_dimensions(
    crop_shape,
    max_nvoxels=MAX_N_VOXELS,
    multiple_of=MULTIPLE_REQUIREMENT,
)
    
logger.info(f"{crop_shape=}")

INFO::tomo2seg::{<ipython-input-85-887b344415fa>:<module>:004}::[2020-12-08::16:57:59.069]
MAX_N_VOXELS=169869312 (169,869,312)

INFO::tomo2seg::{<ipython-input-85-887b344415fa>:<module>:005}::[2020-12-08::16:57:59.071]
MULTIPLE_REQUIREMENT=16

INFO::tomo2seg::{<ipython-input-85-887b344415fa>:<module>:024}::[2020-12-08::16:57:59.072]
the max overlap in each direction will be (21, 27, 30)

DEBUG::tomo2seg::{<ipython-input-85-887b344415fa>:<module>:029}::[2020-12-08::16:57:59.073]
crop_dims_multiple_16=(800, 928, 1008) using args.cropping_strategy=<CroppingStrategy.maximum_size_reduced_overlap: 1>

DEBUG::tomo2seg::{<ipython-input-85-887b344415fa>:<module>:045}::[2020-12-08::16:57:59.075]
ideal crop_shape=(800, 928, 1) for args.model_type=<ModelType.input2d: 0> now let's see if the maximum number of voxels is ok...

INFO::tomo2seg::{<ipython-input-85-887b344415fa>:<module>:053}::[2020-12-08::16:57:59.076]
crop_shape=(800, 928, 1)



## Steps and coordinates

In [87]:
n_steps = tuple(
    int(np.ceil(vol_dim / crop_dim))
    for vol_dim, crop_dim in zip(volume_shape, crop_shape)
)

logger.debug(f"{n_steps=}")

def get_coordinates_iterator(n_steps_):
    assert len(n_steps_) == 3
    return itertools.product(*(range(n_steps_[dim]) for dim in range(3)))

get_ijk_iterator = functools.partial(
    get_coordinates_iterator, copy.copy(n_steps)
)

get_kji_iterator = functools.partial(
    get_coordinates_iterator, tuple(reversed(n_steps))
)

# coordinates (xs, ys, and zs) of the front upper left corners of the crops
x0s, y0s, z0s = tuple(
    tuple(map(
        int, 
        np.linspace(0, vol_dim - crop_dim, n)
    ))
    for vol_dim, crop_dim, n in zip(volume_shape, crop_shape, n_steps)
)
logger.debug(f"{min(x0s)=}, {max(x0s)=}, {len(x0s)=}")
logger.debug(f"{min(y0s)=}, {max(y0s)=}, {len(y0s)=}")
logger.debug(f"{min(z0s)=}, {max(z0s)=}, {len(z0s)=}")

DEBUG::tomo2seg::{<ipython-input-87-9002bd14f40b>:<module>:006}::[2020-12-08::16:59:18.107]
n_steps=(2, 2, 2002)

DEBUG::tomo2seg::{<ipython-input-87-9002bd14f40b>:<module>:028}::[2020-12-08::16:59:18.112]
min(x0s)=0, max(x0s)=779, len(x0s)=2

DEBUG::tomo2seg::{<ipython-input-87-9002bd14f40b>:<module>:029}::[2020-12-08::16:59:18.113]
min(y0s)=0, max(y0s)=917, len(y0s)=2

DEBUG::tomo2seg::{<ipython-input-87-9002bd14f40b>:<module>:030}::[2020-12-08::16:59:18.115]
min(z0s)=0, max(z0s)=2001, len(z0s)=2002



### crops coordinates 

In [90]:
logger.debug("Generating the crop coordinates.")

crops_coordinates = np.array(
    [
        (
            (x0, x0 + crop_shape[0]), 
            (y0, y0 + crop_shape[1]),
            (z0, z0 + crop_shape[2]),
        )
        for x0, y0, z0 in itertools.product(x0s, y0s, z0s)
    ], 
    dtype=tuple
).reshape(len(x0s), len(y0s), len(z0s), 3, 2).astype(int)  # 3 = nb of dimenstions, 2 = (start, end)

logger.debug(f"{crops_coordinates.shape=}")

# 'F' reshapes with x varying fastest and z slowest
crops_coordinates_sequential = crops_coordinates.reshape(-1, 3, 2, order='F')  

logger.debug(f"{crops_coordinates_sequential.shape=}")

DEBUG::tomo2seg::{<ipython-input-90-81c28bea32a8>:<module>:001}::[2020-12-08::17:02:36.493]
Generating the crop coordinates.

DEBUG::tomo2seg::{<ipython-input-90-81c28bea32a8>:<module>:015}::[2020-12-08::17:02:36.567]
crops_coordinates.shape=(2, 2, 2002, 3, 2)

DEBUG::tomo2seg::{<ipython-input-90-81c28bea32a8>:<module>:020}::[2020-12-08::17:02:36.569]
crops_coordinates_sequential.shape=(8008, 3, 2)



## debug

### orthogonal slices plot

In [89]:
if args.opts.debug__save_figs:
    
    fig, axs = plt.subplots(2, 2, figsize=(sz := 15, sz), dpi=120)
    fig.set_tight_layout(True)
    
    display = viz.OrthogonalSlicesDisplay(
        volume=data_volume,
        volume_name=volume.fullname,
    ).plot(axs=axs,)
    
    logger.info(f"Saving figure {(figname := display.title + '.png')=}")
    display.fig_.savefig(
        fname=figs_dir / figname,
        dpi=200, format="png",
        metadata=display.metadata,
    )
    plt.close()

INFO::tomo2seg::{<ipython-input-89-ae231ce72243>:<module>:011}::[2020-12-08::17:01:07.396]
Saving figure (figname := display.title + '.png')='PA66GF30.biaxe.orthogonal-slices-display.x=789-y=922-z=1001.png'



### Segment an example

In [93]:
crop_ijk = (0, 0, 0)
i, j, k = crop_ijk
crop_coords = crops_coordinates[i, j, k]

logger.info(f"Segmenting one crop for debug {crop_ijk=}")

slice3d = tuple(slice(*coords_) for coords_ in crop_coords)
crop_data = data_volume[slice3d]
del slice3d
    
logger.debug(f"{crop_data.shape=}")

# [model] - i call it with a first crop bc if something goes wrong then the error
# will appear here instead of in a loop

# modelin
modelin_target_shape = (1, crop_shape[0], crop_shape[1], crop_shape[2], 1)

logger.debug(f"{modelin_target_shape=}")

modelin = crop_data.reshape(modelin_target_shape) 

# modelout
modelout = model.predict(
    modelin, 
    batch_size=1,
    steps=1,
    verbose=2,
)

logger.debug(f"{modelout.shape=}")

n_classes = modelout.shape[-1]

assert n_classes == len(volume.metadata.labels), f"{n_classes=} {len(volume.metadata.labels)=}"

# probas
crop_probas_target_shape = list(crop_shape) + [n_classes]

logger.debug(f"{crop_probas_target_shape=}")

crop_probas = modelout.reshape(crop_probas_target_shape).astype(args.probabilities_dtype)

logger.debug(f"{crop_probas.shape=}")
logger.debug(f"{crop_probas.dtype=}")

# preds
crop_preds = crop_probas.argmax(axis=-1).astype(np.int8)

logger.debug(f"{crop_preds.shape=}")
logger.debug(f"{crop_preds.dtype=}")

INFO::tomo2seg::{<ipython-input-93-1c353b778d0a>:<module>:005}::[2020-12-08::17:11:38.336]
Segmenting one crop for debug crop_ijk=(0, 0, 0)

DEBUG::tomo2seg::{<ipython-input-93-1c353b778d0a>:<module>:011}::[2020-12-08::17:11:38.338]
crop_data.shape=(800, 928, 1)

DEBUG::tomo2seg::{<ipython-input-93-1c353b778d0a>:<module>:019}::[2020-12-08::17:11:38.339]
modelin_target_shape=(1, 800, 928, 1, 1)

1/1 - 0s
DEBUG::tomo2seg::{<ipython-input-93-1c353b778d0a>:<module>:031}::[2020-12-08::17:11:38.626]
modelout.shape=(1, 800, 928, 3)

DEBUG::tomo2seg::{<ipython-input-93-1c353b778d0a>:<module>:035}::[2020-12-08::17:11:38.627]
n_classes=3

DEBUG::tomo2seg::{<ipython-input-93-1c353b778d0a>:<module>:040}::[2020-12-08::17:11:38.628]
crop_probas_target_shape=[800, 928, 1, 3]

DEBUG::tomo2seg::{<ipython-input-93-1c353b778d0a>:<module>:044}::[2020-12-08::17:11:38.644]
crop_probas.shape=(800, 928, 1, 3)

DEBUG::tomo2seg::{<ipython-input-93-1c353b778d0a>:<module>:045}::[2020-12-08::17:11:38.666]
crop_pro

In [95]:
if args.opts.debug__save_figs:
    fig, axs = plt.subplots(
        nrows=3, ncols=2,
        figsize=(2 * (sz := 20), sz), 
        dpi=120,
    )

    display = viz.OrthogonalSlicesPredictionDisplay(
        volume_data=crop_data,
        volume_prediction=crop_preds,
        n_classes=n_classes,
        volume_name=volume.fullname + f".debug.crop-{crop_ijk=}",
    ).plot(axs=axs,)

    logger.info(f"Saving figure {(figname := display.title + '.png')=}")
    display.fig_.savefig(
        fname=figs_dir / figname,
        format="png",
        metadata=display.metadata,
    )       
    plt.close()

INFO::tomo2seg::{<ipython-input-95-ff0f199fac4c>:<module>:015}::[2020-12-08::17:13:05.176]
Saving figure (figname := display.title + '.png')='PA66GF30.biaxe.debug.crop-crop_ijk=(0, 0, 0).orthogonal-slices-display.x=400-y=464-z=0.png'



### Segment a batch

In [96]:
logger.info("Segmenting a batch for debug.")

INFO::tomo2seg::{<ipython-input-96-63e5c06b8645>:<module>:001}::[2020-12-08::17:23:34.504]
Segmenting a batch for debug.



In [103]:
batch_size = n_gpus

logger.debug(f"{batch_size=}")

DEBUG::tomo2seg::{<ipython-input-103-3ec5bded21f3>:<module>:003}::[2020-12-08::17:29:10.268]
batch_size=2



In [102]:
mirror = tf.distribute.MirroredStrategy()

with mirror.scope():
    logger.info(f"Loading model with {mirror.__class__.__name__}.")
    model = get_model()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
INFO::tomo2seg::{<ipython-input-102-6e906eab95a7>:<module>:004}::[2020-12-08::17:28:27.651]
Loading model with MirroredStrategy.

INFO::tomo2seg::{<ipython-input-101-7454cecbba19>:get_model:003}::[2020-12-08::17:28:27.652]
Loading model from autosaved file: unet2d.vanilla03-f16.fold000.1606-505-109.autosaved.hdf5

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0

In [110]:
batch_coords = crops_coordinates_sequential[:batch_size]

logger.debug(f"{batch_coords.shape=}")

batch_slices = [
    tuple(slice(*coords_) for coords_ in crop_coords)
    for crop_coords in batch_coords
]

logger.debug(f"{batch_slices=}")

batch_data = np.stack([
    data_volume[slice_]
    for slice_ in batch_slices
], axis=0)

logger.debug(f"{batch_data.shape=}")

# [model] - now i call it with a first the mirror strategy to make sure it wont break

# modelin
modelin_target_shape = (batch_size, crop_shape[0], crop_shape[1], crop_shape[2], 1)  # adjust nb. channels

logger.debug(f"{modelin_target_shape=}")

modelin = batch_data.reshape(modelin_target_shape) 

# modelout
modelout = model.predict(
    modelin, 
    batch_size=batch_size,
    steps=1,
    verbose=2,
)

logger.debug(f"{modelout.shape=}")

DEBUG::tomo2seg::{<ipython-input-110-e239b59957c3>:<module>:003}::[2020-12-08::17:36:42.275]
batch_coords.shape=(2, 3, 2)

DEBUG::tomo2seg::{<ipython-input-110-e239b59957c3>:<module>:010}::[2020-12-08::17:36:42.278]
batch_slices=[(slice(0, 800, None), slice(0, 928, None), slice(0, 1, None)), (slice(779, 1579, None), slice(0, 928, None), slice(0, 1, None))]

DEBUG::tomo2seg::{<ipython-input-110-e239b59957c3>:<module>:017}::[2020-12-08::17:36:42.288]
batch_data.shape=(2, 800, 928, 1)

DEBUG::tomo2seg::{<ipython-input-110-e239b59957c3>:<module>:024}::[2020-12-08::17:36:42.289]
modelin_target_shape=(2, 800, 928, 1, 1)

1/1 - 0s
DEBUG::tomo2seg::{<ipython-input-110-e239b59957c3>:<module>:036}::[2020-12-08::17:43:47.775]
modelout.shape=(2, 800, 928, 3)



In [113]:
# probas
batch_probas_target_shape = [batch_size] + list(crop_shape) + [n_classes]

logger.debug(f"{batch_probas_target_shape=}")

batch_probas = modelout.reshape(batch_probas_target_shape).astype(args.probabilities_dtype)

logger.debug(f"{batch_probas.shape=}")
logger.debug(f"{batch_probas.dtype=}")

# preds
batch_preds = batch_probas.argmax(axis=-1).astype(np.int8)

logger.debug(f"{batch_preds.shape=}")
logger.debug(f"{batch_preds.dtype=}")

DEBUG::tomo2seg::{<ipython-input-113-d17f03134596>:<module>:004}::[2020-12-08::17:48:34.574]
batch_probas_target_shape=[2, 800, 928, 1, 3]

DEBUG::tomo2seg::{<ipython-input-113-d17f03134596>:<module>:008}::[2020-12-08::17:48:34.648]
batch_probas.shape=(2, 800, 928, 1, 3)

DEBUG::tomo2seg::{<ipython-input-113-d17f03134596>:<module>:009}::[2020-12-08::17:48:34.650]
batch_probas.dtype=dtype('float16')

DEBUG::tomo2seg::{<ipython-input-113-d17f03134596>:<module>:014}::[2020-12-08::17:48:34.718]
batch_preds.shape=(2, 800, 928, 1)

DEBUG::tomo2seg::{<ipython-input-113-d17f03134596>:<module>:015}::[2020-12-08::17:48:34.719]
batch_preds.dtype=dtype('int8')



### segment the largest `batch_size`

In [None]:
MAX_N_VOXELS

In [None]:
batch_size_increased = batch_size + 2 * n_gpus

logger.debug(f"{batch_size_increased=}")

batch_coords = crops_coordinates_sequential[:batch_size_increased]
batch_slices = [
    tuple(slice(*coords_) for coords_ in crop_coords)
    for crop_coords in batch_coords
]
batch_data = np.stack([data_volume[slice_] for slice_ in batch_slices], axis=0)
# [model]
# modelin
modelin_target_shape = (batch_size_increased, crop_shape[0], crop_shape[1], crop_shape[2], 1)  # adjust nb. channels
modelin = batch_data.reshape(modelin_target_shape) 
# modelout
modelout = model.predict(
    modelin, 
    batch_size=batch_size_increased,
    steps=1,
    verbose=2,
)

DEBUG::tomo2seg::{<ipython-input-118-154069748d84>:<module>:003}::[2020-12-08::18:08:27.485]
batch_size_increased=6



# Rebuild the volume

In [None]:
proba_volume_target_shape = list(volume_shape) + [n_classes]

logger.debug(f"{proba_volume_target_shape=}")

proba_volume = np.zeros(proba_volume_target_shape, dtype=opts.probabilities_dtype)

logger.debug(f"{proba_volume.shape=}")

redundancies_count_target_shape = volume_shape

logger.debug(f"{redundancies_count_target_shape=}")

redundancies_count = np.zeros(redundancies_count_target_shape, dtype=np.int8)  # only one channel

logger.debug(f"{redundancies_count.shape=}")

n_iterations = n_steps[0] * n_steps[1] * n_steps[2]

logger.debug(f"{n_iterations=}")

logger.debug("Predicting and summing up the crops' probabilities.")
for coord in pbar(crops_coordinates_sequential, prefix="predict-and-sum-probas", max_value=n_iterations):
    # [model]
    slice3d = tuple(slice(*coords_) for coords_ in coord)
    crop_data = data_volume[slice3d]
    modelin = crop_data.reshape(modelin_target_shape)
    modelout = model.predict(modelin, batch_size=1, steps=1)
    proba_volume[slice3d] += modelout.astype(opts.probabilities_dtype).reshape(crop_probas_target_shape)
    redundancies_count[slice3d] += np.ones(crop_shape, dtype=np.int)

In [None]:
del voldata

In [None]:
del data_volume

##### sanity checks

In [None]:
# check that the min and max probas are coherent with the min/max redundancy
min_proba_sum = proba_volume.min(axis=0).min(axis=0).min(axis=0)
max_proba_sum = proba_volume.max(axis=0).max(axis=0).max(axis=0)
min_redundancy = np.min(redundancies_count)
max_redundancy = np.max(redundancies_count)

In [None]:
assert min_redundancy >= 1, f"{min_redundancy=}"
assert np.all(min_proba_sum >= 0), f"{min_proba_sum=}"
assert np.all(max_proba_sum <= max_redundancy), f"{max_proba_sum=} {max_redundancy=}"

## Normalize probas

In [None]:
# divide each probability channel by the number of times it was summed (avg proba)
logger.debug(f"Dividing probability redundancies.")
for klass_idx in pbar(range(n_classes), max_value=n_classes, prefix="redundancies-per-class"):
    proba_volume[:, :, :, klass_idx] = proba_volume[:, :, :, klass_idx] / redundancies_count

In [None]:
del redundancies_count

In [None]:
import gc

In [None]:
gc.collect()

In [None]:
# this makes it more stable so that the sum is 1
proba_volume[:, :, :] /= proba_volume[:, :, :].sum(axis=-1, keepdims=True) 

##### sanity checks

In [None]:
# check that proba distribs sum to 1
min_proba = proba_volume.min(axis=0).min(axis=0).min(axis=0)
max_proba = proba_volume.max(axis=0).max(axis=0).max(axis=0)

In [None]:
assert np.all(min_proba >= 0), f"{min_proba=}"
assert np.all(max_proba <= 1), f"{max_proba=}"

In [None]:
min_distrib_proba_sum = proba_volume.sum(axis=-1).min()
max_distrib_proba_sum = proba_volume.sum(axis=-1).max()

In [None]:
assert np.isclose(min_distrib_proba_sum, 1, atol=.001), f"{min_distrib_proba_sum=}"
assert np.isclose(max_distrib_proba_sum, 1, atol=.001), f"{max_distrib_proba_sum=}"

# proba 2 pred

In [None]:
gc.collect()

In [None]:
pred_volume = np.empty(proba_volume.shape[:3], dtype="uint8")

In [None]:
np.argmax(proba_volume, axis=-1, out=pred_volume)

logger.debug(f"{pred_volume.shape=}")
logger.debug(f"{pred_volume.min()=}")
logger.debug(f"{pred_volume.max()=}")

# Save

In [None]:
logger.debug(f"Writing probabilities on disk at `{estimation_volume.probabilities_path}`")
np.save(estimation_volume.probabilities_path, proba_volume)

In [None]:
if opts.save_probas_by_class:
    for klass_idx in volume.metadata.labels:
        logger.debug(f"Writing probabilities of class `{klass_idx}` on disk at `{(str_path := str(estimation_volume.get_class_probability_path(klass_idx)))=}`")
        file_utils.HST_write(proba_volume[:, :, :, klass_idx], str_path)

In [None]:
logger.debug(f"Writing predictions on disk at `{(str_path := str(estimation_volume.predictions_path))}`")
file_utils.HST_write(pred_volume, str_path)

#### one-z-slice-crops-locations.png

not kept, search fro `one-z-slice-crops-locations.png` in `process-3d-crops-entire-2d-slice`

#### debug__materialize_crops

same for
`debug__materialize_crops`

# Save notebook

In [None]:
this_nb_name = "process-volume-01.ipynb"
this_dir = os.getcwd()
save_nb_dir = str(estimation_volume.dir)

logger.warning(f"{this_nb_name=}")
logger.warning(f"{this_dir=}")
logger.warning(f"{save_nb_dir=}")

command = f"jupyter nbconvert {this_dir}/{this_nb_name} --output-dir {save_nb_dir} --to html"
os.system(command)