In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

from functools import partial
import logging
import pathlib
from pathlib import Path
import pprint as pprint_module
from pprint import pprint
import sys
from typing import *
import time
import yaml
from yaml import YAMLObject
import copy
import functools
import itertools
import os
import gc
import operator

import humanize
from matplotlib import pyplot as plt, cm
import numpy as np
from numpy import ndarray
import pandas as pd
from pymicro.file import file_utils
import tensorflow as tf
from numpy.random import RandomState
from progressbar import progressbar as pbar
from enum import Enum
import re
from enum import Enum
from matplotlib import patches
import seaborn as sns
from typing import Type
from dataclasses import dataclass
import dataclasses

from tensorflow import keras
from tensorflow.keras import utils
from tensorflow.keras import optimizers
from tensorflow.keras import callbacks
from tensorflow.keras import losses
from tensorflow.keras import layers

from cnn_segm import keras_custom_loss

from tomo2seg import modular_unet
from tomo2seg.logger import logger
from tomo2seg import data, viz
from tomo2seg.data import Volume
from tomo2seg.metadata import Metadata
from tomo2seg.volume_sequence import (
    VolumeCropSequence, MetaCrop3DGenerator, VSConstantEverywhere, 
    GTConstantEverywhere, SequentialGridPosition, ET3DConstantEverywhere
)
from tomo2seg import volume_sequence
from tomo2seg.model import Model as Tomo2SegModel
from tomo2seg.data import EstimationVolume
from tomo2seg import viz
from tomo2seg import process

# Setup

In [3]:
logger.setLevel(logging.DEBUG)

random_state = 42
random_state = np.random.RandomState(random_state)

n_gpus = len(tf.config.list_physical_devices('GPU'))
    
logger.debug(f"{tf.__version__=}")
logger.info(f"Num GPUs Available: {n_gpus}\nThis should be 2 on R790-TOMO.")
logger.debug(f"Should return 2 devices...\n{tf.config.list_physical_devices('GPU')=}")
logger.debug(f"Should return 2 devices...\n{tf.config.list_logical_devices('GPU')=}")

# xla auto-clustering optimization (see: https://www.tensorflow.org/xla#auto-clustering)
# this seems to break the training
tf.config.optimizer.set_jit(False)

DEBUG::tomo2seg::{<ipython-input-3-bba693eaf462>:<module>:008}::[2020-12-09::13:19:22.113]
tf.__version__='2.2.0'

INFO::tomo2seg::{<ipython-input-3-bba693eaf462>:<module>:009}::[2020-12-09::13:19:22.114]
Num GPUs Available: 0
This should be 2 on R790-TOMO.

DEBUG::tomo2seg::{<ipython-input-3-bba693eaf462>:<module>:010}::[2020-12-09::13:19:22.115]
Should return 2 devices...
tf.config.list_physical_devices('GPU')=[]

DEBUG::tomo2seg::{<ipython-input-3-bba693eaf462>:<module>:011}::[2020-12-09::13:19:22.131]
Should return 2 devices...
tf.config.list_logical_devices('GPU')=[]



# Args

In [4]:
from tomo2seg.datasets import (
    VOLUME_COMPOSITE_V1 as VOLUME_NAME_VERSION,
#     VOLUME_COMPOSITE_V1_REDUCED as VOLUME_NAME_VERSION,
#     VOLUME_COMPOSITE_NEIGHBOUR as VOLUME_NAME_VERSION,    
#     VOLUME_COMPOSITE_FLEX as VOLUME_NAME_VERSION,    
#     VOLUME_COMPOSITE_BIAXE as VOLUME_NAME_VERSION,    
)

# runid = 1607343593
try:
    runid
except NameError:
    runid = int(time.time())

args = process.ProcessVolumeArgs(
    model_name="unet2d.vanilla03-f16.fold000.1606-505-109",
    model_type=process.ModelType.input2d, 
    
    volume_name=VOLUME_NAME_VERSION[0], 
    volume_version=VOLUME_NAME_VERSION[1], 
    
#     partition_alias=None,
    partition_alias="test",
    
    cropping_strategy=process.CroppingStrategy.maximum_size_reduced_overlap, 
    aggregation_strategy=process.AggregationStrategy.average_probabilities, 
    
    runid=runid,
    probabilities_dtype = np.float16,
    
    opts=process.ProcessVolumeOpts(
        save_probas_by_class = False,
        debug__save_figs = True,
        override_batch_size = 6,
    ), 
)

In [5]:
tomo2seg_model = Tomo2SegModel.build_from_model_name(args.model_name)

volume = Volume.with_check(
    name=args.volume_name, version=args.volume_version
)

partition = volume[args.partition_alias] if args.partition_alias is not None else None

estimation_volume = EstimationVolume.from_objects(
    volume=volume, 
    model=tomo2seg_model, 
    set_partition=partition,
    runid=runid,
)

# this is informal metadata for human use
estimation_volume["aggregation_strategy"] = args.aggregation_strategy.name
estimation_volume["cropping_strategy"] = args.cropping_strategy.name
estimation_volume["probabilities_dtype"] = args.probabilities_dtype.__name__

DEBUG::tomo2seg::{data.py:with_check:258}::[2020-12-09::13:19:22.480]
vol=Volume(name='PA66GF30', version='v1', _metadata=None)

DEBUG::tomo2seg::{data.py:metadata:195}::[2020-12-09::13:19:22.482]
Loading metadata from `/home/users/jcasagrande/projects/tomo2seg/data/PA66GF30.v1/PA66GF30.v1.metadata.yml`.

DEBUG::tomo2seg::{data.py:metadata_path:328}::[2020-12-09::13:19:22.500]
Creating metadata file /home/users/jcasagrande/projects/tomo2seg/data/vol=PA66GF30.v1.set=test.model=unet2d.vanilla03-f16.fold000.1606-505-109.runid=1607-516-362/vol=PA66GF30.v1.set=test.model=unet2d.vanilla03-f16.fold000.1606-505-109.runid=1607-516-362.metadata.yml.

DEBUG::tomo2seg::{data.py:__setitem__:336}::[2020-12-09::13:19:22.538]
Writing to file self.metadata_path=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/vol=PA66GF30.v1.set=test.model=unet2d.vanilla03-f16.fold000.1606-505-109.runid=1607-516-362/vol=PA66GF30.v1.set=test.model=unet2d.vanilla03-f16.fold000.1606-505-109.runid=1607-516-362.met

show inputs

In [6]:
logger.info(f"args\n{pprint_module.PrettyPrinter(indent=4, compact=False).pformat(dataclasses.asdict(args))}")
logger.info(f"{estimation_volume=}")
            
logger.debug(f"{volume=}")
logger.debug(f"{partition=}")
logger.debug(f"{tomo2seg_model=}")

if args.model_type == process.ModelType.input2halfd:
    raise NotImplementedError(f"{args.model_type=}")

INFO::tomo2seg::{<ipython-input-6-ce63d89898fb>:<module>:001}::[2020-12-09::13:19:22.682]
args
{   'aggregation_strategy': <AggregationStrategy.average_probabilities: 0>,
    'cropping_strategy': <CroppingStrategy.maximum_size_reduced_overlap: 1>,
    'model_name': 'unet2d.vanilla03-f16.fold000.1606-505-109',
    'model_type': <ModelType.input2d: 0>,
    'opts': {   'debug__save_figs': True,
                'override_batch_size': 6,
                'save_probas_by_class': False},
    'partition_alias': 'test',
    'probabilities_dtype': <class 'numpy.float16'>,
    'runid': 1607516362,
    'volume_name': 'PA66GF30',
    'volume_version': 'v1'}

INFO::tomo2seg::{<ipython-input-6-ce63d89898fb>:<module>:002}::[2020-12-09::13:19:22.683]
estimation_volume=EstimationVolume(volume_fullname='PA66GF30.v1', model_name='unet2d.vanilla03-f16.fold000.1606-505-109', runid=1607516362, partition=SetPartition(x_range=(0, 1300), y_range=(0, 1040), z_range=(1300, 1600), alias='test'))

DEBUG::tomo2seg::{

# Load

##### gpu distribution strategy

In [7]:
# get a distribution strategy to use both gpus (see https://www.tensorflow.org/guide/distributed_training)
# strategy = tf.distribute.MirroredStrategy()  

# there is a bug with MirroredStrategy when you model.predict() with batch_size=1
# https://docs.google.com/document/d/17X1CUvGtlio3pkbKFemSGbF2Qnn0vWAZfCLsgFPoOqg/edit?usp=sharing
one_device = tf.distribute.OneDeviceStrategy(device="/gpu:0" if n_gpus > 0 else "/cpu:0")
# logger.info(f"Because {args.model_type=}, MirroredStrategy cannot be used. Switched to {strategy.__class__.__name__}.")
    
logger.debug(f"{one_device=}")

DEBUG::tomo2seg::{<ipython-input-7-b0dafcd8e311>:<module>:009}::[2020-12-09::13:19:22.749]
one_device=<tensorflow.python.distribute.one_device_strategy.OneDeviceStrategy object at 0x7ff848409850>



##### model

In [8]:
def get_model():
    
    logger.info(f"Loading model from autosaved file: {tomo2seg_model.autosaved_model_path.name}")
    
    model = tf.keras.models.load_model(
        tomo2seg_model.autosaved_model_path_str,
        compile=False
    )
    
    logger.debug("Changing the model's input type to accept any size of crop.")
    
    in_ = model.layers[0]
    in_shape = in_.input_shape[0]
    input_n_channels = in_shape[-1:]

    logger.debug(f"{input_n_channels=}")
    
    # make it capable of getting any dimension in the input
    anysize_input = layers.Input(
        shape=[None, None, None] + list(input_n_channels),
        name="input_any_image_size"
    )
    
    logger.debug(f"{anysize_input=}")
    
    model.layers[0] = anysize_input
    
    # todo keep this somewhere instead of copying and pasting
    optimizer = optimizers.Adam()
    loss_func = keras_custom_loss.jaccard2_loss

    model.compile(loss=loss_func, optimizer=optimizer)
    
    return model

with one_device.scope():
    logger.info(f"Loading model with {one_device.__class__.__name__}.")
    model = get_model()

INFO::tomo2seg::{<ipython-input-8-7454cecbba19>:<module>:037}::[2020-12-09::13:19:22.819]
Loading model with OneDeviceStrategy.

INFO::tomo2seg::{<ipython-input-8-7454cecbba19>:get_model:003}::[2020-12-09::13:19:22.819]
Loading model from autosaved file: unet2d.vanilla03-f16.fold000.1606-505-109.autosaved.hdf5

DEBUG::tomo2seg::{<ipython-input-8-7454cecbba19>:get_model:010}::[2020-12-09::13:19:26.232]
Changing the model's input type to accept any size of crop.

DEBUG::tomo2seg::{<ipython-input-8-7454cecbba19>:get_model:016}::[2020-12-09::13:19:26.233]
input_n_channels=(1,)

DEBUG::tomo2seg::{<ipython-input-8-7454cecbba19>:get_model:024}::[2020-12-09::13:19:26.235]
anysize_input=<tf.Tensor 'input_any_image_size:0' shape=(None, None, None, None, 1) dtype=float32>



##### data

In [9]:
def _read_raw(path_: Path, volume_: Volume): 
    # from pymicro
    return file_utils.HST_read(
        str(path_),  # it doesn't accept paths...
        # pre-loaded kwargs
        autoparse_filename=False,  # the file names are not properly formatted
        data_type=volume.metadata.dtype,
        dims=volume.metadata.dimensions,
        verbose=True,
    )

read_raw = partial(_read_raw, volume_=volume)

logger.info(f"Loading data from disk at file: {volume.data_path.name}")

voldata = read_raw(volume.data_path) / 255  # normalize

logger.debug(f"{voldata.shape=}")

if partition is not None:
    logger.debug(f"Cutting data with {partition.alias=}")
    data_volume = partition.get_volume_partition(voldata)

else:
    logger.debug(f"No partition.")
    data_volume = voldata

del voldata

logger.debug(f"{data_volume.shape=}")
logger.debug(f"{data_volume.size=}  ({humanize.intword(data_volume.size)})")

INFO::tomo2seg::{<ipython-input-9-276a1e1a6bd3>:<module>:014}::[2020-12-09::13:19:26.320]
Loading data from disk at file: PA66GF30.v1.raw

data type is uint8
volume size is 1300 x 1040 x 1900
reading volume... from byte 0
DEBUG::tomo2seg::{<ipython-input-9-276a1e1a6bd3>:<module>:018}::[2020-12-09::13:19:52.349]
voldata.shape=(1300, 1040, 1900)

DEBUG::tomo2seg::{<ipython-input-9-276a1e1a6bd3>:<module>:021}::[2020-12-09::13:19:52.350]
Cutting data with partition.alias='test'

DEBUG::tomo2seg::{<ipython-input-9-276a1e1a6bd3>:<module>:030}::[2020-12-09::13:19:52.351]
data_volume.shape=(1300, 1040, 300)

DEBUG::tomo2seg::{<ipython-input-9-276a1e1a6bd3>:<module>:031}::[2020-12-09::13:19:52.351]
data_volume.size=405600000  (405.6 million)



# Processing

In [10]:
if args.opts.debug__save_figs:
    figs_dir = estimation_volume.dir
    
    logger.debug(f"{figs_dir=}")
    figs_dir.mkdir(exist_ok=True)
    
volume_shape = data_volume.shape

logger.info(f"{volume_shape=}")

DEBUG::tomo2seg::{<ipython-input-10-789dccd181b8>:<module>:004}::[2020-12-09::13:19:52.418]
figs_dir=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/vol=PA66GF30.v1.set=test.model=unet2d.vanilla03-f16.fold000.1606-505-109.runid=1607-516-362')

INFO::tomo2seg::{<ipython-input-10-789dccd181b8>:<module>:009}::[2020-12-09::13:19:52.419]
volume_shape=(1300, 1040, 300)



## Shapes

In [11]:
MULTIPLE_REQUIREMENT = 16
logger.info(f"{MULTIPLE_REQUIREMENT=}")

MAX_INTERNAL_NVOXELS = max(
    # seen cases
    4 * (8 * 6) * (96**3),
    8 * (16 * 6) * (320**2),  
    3 * (16 * 6) * (800 * 928),
)

logger.info(f"{MAX_INTERNAL_NVOXELS=} ({humanize.intcomma(MAX_INTERNAL_NVOXELS)})")

input_layer = model.layers[0]

logger.debug(f"{input_layer}")

assert (input_layer_class := input_layer.__class__) == tf.keras.layers.InputLayer, f"{input_layer_class=}"

input_nvoxels = functools.reduce(operator.mul, (x for x in input_layer.input_shape[0][1:]))

logger.debug(f"{input_nvoxels=}")


def get_layer_nvoxels(layer) -> int:
    return functools.reduce(operator.mul, (x for x in layer.output_shape[1:]))


internal_nvoxels = [
    get_layer_nvoxels(l)
    for l in model.layers[1:]
]

max_internal_nvoxels = max(internal_nvoxels)

logger.debug(f"{max_internal_nvoxels=} ({humanize.intcomma(max_internal_nvoxels)})")

internal_nvoxel_factor = max_internal_nvoxels / input_nvoxels

logger.debug(f"{internal_nvoxel_factor=}")

assert internal_nvoxel_factor == int(internal_nvoxel_factor), f"{internal_nvoxel_factor=}"

internal_nvoxel_factor = int(internal_nvoxel_factor)

logger.debug(f"{internal_nvoxel_factor=}")

max_batch_nvoxels = int(np.floor(MAX_INTERNAL_NVOXELS / internal_nvoxel_factor))

logger.info(f"{max_batch_nvoxels=} ({humanize.intcomma(max_batch_nvoxels)})")

if args.cropping_strategy == process.CroppingStrategy.maximum_size:
    crop_dims_multiple_16 = process.get_largest_crop_multiple(
        volume_shape, 
        multiple_of=MULTIPLE_REQUIREMENT
    )

elif args.cropping_strategy == process.CroppingStrategy.maximum_size_reduced_overlap:
    # it's not necessarily the real minimum, just an easy way to get a big crop with less overlap
    # get the largest multiple of the requirement above the dimension size / 2
    # that will give a max overlap of 2 * MULTIPLE_REQUIREMENT - 1
    # e.g. with MULTIPLE_REQUIREMENT = 16, the maximum overlap is 31
    crop_dims_multiple_16 = tuple(
        (1 + int((dim / 2) // MULTIPLE_REQUIREMENT)) * MULTIPLE_REQUIREMENT if dim % MULTIPLE_REQUIREMENT != 0 else
        dim
        for dim in volume_shape
    )
    
    logger.info(f"the max overlap in each direction will be {tuple(int(2 * MULTIPLE_REQUIREMENT - s % MULTIPLE_REQUIREMENT) for s in volume_shape)}")
    
else:
    raise ValueError(f"{args.cropping_strategy=}")

logger.debug(f"{crop_dims_multiple_16=} using {args.cropping_strategy=}")

# it has to be multiple of 16 because of the 4 cascaded 2x2-strided 2x2-downsamplings in u-net
if args.model_type == process.ModelType.input2d:
    crop_shape = (
        crop_dims_multiple_16[0],
        crop_dims_multiple_16[1],
        1,
    )

elif args.model_type == process.ModelType.input2halfd:
    raise NotImplemented()
    
elif args.model_type == process.ModelType.input3d:
    crop_shape = crop_dims_multiple_16

logger.debug(f"ideal {crop_shape=} for {args.model_type=} now let's see if the maximum number of voxels is ok...")

crop_shape = process.reduce_dimensions(
    crop_shape,
    max_nvoxels=max_batch_nvoxels,
    multiple_of=MULTIPLE_REQUIREMENT,
)
    
logger.info(f"{crop_shape=}")

crop_nvoxels = functools.reduce(operator.mul, crop_shape)

logger.info(f"{crop_nvoxels=}")

max_batch_size_per_gpu = int(np.floor(max_batch_nvoxels / crop_nvoxels))

logger.info(f"{max_batch_size_per_gpu=}")

INFO::tomo2seg::{<ipython-input-11-7eb2ef5a5ab3>:<module>:002}::[2020-12-09::13:19:52.494]
MULTIPLE_REQUIREMENT=16

INFO::tomo2seg::{<ipython-input-11-7eb2ef5a5ab3>:<module>:011}::[2020-12-09::13:19:52.495]
MAX_INTERNAL_NVOXELS=213811200 (213,811,200)

DEBUG::tomo2seg::{<ipython-input-11-7eb2ef5a5ab3>:<module>:015}::[2020-12-09::13:19:52.496]
<tensorflow.python.keras.engine.input_layer.InputLayer object at 0x7ff90c699b50>

DEBUG::tomo2seg::{<ipython-input-11-7eb2ef5a5ab3>:<module>:021}::[2020-12-09::13:19:52.497]
input_nvoxels=102400

DEBUG::tomo2seg::{<ipython-input-11-7eb2ef5a5ab3>:<module>:035}::[2020-12-09::13:19:52.498]
max_internal_nvoxels=9830400 (9,830,400)

DEBUG::tomo2seg::{<ipython-input-11-7eb2ef5a5ab3>:<module>:039}::[2020-12-09::13:19:52.500]
internal_nvoxel_factor=96.0

DEBUG::tomo2seg::{<ipython-input-11-7eb2ef5a5ab3>:<module>:045}::[2020-12-09::13:19:52.500]
internal_nvoxel_factor=96

INFO::tomo2seg::{<ipython-input-11-7eb2ef5a5ab3>:<module>:049}::[2020-12-09::13:19:52

## Steps and coordinates

In [12]:
n_steps = tuple(
    int(np.ceil(vol_dim / crop_dim))
    for vol_dim, crop_dim in zip(volume_shape, crop_shape)
)

logger.debug(f"{n_steps=}")

def get_coordinates_iterator(n_steps_):
    assert len(n_steps_) == 3
    return itertools.product(*(range(n_steps_[dim]) for dim in range(3)))

get_ijk_iterator = functools.partial(
    get_coordinates_iterator, copy.copy(n_steps)
)

get_kji_iterator = functools.partial(
    get_coordinates_iterator, tuple(reversed(n_steps))
)

# coordinates (xs, ys, and zs) of the front upper left corners of the crops
x0s, y0s, z0s = tuple(
    tuple(map(
        int, 
        np.linspace(0, vol_dim - crop_dim, n)
    ))
    for vol_dim, crop_dim, n in zip(volume_shape, crop_shape, n_steps)
)
logger.debug(f"{min(x0s)=}, {max(x0s)=}, {len(x0s)=}")
logger.debug(f"{min(y0s)=}, {max(y0s)=}, {len(y0s)=}")
logger.debug(f"{min(z0s)=}, {max(z0s)=}, {len(z0s)=}")

DEBUG::tomo2seg::{<ipython-input-12-9002bd14f40b>:<module>:006}::[2020-12-09::13:19:52.573]
n_steps=(2, 1, 300)

DEBUG::tomo2seg::{<ipython-input-12-9002bd14f40b>:<module>:028}::[2020-12-09::13:19:52.575]
min(x0s)=0, max(x0s)=644, len(x0s)=2

DEBUG::tomo2seg::{<ipython-input-12-9002bd14f40b>:<module>:029}::[2020-12-09::13:19:52.576]
min(y0s)=0, max(y0s)=0, len(y0s)=1

DEBUG::tomo2seg::{<ipython-input-12-9002bd14f40b>:<module>:030}::[2020-12-09::13:19:52.576]
min(z0s)=0, max(z0s)=299, len(z0s)=300



### crops coordinates 

In [13]:
logger.debug("Generating the crop coordinates.")

crops_coordinates = np.array(
    [
        (
            (x0, x0 + crop_shape[0]), 
            (y0, y0 + crop_shape[1]),
            (z0, z0 + crop_shape[2]),
        )
        for x0, y0, z0 in itertools.product(x0s, y0s, z0s)
    ], 
    dtype=tuple
).reshape(len(x0s), len(y0s), len(z0s), 3, 2).astype(int)  # 3 = nb of dimenstions, 2 = (start, end)

logger.debug(f"{crops_coordinates.shape=}")

# 'F' reshapes with x varying fastest and z slowest
crops_coordinates_sequential = crops_coordinates.reshape(-1, 3, 2, order='F')  

logger.debug(f"{crops_coordinates_sequential.shape=}")

DEBUG::tomo2seg::{<ipython-input-13-81c28bea32a8>:<module>:001}::[2020-12-09::13:19:52.645]
Generating the crop coordinates.

DEBUG::tomo2seg::{<ipython-input-13-81c28bea32a8>:<module>:015}::[2020-12-09::13:19:52.650]
crops_coordinates.shape=(2, 1, 300, 3, 2)

DEBUG::tomo2seg::{<ipython-input-13-81c28bea32a8>:<module>:020}::[2020-12-09::13:19:52.651]
crops_coordinates_sequential.shape=(600, 3, 2)



## debug

### orthogonal slices plot

In [14]:
if args.opts.debug__save_figs:
    
    fig, axs = plt.subplots(2, 2, figsize=(sz := 15, sz), dpi=120)
    fig.set_tight_layout(True)
    
    display = viz.OrthogonalSlicesDisplay(
        volume=data_volume,
        volume_name=volume.fullname,
    ).plot(axs=axs,)
    
    logger.info(f"Saving figure {(figname := display.title + '.png')=}")
    display.fig_.savefig(
        fname=figs_dir / figname,
        dpi=200, format="png",
        metadata=display.metadata,
    )
    plt.close()

INFO::tomo2seg::{<ipython-input-14-ae231ce72243>:<module>:011}::[2020-12-09::13:19:52.860]
Saving figure (figname := display.title + '.png')='PA66GF30.v1.orthogonal-slices-display.x=650-y=520-z=150.png'



### Segment an example

In [15]:
crop_ijk = (0, 0, 0)
i, j, k = crop_ijk
crop_coords = crops_coordinates[i, j, k]

logger.info(f"Segmenting one crop for debug {crop_ijk=}")

slice3d = tuple(slice(*coords_) for coords_ in crop_coords)
crop_data = data_volume[slice3d]
del slice3d
    
logger.debug(f"{crop_data.shape=}")

# [model] - i call it with a first crop bc if something goes wrong then the error
# will appear here instead of in a loop

# modelin
modelin_target_shape = (1, crop_shape[0], crop_shape[1], crop_shape[2], 1)

logger.debug(f"{modelin_target_shape=}")

modelin = crop_data.reshape(modelin_target_shape) 

# modelout
modelout = model.predict(
    modelin, 
    batch_size=1,
    steps=1,
    verbose=2,
)

logger.debug(f"{modelout.shape=}")

n_classes = modelout.shape[-1]

assert n_classes == len(volume.metadata.labels), f"{n_classes=} {len(volume.metadata.labels)=}"

# probas
crop_probas_target_shape = list(crop_shape) + [n_classes]

logger.debug(f"{crop_probas_target_shape=}")

crop_probas = modelout.reshape(crop_probas_target_shape).astype(args.probabilities_dtype)

logger.debug(f"{crop_probas.shape=}")
logger.debug(f"{crop_probas.dtype=}")

# preds
crop_preds = crop_probas.argmax(axis=-1).astype(np.int8)

logger.debug(f"{crop_preds.shape=}")
logger.debug(f"{crop_preds.dtype=}")

INFO::tomo2seg::{<ipython-input-15-c7dea0ab7ff2>:<module>:005}::[2020-12-09::13:19:54.935]
Segmenting one crop for debug crop_ijk=(0, 0, 0)

DEBUG::tomo2seg::{<ipython-input-15-c7dea0ab7ff2>:<module>:011}::[2020-12-09::13:19:54.936]
crop_data.shape=(656, 1040, 1)

DEBUG::tomo2seg::{<ipython-input-15-c7dea0ab7ff2>:<module>:019}::[2020-12-09::13:19:54.936]
modelin_target_shape=(1, 656, 1040, 1, 1)

1/1 - 0s
DEBUG::tomo2seg::{<ipython-input-15-c7dea0ab7ff2>:<module>:031}::[2020-12-09::13:20:10.702]
modelout.shape=(1, 656, 1040, 3)

DEBUG::tomo2seg::{<ipython-input-15-c7dea0ab7ff2>:<module>:040}::[2020-12-09::13:20:10.703]
crop_probas_target_shape=[656, 1040, 1, 3]

DEBUG::tomo2seg::{<ipython-input-15-c7dea0ab7ff2>:<module>:044}::[2020-12-09::13:20:10.729]
crop_probas.shape=(656, 1040, 1, 3)

DEBUG::tomo2seg::{<ipython-input-15-c7dea0ab7ff2>:<module>:045}::[2020-12-09::13:20:10.730]
crop_probas.dtype=dtype('float16')

DEBUG::tomo2seg::{<ipython-input-15-c7dea0ab7ff2>:<module>:050}::[2020-1

In [16]:
if args.opts.debug__save_figs:
    fig, axs = plt.subplots(
        nrows=3, ncols=2,
        figsize=(2 * (sz := 20), sz), 
        dpi=120,
    )

    display = viz.OrthogonalSlicesPredictionDisplay(
        volume_data=crop_data,
        volume_prediction=crop_preds,
        n_classes=n_classes,
        volume_name=volume.fullname + f".debug.crop-{crop_ijk=}",
    ).plot(axs=axs,)

    logger.info(f"Saving figure {(figname := display.title + '.png')=}")
    display.fig_.savefig(
        fname=figs_dir / figname,
        format="png",
        metadata=display.metadata,
    )       
    plt.close()

INFO::tomo2seg::{<ipython-input-16-ff0f199fac4c>:<module>:015}::[2020-12-09::13:20:11.249]
Saving figure (figname := display.title + '.png')='PA66GF30.v1.debug.crop-crop_ijk=(0, 0, 0).orthogonal-slices-display.x=328-y=520-z=0.png'



### Segment a batch with `batch_size=n_gpus` (1 per device)

In [17]:
logger.info("Segmenting a batch for debug.")

INFO::tomo2seg::{<ipython-input-17-63e5c06b8645>:<module>:001}::[2020-12-09::13:20:12.496]
Segmenting a batch for debug.



In [18]:
batch_size = max(1, n_gpus)

logger.debug(f"{batch_size=}")

DEBUG::tomo2seg::{<ipython-input-18-1363c31e67ff>:<module>:003}::[2020-12-09::13:20:12.574]
batch_size=1



In [19]:
mirror = tf.distribute.MirroredStrategy()

with mirror.scope():
    logger.info(f"Loading model with {mirror.__class__.__name__}.")
    model = get_model()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
INFO::tomo2seg::{<ipython-input-19-6e906eab95a7>:<module>:004}::[2020-12-09::13:20:12.650]
Loading model with MirroredStrategy.

INFO::tomo2seg::{<ipython-input-8-7454cecbba19>:get_model:003}::[2020-12-09::13:20:12.651]
Loading model from autosaved file: unet2d.vanilla03-f16.fold000.1606-505-109.autosaved.hdf5

DEBUG::tomo2seg::{<ipython-input-8-7454cecbba19>:get_model:010}::[2020-12-09::13:20:15.818]
Changing the model's input type to accept any size of crop.

DEBUG::tomo2seg::{<ipython-input-8-7454cecbba19>:get_model:016}::[2020-12-09::13:20:15.820]
input_n_channels=(1,)

DEBUG::tomo2seg::{<ipython-input-8-7454cecbba19>:get_model:024}::[2020-12-09::13:20:15.822]
anysize_input=<tf.Tensor 'input_any_image_size_1:0' shape=(None, None, None, None, 1) dtype=float32>



In [20]:
batch_coords = crops_coordinates_sequential[:batch_size]

logger.debug(f"{batch_coords.shape=}")

batch_slices = [
    tuple(slice(*coords_) for coords_ in crop_coords)
    for crop_coords in batch_coords
]

logger.debug(f"{batch_slices=}")

batch_data = np.stack([
    data_volume[slice_]
    for slice_ in batch_slices
], axis=0)

logger.debug(f"{batch_data.shape=}")

# [model] - now i call it with a first the mirror strategy to make sure it wont break

# modelin
modelin_target_shape = (batch_size, crop_shape[0], crop_shape[1], crop_shape[2], 1)  # adjust nb. channels

logger.debug(f"{modelin_target_shape=}")

modelin = batch_data.reshape(modelin_target_shape) 

# modelout
modelout = model.predict(
    modelin, 
    batch_size=batch_size,
    steps=1,
    verbose=2,
)

logger.debug(f"{modelout.shape=}")

DEBUG::tomo2seg::{<ipython-input-20-e239b59957c3>:<module>:003}::[2020-12-09::13:20:16.146]
batch_coords.shape=(1, 3, 2)

DEBUG::tomo2seg::{<ipython-input-20-e239b59957c3>:<module>:010}::[2020-12-09::13:20:16.147]
batch_slices=[(slice(0, 656, None), slice(0, 1040, None), slice(0, 1, None))]

DEBUG::tomo2seg::{<ipython-input-20-e239b59957c3>:<module>:017}::[2020-12-09::13:20:16.149]
batch_data.shape=(1, 656, 1040, 1)

DEBUG::tomo2seg::{<ipython-input-20-e239b59957c3>:<module>:024}::[2020-12-09::13:20:16.150]
modelin_target_shape=(1, 656, 1040, 1, 1)

1/1 - 0s
DEBUG::tomo2seg::{<ipython-input-20-e239b59957c3>:<module>:036}::[2020-12-09::13:20:32.342]
modelout.shape=(1, 656, 1040, 3)



In [21]:
# probas
batch_probas_target_shape = [batch_size] + list(crop_shape) + [n_classes]

logger.debug(f"{batch_probas_target_shape=}")

batch_probas = modelout.reshape(batch_probas_target_shape).astype(args.probabilities_dtype)

logger.debug(f"{batch_probas.shape=}")
logger.debug(f"{batch_probas.dtype=}")

# preds
batch_preds = batch_probas.argmax(axis=-1).astype(np.int8)

logger.debug(f"{batch_preds.shape=}")
logger.debug(f"{batch_preds.dtype=}")

DEBUG::tomo2seg::{<ipython-input-21-d17f03134596>:<module>:004}::[2020-12-09::13:20:32.724]
batch_probas_target_shape=[1, 656, 1040, 1, 3]

DEBUG::tomo2seg::{<ipython-input-21-d17f03134596>:<module>:008}::[2020-12-09::13:20:32.750]
batch_probas.shape=(1, 656, 1040, 1, 3)

DEBUG::tomo2seg::{<ipython-input-21-d17f03134596>:<module>:009}::[2020-12-09::13:20:32.751]
batch_probas.dtype=dtype('float16')

DEBUG::tomo2seg::{<ipython-input-21-d17f03134596>:<module>:014}::[2020-12-09::13:20:32.773]
batch_preds.shape=(1, 656, 1040, 1)

DEBUG::tomo2seg::{<ipython-input-21-d17f03134596>:<module>:015}::[2020-12-09::13:20:32.774]
batch_preds.dtype=dtype('int8')



### segment batch with `batch_size = n_gpus * max_batch_size_per_gpu`

In [22]:
batch_size = max(1, n_gpus) * max_batch_size_per_gpu

logger.debug(f"{batch_size=}")

if args.opts.override_batch_size is not None:
    batch_size = args.opts.override_batch_size
    logger.info(f"{args.opts.override_batch_size=} give ==> replacing the {batch_size=}")

batch_coords = crops_coordinates_sequential[:batch_size]
batch_slices = [
    tuple(slice(*coords_) for coords_ in crop_coords)
    for crop_coords in batch_coords
]
batch_data = np.stack([data_volume[slice_] for slice_ in batch_slices], axis=0)

# [model]
# modelin
modelin_target_shape = (batch_size, crop_shape[0], crop_shape[1], crop_shape[2], 1)  # adjust nb. channels
modelin = batch_data.reshape(modelin_target_shape) 
# modelout
modelout = model.predict(
    modelin, 
    batch_size=batch_size,
    steps=1,
    verbose=2,
)

logger.debug(f"{modelout.shape=}")

DEBUG::tomo2seg::{<ipython-input-22-077fe178a744>:<module>:003}::[2020-12-09::13:20:32.846]
batch_size=3

INFO::tomo2seg::{<ipython-input-22-077fe178a744>:<module>:007}::[2020-12-09::13:20:32.847]
args.opts.override_batch_size=6 give ==> replacing the batch_size=6

1/1 - 0s
DEBUG::tomo2seg::{<ipython-input-22-077fe178a744>:<module>:028}::[2020-12-09::13:22:03.699]
modelout.shape=(6, 656, 1040, 3)



In [23]:
if args.opts.debug__save_figs:
    
    batch_modelout = modelout
    
    for idx, (crop_data__, crop_probas__) in enumerate(zip(batch_data, batch_modelout)):
        
        fig, axs = plt.subplots(
            nrows=3, ncols=2,
            figsize=(2 * (sz := 20), sz), 
            dpi=120,
        )

        display = viz.OrthogonalSlicesPredictionDisplay(
            volume_data=crop_data__,
            volume_prediction=crop_probas__.argmax(axis=-1).reshape(crop_data__.shape),
            n_classes=n_classes,
            volume_name=volume.fullname + f".debug.batch-segm.{idx=}",
        ).plot(axs=axs,)

        logger.info(f"Saving figure {(figname := display.title + '.png')=}")
        display.fig_.savefig(
            fname=figs_dir / figname,
            format="png",
            metadata=display.metadata,
        )       
        plt.close()

INFO::tomo2seg::{<ipython-input-23-18f87ba131ed>:<module>:020}::[2020-12-09::13:22:04.904]
Saving figure (figname := display.title + '.png')='PA66GF30.v1.debug.batch-segm.idx=0.orthogonal-slices-display.x=328-y=520-z=0.png'

INFO::tomo2seg::{<ipython-input-23-18f87ba131ed>:<module>:020}::[2020-12-09::13:22:06.136]
Saving figure (figname := display.title + '.png')='PA66GF30.v1.debug.batch-segm.idx=1.orthogonal-slices-display.x=328-y=520-z=0.png'

INFO::tomo2seg::{<ipython-input-23-18f87ba131ed>:<module>:020}::[2020-12-09::13:22:07.581]
Saving figure (figname := display.title + '.png')='PA66GF30.v1.debug.batch-segm.idx=2.orthogonal-slices-display.x=328-y=520-z=0.png'

INFO::tomo2seg::{<ipython-input-23-18f87ba131ed>:<module>:020}::[2020-12-09::13:22:08.802]
Saving figure (figname := display.title + '.png')='PA66GF30.v1.debug.batch-segm.idx=3.orthogonal-slices-display.x=328-y=520-z=0.png'

INFO::tomo2seg::{<ipython-input-23-18f87ba131ed>:<module>:020}::[2020-12-09::13:22:10.013]
Saving fi

# Rebuild the volume

In [26]:
n_crops = crops_coordinates_sequential.shape[0] 

logger.debug(f"{n_crops=}")

last_batch_size = n_crops % batch_size

if n_gpus > 1:
    assert last_batch_size % n_gpus == 0, f"{last_batch_size=}"

logger.debug(f"{last_batch_size=}")
    
niterations = int(np.floor(crops_coordinates_sequential.shape[0] / batch_size)) 

logger.debug(f"{niterations=}")

DEBUG::tomo2seg::{<ipython-input-26-6d190f1be318>:<module>:003}::[2020-12-09::13:23:47.589]
n_crops=600

DEBUG::tomo2seg::{<ipython-input-26-6d190f1be318>:<module>:010}::[2020-12-09::13:23:47.590]
last_batch_size=0

DEBUG::tomo2seg::{<ipython-input-26-6d190f1be318>:<module>:014}::[2020-12-09::13:23:47.590]
niterations=100



In [None]:
proba_volume_target_shape = list(volume_shape) + [n_classes]

logger.debug(f"{proba_volume_target_shape=}")

proba_volume = np.zeros(proba_volume_target_shape, dtype=args.probabilities_dtype)

logger.debug(f"{proba_volume.shape=}")

redundancies_count_target_shape = volume_shape

logger.debug(f"{redundancies_count_target_shape=}")

redundancies_count = np.zeros(redundancies_count_target_shape, dtype=np.int8)  # only one channel

logger.debug(f"{redundancies_count.shape=}")

def process_batch(batch_start_, batch_size_):
    batch_end = batch_start_ + batch_size_
    
    batch_coords = crops_coordinates_sequential[batch_start_:batch_end]
    batch_slices = [
        tuple(slice(*coords_) for coords_ in crop_coords)
        for crop_coords in batch_coords
    ]
    batch_data = np.stack([data_volume[slice_] for slice_ in batch_slices], axis=0)
    
    # [model]
    modelin_target_shape = (batch_size_, crop_shape[0], crop_shape[1], crop_shape[2], 1)  # adjust nb. channels
    batch_probas = model.predict(
        batch_data.reshape(modelin_target_shape), 
        batch_size=batch_size_,
        steps=1,
    ).astype(args.probabilities_dtype)

    for slice_, crop_proba in zip(batch_slices, batch_probas):
        proba_volume[slice_] += crop_proba.reshape(crop_probas_target_shape)
        redundancies_count[slice_] += np.ones(crop_shape, dtype=np.int)
        
logger.debug("Predicting and summing up the crops' probabilities.")
for batch_idx in pbar(
    range(niterations), 
    prefix="predict-and-sum-probas", 
    max_value=niterations
):
    batch_start = batch_idx * batch_size
    process_batch(batch_start, batch_size)

if last_batch_size > 0:
    logger.info("Segmenting the last batch")
    batch_start = niterations * batch_size
    process_batch(batch_start, last_batch_size)

DEBUG::tomo2seg::{<ipython-input-27-85d369c9652b>:<module>:003}::[2020-12-09::13:23:50.698]
proba_volume_target_shape=[1300, 1040, 300, 3]

DEBUG::tomo2seg::{<ipython-input-27-85d369c9652b>:<module>:007}::[2020-12-09::13:23:50.699]
proba_volume.shape=(1300, 1040, 300, 3)

DEBUG::tomo2seg::{<ipython-input-27-85d369c9652b>:<module>:011}::[2020-12-09::13:23:50.700]
redundancies_count_target_shape=(1300, 1040, 300)

DEBUG::tomo2seg::{<ipython-input-27-85d369c9652b>:<module>:015}::[2020-12-09::13:23:50.700]
redundancies_count.shape=(1300, 1040, 300)

DEBUG::tomo2seg::{<ipython-input-27-85d369c9652b>:<module>:039}::[2020-12-09::13:23:50.701]
Predicting and summing up the crops' probabilities.



predict-and-sum-probas 56% (56 of 100) | | Elapsed Time: 1:25:26 ETA:   1:07:18

In [None]:
del data_volume

In [None]:
gc.collect()

##### sanity checks

In [None]:
# check that the min and max probas are coherent with the min/max redundancy
min_proba_sum = proba_volume.min(axis=0).min(axis=0).min(axis=0)
max_proba_sum = proba_volume.max(axis=0).max(axis=0).max(axis=0)
min_redundancy = np.min(redundancies_count)
max_redundancy = np.max(redundancies_count)

In [None]:
assert min_redundancy >= 1, f"{min_redundancy=}"
assert np.all(min_proba_sum >= 0), f"{min_proba_sum=}"
assert np.all(max_proba_sum <= max_redundancy), f"{max_proba_sum=} {max_redundancy=}"

## Normalize probas

In [None]:
# divide each probability channel by the number of times it was summed (avg proba)
logger.debug(f"Dividing probability redundancies.")
for klass_idx in pbar(range(n_classes), max_value=n_classes, prefix="redundancies-per-class"):
    proba_volume[:, :, :, klass_idx] = proba_volume[:, :, :, klass_idx] / redundancies_count

In [None]:
del redundancies_count

In [None]:
gc.collect()

In [None]:
# this makes it more stable so that the sum is 1
proba_volume[:, :, :] /= proba_volume[:, :, :].sum(axis=-1, keepdims=True) 

##### sanity checks

In [None]:
# check that proba distribs sum to 1
min_proba = proba_volume.min(axis=0).min(axis=0).min(axis=0)
max_proba = proba_volume.max(axis=0).max(axis=0).max(axis=0)

In [None]:
assert np.all(min_proba >= 0), f"{min_proba=}"
assert np.all(max_proba <= 1), f"{max_proba=}"

In [None]:
min_distrib_proba_sum = proba_volume.sum(axis=-1).min()
max_distrib_proba_sum = proba_volume.sum(axis=-1).max()

In [None]:
assert np.isclose(min_distrib_proba_sum, 1, atol=.001), f"{min_distrib_proba_sum=}"
assert np.isclose(max_distrib_proba_sum, 1, atol=.001), f"{max_distrib_proba_sum=}"

In [None]:
gc.collect()

# proba 2 pred

In [None]:
pred_volume = np.empty(proba_volume.shape[:3], dtype="uint8")

In [None]:
np.argmax(proba_volume, axis=-1, out=pred_volume)

logger.debug(f"{pred_volume.shape=}")
logger.debug(f"{pred_volume.min()=}")
logger.debug(f"{pred_volume.max()=}")

# Save

In [None]:
logger.debug(f"Writing probabilities on disk at `{estimation_volume.probabilities_path}`")
np.save(estimation_volume.probabilities_path, proba_volume)

In [None]:
if args.opts.save_probas_by_class:
    for klass_idx in volume.metadata.labels:
        logger.debug(f"Writing probabilities of class `{klass_idx}` on disk at `{(str_path := str(estimation_volume.get_class_probability_path(klass_idx)))=}`")
        file_utils.HST_write(proba_volume[:, :, :, klass_idx], str_path)

In [None]:
logger.debug(f"Writing predictions on disk at `{(str_path := str(estimation_volume.predictions_path))}`")
file_utils.HST_write(pred_volume, str_path)

#### one-z-slice-crops-locations.png

not kept, search fro `one-z-slice-crops-locations.png` in `process-3d-crops-entire-2d-slice`

#### debug__materialize_crops

same for
`debug__materialize_crops`

# Save notebook

In [None]:
this_nb_name = "process-volume-02.ipynb"
this_dir = os.getcwd()
save_nb_dir = str(estimation_volume.dir)

logger.warning(f"{this_nb_name=}")
logger.warning(f"{this_dir=}")
logger.warning(f"{save_nb_dir=}")

command = f"jupyter nbconvert {this_dir}/{this_nb_name} --output-dir {save_nb_dir} --to html"
os.system(command)