In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

from functools import partial
import logging
import pathlib
from pathlib import Path
from pprint import pprint
import sys
from typing import *
import time
import yaml
from yaml import YAMLObject

import humanize
from matplotlib import pyplot as plt, cm
import numpy as np
import pandas as pd
from pymicro.file import file_utils
import tensorflow as tf
from numpy.random import RandomState
from progressbar import progressbar as pbar

from tensorflow import keras
from tensorflow.keras import utils
from tensorflow.keras import optimizers
from tensorflow.keras import callbacks
from tensorflow.keras import losses

from tomo2seg import modular_unet
from tomo2seg.logger import logger
from tomo2seg import data, viz
from tomo2seg.data import Volume
from tomo2seg.metadata import Metadata
from tomo2seg.volume_sequence import VolumeCropSequence, MetaCrop3DGenerator, VSConstantEverywhere, GTConstantEverywhere, SequentialGridPosition, ET3DConstantEverywhere
from tomo2seg import volume_sequence
from tomo2seg.model import Model as Tomo2SegModel

In [3]:
logger.setLevel(logging.DEBUG)

In [4]:
random_state = 42
random_state = np.random.RandomState(random_state)
runid = int(time.time())
logger.info(f"{runid=}")

INFO::tomo2seg::{<ipython-input-4-cf972d05bc84>:<module>:004}::[2020-11-20::17:22:22.793]
runid=1605889342



In [5]:
logger.debug(f"{tf.__version__=}")
logger.info(f"Num GPUs Available: {len(tf.config.list_physical_devices('GPU'))}\nThis should be 2 on R790-TOMO.")
logger.debug(f"Both here should return 2 devices...\n{tf.config.list_physical_devices('GPU')=}\n{tf.config.list_logical_devices('GPU')=}")

# xla auto-clustering optimization (see: https://www.tensorflow.org/xla#auto-clustering)
# this seems to break the training
tf.config.optimizer.set_jit(False)

# get a distribution strategy to use both gpus (see https://www.tensorflow.org/guide/distributed_training)
strategy = tf.distribute.MirroredStrategy()  
logger.debug(f"{strategy=}")

DEBUG::tomo2seg::{<ipython-input-5-9df7dd5953d5>:<module>:001}::[2020-11-20::17:22:22.850]
tf.__version__='2.2.0'

INFO::tomo2seg::{<ipython-input-5-9df7dd5953d5>:<module>:002}::[2020-11-20::17:22:22.973]
Num GPUs Available: 2
This should be 2 on R790-TOMO.

DEBUG::tomo2seg::{<ipython-input-5-9df7dd5953d5>:<module>:003}::[2020-11-20::17:22:23.346]
Both here should return 2 devices...
tf.config.list_physical_devices('GPU')=[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]
tf.config.list_logical_devices('GPU')=[LogicalDevice(name='/device:GPU:0', device_type='GPU'), LogicalDevice(name='/device:GPU:1', device_type='GPU')]

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
DEBUG::tomo2seg::{<ipython-input-5-9df7dd5953d5>:<module>:011}::[2020-11-20::17:22:23.356]
strategy=<tensorflow.python.distribute.mirrored_s

# Model

In [6]:
from cnn_segm import keras_custom_loss
from tensorflow.keras import layers

In [7]:
tomo2seg_model = Tomo2SegModel.build_from_model_name(
    "unet-2d.vanilla00.000.1605-801-777"
)
logger.info(f"{tomo2seg_model=}")

with strategy.scope():
    model = tf.keras.models.load_model(
        tomo2seg_model.autosaved_model_path_str,
        compile=False
    )
    
    in_ = model.layers[0]
    in_shape = in_.input_shape[0]

    logger.debug(f"{(input_n_channels := in_shape[-1:])=}")

    anysize_input = layers.Input(
        shape=[None, None, None] + list(input_n_channels),
        name="input_any_image_size"
    )

    model.layers[0] = anysize_input
    
    # todo keep this somewhere instead of copying and pasting
    optimizer = optimizers.Adam()
    loss_func = keras_custom_loss.jaccard2_loss

    model.compile(loss=loss_func, optimizer=optimizer)


INFO::tomo2seg::{<ipython-input-7-d9d7ebfa9ca4>:<module>:004}::[2020-11-20::17:22:23.471]
tomo2seg_model=Model(master_name='unet-2d', version='vanilla00', fold=0, runid=1605801777, factory_function=None, factory_kwargs=None)

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 the

# Data

In [8]:
from tomo2seg.datasets import (
    VOLUME_COMPOSITE_V1 as VOLUME_NAME_VERSION,
    VOLUME_COMPOSITE_V1_LABELS_REFINED3 as LABELS_VERSION
)

volume_name, volume_version = VOLUME_NAME_VERSION
labels_version = LABELS_VERSION

logger.info(f"{volume_name=} {volume_version=} {labels_version=}")

INFO::tomo2seg::{<ipython-input-8-e97f5958400b>:<module>:009}::[2020-11-20::17:22:27.016]
volume_name='PA66GF30' volume_version='v1' labels_version='refined3'



In [9]:
# Metadata/paths objects

## Volume
volume = Volume.with_check(
    name=volume_name, version=volume_version
)
logger.info(f"{volume=}")

def _read_raw(path_: Path, volume_: Volume): 
    # from pymicro
    return file_utils.HST_read(
        str(path_),  # it doesn't accept paths...
        # pre-loaded kwargs
        autoparse_filename=False,  # the file names are not properly formatted
        data_type=volume.metadata.dtype,
        dims=volume.metadata.dimensions,
        verbose=True,
    )

read_raw = partial(_read_raw, volume_=volume)

logger.info("Loading data from disk.")

## Data
voldata = read_raw(volume.data_path) / 255  # normalize
logger.debug(f"{voldata.shape=}")

voldata_train = volume.train_partition.get_volume_partition(voldata)
voldata_val = volume.val_partition.get_volume_partition(voldata)
voldata_test = volume.test_partition.get_volume_partition(voldata)

del voldata
logger.debug(f"{voldata_train.shape=} {voldata_val.shape=} {voldata_test.shape=}")

## Labels
vollabels = read_raw(volume.versioned_labels_path(labels_version))
logger.debug(f"{vollabels.shape=}")

vollabels_train = volume.train_partition.get_volume_partition(vollabels)
vollabels_val = volume.val_partition.get_volume_partition(vollabels)
vollabels_test = volume.test_partition.get_volume_partition(vollabels)

del vollabels
logger.debug(f"{vollabels_train.shape=} {vollabels_val.shape=} {vollabels_test.shape=}")

DEBUG::tomo2seg::{data.py:with_check:220}::[2020-11-20::17:22:27.079]
vol=Volume(name='PA66GF30', version='v1', _metadata=None)

ERROR::tomo2seg::{data.py:with_check:238}::[2020-11-20::17:22:27.081]
Missing file: /home/users/jcasagrande/projects/tomo2seg/data/PA66GF30.v1/PA66GF30.v1.labels.raw

Missing file: /home/users/jcasagrande/projects/tomo2seg/data/PA66GF30.v1/PA66GF30.v1.weights.raw

DEBUG::tomo2seg::{data.py:metadata:177}::[2020-11-20::17:22:27.083]
Loading metadata from `/home/users/jcasagrande/projects/tomo2seg/data/PA66GF30.v1/PA66GF30.v1.metadata.yml`.

INFO::tomo2seg::{<ipython-input-9-014320ca7357>:<module>:007}::[2020-11-20::17:22:27.088]
volume=Volume(name='PA66GF30', version='v1', _metadata=Volume.Metadata(dimensions=[1300, 1040, 1900], dtype='uint8', labels=[0, 1, 2], labels_names={0: 'matrix', 1: 'fiber', 2: 'porosity'}, set_partitions={'train': {'x_range': [0, 1300], 'y_range': [0, 1040], 'z_range': [0, 1300], 'alias': 'train'}, 'val': {'x_range': [0, 1300], 'y_rang

# Crop generator (not yet integrated)

In [10]:
from numpy import ndarray

# Processing

# Here

In [11]:
from enum import Enum

In [12]:
class AggregationStrategy(Enum):
    """This identifies the strategy used to deal with overlaping probabilities."""
    average_probabilities = 0
    
agg_strategy = 0 
agg_strategy = AggregationStrategy(agg_strategy)

In [13]:
import re

In [14]:
def get_seconds_2by2(now: float) -> str:
    s = str(int(now))
    return f"{s[:4]}-{s[4:7]}-{s[7:]}"

exec_time = time.time()

In [15]:
# data_volume = voldata_train
# execid = "train-" + get_seconds_2by2(exec_time)

# data_volume = voldata_val
# execid = "val-" + get_seconds_2by2(exec_time)

data_volume = voldata_test
execid = "test-" + get_seconds_2by2(exec_time)
partition = volume.test_partition

logger.info(f"{execid=}")

INFO::tomo2seg::{<ipython-input-15-82ef2028ef43>:<module>:011}::[2020-11-20::17:22:40.548]
execid='test-1605-889-360'



In [16]:
from tomo2seg.data import EstimationVolume

In [17]:
estimation_volume = EstimationVolume.from_objects(
    volume=volume, model=tomo2seg_model, set_partition=partition
)

for k, v in [
    ("exec_time", exec_time),
    ("execid", execid)
]:
    estimation_volume.write_metadata(k, v)

DEBUG::tomo2seg::{data.py:write_metadata:305}::[2020-11-20::17:22:40.682]
Writing to metadata file at `/home/users/jcasagrande/projects/tomo2seg/data/PA66GF30.v1/vol=PA66GF30.v1.set=test.model=unet-2d.vanilla00.000.1605-801-777.metadata.yml`

DEBUG::tomo2seg::{data.py:write_metadata:305}::[2020-11-20::17:22:40.708]
Writing to metadata file at `/home/users/jcasagrande/projects/tomo2seg/data/PA66GF30.v1/vol=PA66GF30.v1.set=test.model=unet-2d.vanilla00.000.1605-801-777.metadata.yml`



In [18]:
# this will later be useful when i transform this in python script
save_probas_by_class = True

debug__save_figs = True
debug__save_crops_coordinates = True
debug__materialize_crops = False
# figs_dir = volume.volume_processing_dir(execid)  # todo delete this method
figs_dir = estimation_volume.dir

if debug__save_figs:
    logger.debug(f"Creating debug figs directory: {figs_dir=}")
    figs_dir.mkdir(exist_ok=True)

probabilities_dtype = np.float16

DEBUG::tomo2seg::{<ipython-input-18-9f715a04b206>:<module>:011}::[2020-11-20::17:22:40.882]
Creating debug figs directory: figs_dir=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/vol=PA66GF30.v1.set=test.model=unet-2d.vanilla00.000.1605-801-777')



In [19]:
import functools

In [20]:
def crop_coord2data(coordinates: ndarray, data: ndarray) -> ndarray:
    """
    coordinates: 3x2
    data: W x H x D
    """
    (x0, x1), (y0, y1), (z0, z1) = coordinates
    return data[x0:x1, y0:y1, z0:z1]


crop_coord2data__data_loaded = functools.partial(crop_coord2data, data=data_volume)

In [21]:
import copy

# Shapes and steps 

In [22]:
# it has to be multiple of 16 because of the 4 cascaded 2x2-strided 2x2-downsamplings in u-net
xy_dims_multiple_16 = [int(16 * np.floor(dim / 16)) for dim in volume.metadata.dimensions[:2]]
logger.debug(f"{xy_dims_multiple_16=}")

crop_shape = tuple(xy_dims_multiple_16 + [1])  # x-axis, y-axis, z-axis
volume_shape = data_volume.shape

logger.debug(f"{crop_shape=}   {volume_shape=}")

n_steps = tuple(
    int(np.ceil(vol_dim / crop_dim))
    for vol_dim, crop_dim in zip(volume_shape, crop_shape)
)
logger.debug(f"{n_steps=}")

def get_coordinates_iterator(n_steps_):
    assert len(n_steps_) == 3
    return itertools.product(*(range(n_steps_[dim]) for dim in range(3)))

get_ijk_iterator = functools.partial(
    get_coordinates_iterator, copy.copy(n_steps)
)

get_kji_iterator = functools.partial(
    get_coordinates_iterator, tuple(reversed(n_steps))
)

# coordinates (xs, ys, and zs) of the front upper left corners of the crops
x0s, y0s, z0s = tuple(
    tuple(map(
        int, 
        np.linspace(0, vol_dim - crop_dim, n)
    ))
    for vol_dim, crop_dim, n in zip(volume_shape, crop_shape, n_steps)
)
logger.debug(f"""{min(x0s)=}, {max(x0s)=}, {len(x0s)=}
{min(y0s)=}, {max(y0s)=}, {len(y0s)=}
{min(z0s)=}, {max(z0s)=}, {len(z0s)=}
""")

DEBUG::tomo2seg::{<ipython-input-22-0d701cd4dcc5>:<module>:003}::[2020-11-20::17:22:41.186]
xy_dims_multiple_16=[1296, 1040]

DEBUG::tomo2seg::{<ipython-input-22-0d701cd4dcc5>:<module>:008}::[2020-11-20::17:22:41.187]
crop_shape=(1296, 1040, 1)   volume_shape=(1300, 1040, 300)

DEBUG::tomo2seg::{<ipython-input-22-0d701cd4dcc5>:<module>:014}::[2020-11-20::17:22:41.189]
n_steps=(2, 1, 300)

DEBUG::tomo2seg::{<ipython-input-22-0d701cd4dcc5>:<module>:036}::[2020-11-20::17:22:41.191]
min(x0s)=0, max(x0s)=4, len(x0s)=2
min(y0s)=0, max(y0s)=0, len(y0s)=1
min(z0s)=0, max(z0s)=299, len(z0s)=300




# Orthogonal slices figs

In [23]:
figs_common_metadata = dict(Author="joaopcbertoldo", CreationTime=str(int(exec_time)), Software="tomo2seg")
figs_common_kwargs = dict(bbox_inches="tight", format="png")

In [24]:
if debug__save_figs:
    
    logger.debug(f"Saving figure {(fig_name := 'whole-volume.orthogonal-slices.png')=}")
    fig, axs = plt.subplots(2, 2, figsize=(sz := 20, sz))
    viz.plot_orthogonal_slices(axs, data_volume, normalized_voxels=True);
    
    figpath = figs_dir / fig_name
    logger.info(f"{figpath=}")
    fig.savefig(
        fname=figpath,
        dpi=200,
        **figs_common_kwargs,
        metadata={
            "Title": f"vol={volume.fullname}::debug-fig::{fig_name}",
            **figs_common_kwargs
        }
    )
    plt.close()

    
if debug__save_figs:
    logger.debug(f"Saving figure {(fig_name := 'whole-volume.orthogonal-slices-with-(x0s, y0s, z0s).png')=}")
    fig, axs = plt.subplots(2, 2, figsize=(sz := 20, sz))
    viz.plot_orthogonal_slices(axs, data_volume, normalized_voxels=True)

    ax_xy, ax_yz, ax_xz = axs[0, 0], axs[0, 1], axs[1, 0]
    
    for x_ in x0s:
        ax_xy.vlines(x_, 0, volume_shape[0] - 1, color='g', linewidth=1)
        ax_xz.vlines(x_, 0, volume_shape[2] - 1, color='g', linewidth=1)

    for y_ in y0s:
        ax_xy.hlines(y_, 0, volume_shape[1] - 1, color='r', linewidth=1)
        ax_yz.hlines(y_, 0, volume_shape[2] - 1, color='r', linewidth=1)

    for z_ in z0s:
        ax_yz.vlines(z_, 0, volume_shape[0] - 1, color='b', linewidth=0.2)    
        ax_xz.hlines(z_, 0, volume_shape[1] - 1, color='b', linewidth=0.2)
    
    figpath = figs_dir / fig_name
    logger.info(f"{figpath=}")
    fig.savefig(
        fname=figpath,
        dpi=200,
        **figs_common_kwargs,
        metadata={
            "Title": f"vol={volume.fullname}::debug-fig::{fig_name}",
            **figs_common_kwargs
        }
    )
    plt.close()

DEBUG::tomo2seg::{<ipython-input-24-238f0b2e2bd3>:<module>:003}::[2020-11-20::17:22:41.319]
Saving figure (fig_name := 'whole-volume.orthogonal-slices.png')='whole-volume.orthogonal-slices.png'

DEBUG::tomo2seg::{viz.py:plot_orthogonal_slices:016}::[2020-11-20::17:22:41.442]
volume.shape=(1300, 1040, 300)

DEBUG::tomo2seg::{viz.py:plot_orthogonal_slices:020}::[2020-11-20::17:22:41.444]
vmin, vmax=(0, 1)

DEBUG::tomo2seg::{viz.py:plot_orthogonal_slices:025}::[2020-11-20::17:22:41.445]
No label mask given.

DEBUG::tomo2seg::{viz.py:plot_orthogonal_slices:037}::[2020-11-20::17:22:41.446]
xy_z_coord, yz_x_coord, xz_y_coord=(150, 520, 650)

DEBUG::tomo2seg::{viz.py:plot_orthogonal_slices:044}::[2020-11-20::17:22:41.467]
xy_slice.shape, yz_slice.shape, xz_slice.shape=((1300, 1040), (1040, 300), (1300, 300))

INFO::tomo2seg::{<ipython-input-24-238f0b2e2bd3>:<module>:008}::[2020-11-20::17:22:41.469]
figpath=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/vol=PA66GF30.v1.set=test.mode

# Crops coordinates 

In [26]:
import itertools

In [28]:
logger.debug("Generating the crop coordinates.")

crops_coordinates = np.array(
    [
        (
            (x0, x0 + crop_shape[0]), 
            (y0, y0 + crop_shape[1]),
            (z0, z0 + crop_shape[2]),
        )
        for x0, y0, z0 in itertools.product(x0s, y0s, z0s)
    ], 
    dtype=tuple
).reshape(len(x0s), len(y0s), len(z0s), 3, 2).astype(int)  # 3 = nb of dimenstions, 2 = (start, end)

logger.debug(f"{crops_coordinates.shape=}\n{crops_coordinates[0, 0, 0]=} ")

if debug__save_crops_coordinates:
    logger.debug(f"Saving crops coordinates at {(coords_fname := volume.dir / f'process-volume.execution={execid}.crops-coordinates.npy')=}")
    np.save(coords_fname, crops_coordinates)

crops_coordinates_sequential = crops_coordinates.reshape(-1, 3, 2, order='F')  # 'F' reshapes with x varying fastest and z slowest

logger.debug(f"{crops_coordinates_sequential.shape=}\n{crops_coordinates_sequential[0]=} ")


DEBUG::tomo2seg::{<ipython-input-28-090b16eecc98>:<module>:001}::[2020-11-20::17:25:06.745]
Generating the crop coordinates.

DEBUG::tomo2seg::{<ipython-input-28-090b16eecc98>:<module>:015}::[2020-11-20::17:25:06.751]
crops_coordinates.shape=(2, 1, 300, 3, 2)
crops_coordinates[0, 0, 0]=array([[   0, 1296],
       [   0, 1040],
       [   0,    1]]) 

DEBUG::tomo2seg::{<ipython-input-28-090b16eecc98>:<module>:018}::[2020-11-20::17:25:06.752]
Saving crops coordinates at (coords_fname := volume.dir / f'process-volume.execution={execid}.crops-coordinates.npy')=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/PA66GF30.v1/process-volume.execution=test-1605-889-360.crops-coordinates.npy')

DEBUG::tomo2seg::{<ipython-input-28-090b16eecc98>:<module>:023}::[2020-11-20::17:25:06.799]
crops_coordinates_sequential.shape=(600, 3, 2)
crops_coordinates_sequential[0]=array([[   0, 1296],
       [   0, 1040],
       [   0,    1]]) 



# one-z-slice-crops-locations.png

not kept, search fro `one-z-slice-crops-locations.png` in `process-3d-crops-entire-2d-slice`

# debug__materialize_crops

same for
`debug__materialize_crops`

# Examples of 3d crops

In [33]:
from matplotlib import patches

In [36]:
if debug__save_figs:

    logger.debug(f"Plotinng {(n_crop_plots := 3)=} examples of 3d crops.")
    
    for n, (k, j, i) in enumerate(get_kji_iterator()):
                
        if n >= n_crop_plots:
            break
            
        logger.debug(f"{(ijk := (i, j, k))=}")
        
        one_crop = crop_coord2data__data_loaded(crops_coordinates[i, j, k])
        logger.debug(f"{one_crop.shape=}")
        
        logger.debug(f"Saving figure {(fig_name := f'crop-{ijk}.orthogonal-slices.png')=}")    

        fig, axs = plt.subplots(
            nrows=2, ncols=2,
            figsize=(sz := 20, sz), 
            gridspec_kw={"wspace": (gridspace := .01), "hspace": .5 * gridspace}
        )
        for ax in axs.ravel():
            ax.axis("off")
            
        viz.plot_orthogonal_slices(axs, one_crop, normalized_voxels=True)
        
        figpath = figs_dir / fig_name
        logger.info(f"{figpath=}")
        fig.savefig(
            fname=figpath,
            dpi=200,
            **figs_common_kwargs,
            metadata={
                "Title": f"vol={volume.fullname}::debug-fig::{fig_name}",
                **figs_common_kwargs
            }
        )
        plt.close()

DEBUG::tomo2seg::{<ipython-input-36-7709108d4013>:<module>:003}::[2020-11-20::17:34:22.919]
Plotinng (n_crop_plots := 3)=3 examples of 3d crops.

DEBUG::tomo2seg::{<ipython-input-36-7709108d4013>:<module>:010}::[2020-11-20::17:34:22.920]
(ijk := (i, j, k))=(0, 0, 0)

DEBUG::tomo2seg::{<ipython-input-36-7709108d4013>:<module>:013}::[2020-11-20::17:34:22.921]
one_crop.shape=(1296, 1040, 1)

DEBUG::tomo2seg::{<ipython-input-36-7709108d4013>:<module>:015}::[2020-11-20::17:34:22.922]
Saving figure (fig_name := f'crop-{ijk}.orthogonal-slices.png')='crop-(0, 0, 0).orthogonal-slices.png'

DEBUG::tomo2seg::{viz.py:plot_orthogonal_slices:016}::[2020-11-20::17:34:22.993]
volume.shape=(1296, 1040, 1)

DEBUG::tomo2seg::{viz.py:plot_orthogonal_slices:020}::[2020-11-20::17:34:22.994]
vmin, vmax=(0, 1)

DEBUG::tomo2seg::{viz.py:plot_orthogonal_slices:025}::[2020-11-20::17:34:22.995]
No label mask given.

DEBUG::tomo2seg::{viz.py:plot_orthogonal_slices:037}::[2020-11-20::17:34:22.996]
xy_z_coord, yz_x_

# Segment an example

In [51]:
logger.debug(f"Segmenting one crop for debug {(crop_ijk := (0, 0, 0))=}")

crop_coordinates = crops_coordinates[crop_ijk[0], crop_ijk[1], crop_ijk[2]]
crop_data = crop_coord2data__data_loaded(crop_coordinates)
    
logger.debug(f"{crop_data.shape=}")

# [model] - i call it with a first crop bc if something goes wrong then the error
# will appear here instead of in a loop

# modelin
modelin = crop_data.reshape(1, crop_shape[0], crop_shape[1], 1) 
logger.debug(f"{modelin.shape=}")

# modelout
modelout = model.predict(modelin, batch_size=1)
logger.debug(f"{modelout.shape=}")

logger.debug(f"{(n_classes := modelout.shape[-1])=}")

# probas
logger.debug(f"{(crop_probas_target_shape := list(crop_shape) + [n_classes])=}")
crop_probas = modelout.reshape(crop_probas_target_shape).astype(probabilities_dtype)

# preds
crop_preds = crop_probas.argmax(axis=-1)

DEBUG::tomo2seg::{<ipython-input-51-8f0493f8090b>:<module>:001}::[2020-11-20::17:42:07.779]
Segmenting one crop for debug (crop_ijk := (0, 0, 0))=(0, 0, 0)

DEBUG::tomo2seg::{<ipython-input-51-8f0493f8090b>:<module>:006}::[2020-11-20::17:42:07.781]
crop_data.shape=(1296, 1040, 1)

DEBUG::tomo2seg::{<ipython-input-51-8f0493f8090b>:<module>:013}::[2020-11-20::17:42:07.782]
modelin.shape=(1, 1296, 1040, 1)

DEBUG::tomo2seg::{<ipython-input-51-8f0493f8090b>:<module>:017}::[2020-11-20::17:42:08.069]
modelout.shape=(1, 1296, 1040, 3)

DEBUG::tomo2seg::{<ipython-input-51-8f0493f8090b>:<module>:019}::[2020-11-20::17:42:08.070]
(n_classes := modelout.shape[-1])=3

DEBUG::tomo2seg::{<ipython-input-51-8f0493f8090b>:<module>:022}::[2020-11-20::17:42:08.071]
(crop_probas_target_shape := list(crop_shape) + [n_classes])=[1296, 1040, 1, 3]



In [53]:
if debug__save_figs:
    logger.debug(f"Saving figure {(fig_name := f'crop-{crop_ijk}.prediction.png')=}")    
    
    fig, axs = plt.subplots(1, 2, figsize=(sz := 20, 2 * sz))
    fig.set_tight_layout(True)

    axs[0].imshow(crop_data, vmin=0, vmax=1, cmap=cm.gray)
    axs[0].set_title("data")
    
    axs[1].imshow(crop_preds, vmin=0, vmax=n_classes-1, cmap=cm.gray)
    axs[1].set_title("prediction (classes)")

    figpath = figs_dir / fig_name
    logger.info(f"{figpath=}")
    fig.savefig(
        fname=figpath,
        dpi=200,
        **figs_common_kwargs,
        metadata={
            "Title": f"vol={volume.fullname}::debug-fig::{fig_name}",
            **figs_common_kwargs
        }
    )
    plt.close()

DEBUG::tomo2seg::{<ipython-input-53-e9f4a7193199>:<module>:002}::[2020-11-20::17:42:25.318]
Saving figure (fig_name := f'crop-{crop_ijk}.prediction.png')='crop-(0, 0, 0).prediction.png'

INFO::tomo2seg::{<ipython-input-53-e9f4a7193199>:<module>:014}::[2020-11-20::17:42:25.370]
figpath=PosixPath('/home/users/jcasagrande/projects/tomo2seg/data/vol=PA66GF30.v1.set=test.model=unet-2d.vanilla00.000.1605-801-777/crop-(0, 0, 0).prediction.png')

