In [1]:
%load_ext autoreload

In [15]:
%autoreload 2

from functools import partial
import logging
import pathlib
from pathlib import Path
from pprint import pprint
import sys
from typing import *
import time
import yaml
from yaml import YAMLObject
import copy
import functools

import humanize
from matplotlib import pyplot as plt, cm
import numpy as np
from numpy import ndarray
import pandas as pd
from pymicro.file import file_utils
import tensorflow as tf
from numpy.random import RandomState
from progressbar import progressbar as pbar
from enum import Enum
import re
from enum import Enum

from tensorflow import keras
from tensorflow.keras import utils
from tensorflow.keras import optimizers
from tensorflow.keras import callbacks
from tensorflow.keras import losses
from tensorflow.keras import layers

from cnn_segm import keras_custom_loss

from tomo2seg import modular_unet
from tomo2seg.logger import logger
from tomo2seg import data, viz
from tomo2seg.data import Volume
from tomo2seg.metadata import Metadata
from tomo2seg.volume_sequence import (
    VolumeCropSequence, MetaCrop3DGenerator, VSConstantEverywhere, 
    GTConstantEverywhere, SequentialGridPosition, ET3DConstantEverywhere
)
from tomo2seg import volume_sequence
from tomo2seg.model import Model as Tomo2SegModel
from tomo2seg.data import EstimationVolume
from tomo2seg import AggregationStrategy

# Setup

In [3]:
logger.setLevel(logging.DEBUG)

In [4]:
random_state = 42
random_state = np.random.RandomState(random_state)
runid = int(time.time())
logger.info(f"{runid=}")

INFO::tomo2seg::{<ipython-input-4-cf972d05bc84>:<module>:004}::[2020-11-22::15:47:53.304]
runid=1606056473



In [5]:
logger.debug(f"{tf.__version__=}")
logger.info(f"Num GPUs Available: {len(tf.config.list_physical_devices('GPU'))}\nThis should be 2 on R790-TOMO.")
logger.debug(f"Both here should return 2 devices...\n{tf.config.list_physical_devices('GPU')=}\n{tf.config.list_logical_devices('GPU')=}")

# xla auto-clustering optimization (see: https://www.tensorflow.org/xla#auto-clustering)
# this seems to break the training
tf.config.optimizer.set_jit(False)

# get a distribution strategy to use both gpus (see https://www.tensorflow.org/guide/distributed_training)
strategy = tf.distribute.MirroredStrategy()  
logger.debug(f"{strategy=}")

DEBUG::tomo2seg::{<ipython-input-5-9df7dd5953d5>:<module>:001}::[2020-11-22::15:47:53.353]
tf.__version__='2.2.0'

INFO::tomo2seg::{<ipython-input-5-9df7dd5953d5>:<module>:002}::[2020-11-22::15:47:53.355]
Num GPUs Available: 0
This should be 2 on R790-TOMO.

DEBUG::tomo2seg::{<ipython-input-5-9df7dd5953d5>:<module>:003}::[2020-11-22::15:47:53.384]
Both here should return 2 devices...
tf.config.list_physical_devices('GPU')=[]
tf.config.list_logical_devices('GPU')=[]

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
DEBUG::tomo2seg::{<ipython-input-5-9df7dd5953d5>:<module>:011}::[2020-11-22::15:47:53.389]
strategy=<tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f21ae84cee0>



# Options

In [13]:
# this will later be useful when i transform this in python script
save_probas_by_class = True

debug__save_figs = True
debug__materialize_crops = False
probabilities_dtype = np.float16

# Model

In [6]:
ls ../data/models | grep unet-2d-small

unet-2d-small.vanilla00.000.1605-971-456.autosaved.hdf5
[01;34munet-2d-small.vanilla00.000.1605-972-712[0m/
unet-2d-small.vanilla00.000.1605-972-712.autosaved.hdf5
[01;34munet-2d-small.vanilla00.000.1605-982-196[0m/
unet-2d-small.vanilla00.000.1605-982-196.autosaved.hdf5


In [7]:
tomo2seg_model = Tomo2SegModel.build_from_model_name(
    "unet-2d-small.vanilla00.000.1605-982-196"
)
logger.info(f"{tomo2seg_model=}")

with strategy.scope():
    model = tf.keras.models.load_model(
        tomo2seg_model.autosaved_model_path_str,
        compile=False
    )
    
    in_ = model.layers[0]
    in_shape = in_.input_shape[0]
    input_n_channels = in_shape[-1:]

    logger.debug(f"{input_n_channels=}")
    
    # make it capable of getting any dimension in the input
    anysize_input = layers.Input(
        shape=[None, None, None] + list(input_n_channels),
        name="input_any_image_size"
    )
    
    logger.debug(f"{anysize_input=}")
    
    model.layers[0] = anysize_input
    
    # todo keep this somewhere instead of copying and pasting
    optimizer = optimizers.Adam()
    loss_func = keras_custom_loss.jaccard2_loss

    model.compile(loss=loss_func, optimizer=optimizer)


INFO::tomo2seg::{<ipython-input-7-d7942cfb89c4>:<module>:004}::[2020-11-22::15:47:53.658]
tomo2seg_model=Model(master_name='unet-2d-small', version='vanilla00', fold=0, runid=1605982196, factory_function=None, factory_kwargs=None)

DEBUG::tomo2seg::{<ipython-input-7-d7942cfb89c4>:<module>:016}::[2020-11-22::15:47:55.357]
input_n_channels=(1,)

DEBUG::tomo2seg::{<ipython-input-7-d7942cfb89c4>:<module>:024}::[2020-11-22::15:47:55.360]
anysize_input=<tf.Tensor 'input_any_image_size:0' shape=(None, None, None, None, 1) dtype=float32>



# Data

In [8]:
from tomo2seg.datasets import (
#     VOLUME_COMPOSITE_V1 as VOLUME_NAME_VERSION,
    VOLUME_COMPOSITE_V1_REDUCED as VOLUME_NAME_VERSION,
    VOLUME_COMPOSITE_V1_LABELS_REFINED3 as LABELS_VERSION
)

volume_name, volume_version = VOLUME_NAME_VERSION
labels_version = LABELS_VERSION

logger.info(f"{volume_name=} {volume_version=} {labels_version=}")

INFO::tomo2seg::{<ipython-input-8-cdb1781c8ebd>:<module>:010}::[2020-11-22::15:47:55.422]
volume_name='PA66GF30' volume_version='v1-reduced' labels_version='refined3'



In [9]:
# Metadata/paths objects

## Volume
volume = Volume.with_check(
    name=volume_name, version=volume_version
)
logger.info(f"{volume=}")

def _read_raw(path_: Path, volume_: Volume): 
    # from pymicro
    return file_utils.HST_read(
        str(path_),  # it doesn't accept paths...
        # pre-loaded kwargs
        autoparse_filename=False,  # the file names are not properly formatted
        data_type=volume.metadata.dtype,
        dims=volume.metadata.dimensions,
        verbose=True,
    )

read_raw = partial(_read_raw, volume_=volume)

logger.info("Loading data from disk.")

## Data
voldata = read_raw(volume.data_path) / 255  # normalize
logger.debug(f"{voldata.shape=}")

voldata_train = volume.train_partition.get_volume_partition(voldata)
voldata_val = volume.val_partition.get_volume_partition(voldata)
voldata_test = volume.test_partition.get_volume_partition(voldata)

del voldata
logger.debug(f"{voldata_train.shape=} {voldata_val.shape=} {voldata_test.shape=}")

DEBUG::tomo2seg::{data.py:with_check:227}::[2020-11-22::15:47:55.491]
vol=Volume(name='PA66GF30', version='v1-reduced', _metadata=None)

ERROR::tomo2seg::{data.py:with_check:245}::[2020-11-22::15:47:55.492]
Missing file: /home/joaopcbertoldo/projects/tomo2seg/data/PA66GF30.v1-reduced/PA66GF30.v1-reduced.labels.raw

Missing file: /home/joaopcbertoldo/projects/tomo2seg/data/PA66GF30.v1-reduced/PA66GF30.v1-reduced.weights.raw

DEBUG::tomo2seg::{data.py:metadata:184}::[2020-11-22::15:47:55.495]
Loading metadata from `/home/joaopcbertoldo/projects/tomo2seg/data/PA66GF30.v1-reduced/PA66GF30.v1-reduced.metadata.yml`.

INFO::tomo2seg::{<ipython-input-9-127f893db5f4>:<module>:007}::[2020-11-22::15:47:55.503]
volume=Volume(name='PA66GF30', version='v1-reduced', _metadata=Volume.Metadata(dimensions=[256, 256, 256], dtype='uint8', labels=[0, 1, 2], labels_names={0: 'matrix', 1: 'fiber', 2: 'porosity'}, set_partitions={'train': {'x_range': [0, 256], 'y_range': [0, 256], 'z_range': [0, 128], 'alias'

# Crop generator (not yet integrated)

# Estimation Volume

In [10]:
# data_volume = voldata_train
# partition = volume.train_partition

# data_volume = voldata_val
# partition = volume.val_partition

data_volume = voldata_test
partition = volume.test_partition

agg_strategy = AggregationStrategy.average_probabilities

logger.debug(f"{data_volume.shape=} {partition=} {agg_strategy=} {runid=}")

DEBUG::tomo2seg::{<ipython-input-10-a4286ac1c483>:<module>:012}::[2020-11-22::15:47:55.677]
data_volume.shape=(256, 256, 64) partition=SetPartition(x_range=(0, 256), y_range=(0, 256), z_range=(192, 256), alias='test') agg_strategy=<AggregationStrategy.average_probabilities: 0> runid=1606056473



In [11]:
estimation_volume = EstimationVolume.from_objects(
    volume=volume, 
    model=tomo2seg_model, 
    set_partition=partition,
    runid=runid,
)
estimation_volume["aggregation_strategy"] = agg_strategy.name

logger.info(f"{estimation_volume=}")

DEBUG::tomo2seg::{data.py:metadata_path:290}::[2020-11-22::15:47:55.723]
Creating metadata file /home/joaopcbertoldo/projects/tomo2seg/data/vol=PA66GF30.v1-reduced.set=test.model=unet-2d-small.vanilla00.000.1605-982-196.runid=1606-056-473/vol=PA66GF30.v1-reduced.set=test.model=unet-2d-small.vanilla00.000.1605-982-196.runid=1606-056-473.metadata.yml.

DEBUG::tomo2seg::{data.py:__setitem__:298}::[2020-11-22::15:47:55.724]
Writing to file self.metadata_path=PosixPath('/home/joaopcbertoldo/projects/tomo2seg/data/vol=PA66GF30.v1-reduced.set=test.model=unet-2d-small.vanilla00.000.1605-982-196.runid=1606-056-473/vol=PA66GF30.v1-reduced.set=test.model=unet-2d-small.vanilla00.000.1605-982-196.runid=1606-056-473.metadata.yml').

INFO::tomo2seg::{<ipython-input-11-ccaae396c38c>:<module>:009}::[2020-11-22::15:47:55.727]
estimation_volume=EstimationVolume(volume_fullname='PA66GF30.v1-reduced', model_name='unet-2d-small.vanilla00.000.1605-982-196', runid=1606056473, partition=SetPartition(x_range=(0

# Processing

In [14]:
if debug__save_figs:
    figs_dir = estimation_volume.dir
    logger.debug(f"{figs_dir=}")
    figs_dir.mkdir(exist_ok=True)

DEBUG::tomo2seg::{<ipython-input-14-18204e9a3117>:<module>:003}::[2020-11-22::15:51:43.999]
figs_dir=PosixPath('/home/joaopcbertoldo/projects/tomo2seg/data/vol=PA66GF30.v1-reduced.set=test.model=unet-2d-small.vanilla00.000.1605-982-196.runid=1606-056-473')



# Shapes and steps 

In [16]:
# it has to be multiple of 16 because of the 4 cascaded 2x2-strided 2x2-downsamplings in u-net
xy_dims_multiple_16 = [int(16 * np.floor(dim / 16)) for dim in volume.metadata.dimensions[:2]]
logger.debug(f"{xy_dims_multiple_16=}")

crop_shape = tuple(xy_dims_multiple_16 + [1])  # x-axis, y-axis, z-axis
volume_shape = data_volume.shape

logger.debug(f"{crop_shape=}   {volume_shape=}")

n_steps = tuple(
    int(np.ceil(vol_dim / crop_dim))
    for vol_dim, crop_dim in zip(volume_shape, crop_shape)
)
logger.debug(f"{n_steps=}")

def get_coordinates_iterator(n_steps_):
    assert len(n_steps_) == 3
    return itertools.product(*(range(n_steps_[dim]) for dim in range(3)))

get_ijk_iterator = functools.partial(
    get_coordinates_iterator, copy.copy(n_steps)
)

get_kji_iterator = functools.partial(
    get_coordinates_iterator, tuple(reversed(n_steps))
)

# coordinates (xs, ys, and zs) of the front upper left corners of the crops
x0s, y0s, z0s = tuple(
    tuple(map(
        int, 
        np.linspace(0, vol_dim - crop_dim, n)
    ))
    for vol_dim, crop_dim, n in zip(volume_shape, crop_shape, n_steps)
)
logger.debug(f"""{min(x0s)=}, {max(x0s)=}, {len(x0s)=}
{min(y0s)=}, {max(y0s)=}, {len(y0s)=}
{min(z0s)=}, {max(z0s)=}, {len(z0s)=}
""")

DEBUG::tomo2seg::{<ipython-input-16-0d701cd4dcc5>:<module>:003}::[2020-11-22::15:54:23.530]
xy_dims_multiple_16=[256, 256]

DEBUG::tomo2seg::{<ipython-input-16-0d701cd4dcc5>:<module>:008}::[2020-11-22::15:54:23.532]
crop_shape=(256, 256, 1)   volume_shape=(256, 256, 64)

DEBUG::tomo2seg::{<ipython-input-16-0d701cd4dcc5>:<module>:014}::[2020-11-22::15:54:23.533]
n_steps=(1, 1, 64)

DEBUG::tomo2seg::{<ipython-input-16-0d701cd4dcc5>:<module>:036}::[2020-11-22::15:54:23.535]
min(x0s)=0, max(x0s)=0, len(x0s)=1
min(y0s)=0, max(y0s)=0, len(y0s)=1
min(z0s)=0, max(z0s)=63, len(z0s)=64




# Orthogonal slices figs

In [None]:
figs_common_kwargs = dict(bbox_inches="tight", format="png")

In [None]:
if debug__save_figs:
    
    logger.debug(f"Saving figure {(fig_name := 'whole-volume.orthogonal-slices.png')=}")
    fig, axs = plt.subplots(2, 2, figsize=(sz := 20, sz))
    viz.plot_orthogonal_slices(axs, data_volume, normalized_voxels=True);
    
    figpath = figs_dir / fig_name
    logger.info(f"{figpath=}")
    fig.savefig(
        fname=figpath,
        dpi=200,
        **figs_common_kwargs,
        metadata={
            "Title": f"vol={volume.fullname}::debug-fig::{fig_name}",
            **figs_common_kwargs
        }
    )
    plt.close()

    
if debug__save_figs:
    logger.debug(f"Saving figure {(fig_name := 'whole-volume.orthogonal-slices-with-(x0s, y0s, z0s).png')=}")
    fig, axs = plt.subplots(2, 2, figsize=(sz := 20, sz))
    viz.plot_orthogonal_slices(axs, data_volume, normalized_voxels=True)

    ax_xy, ax_yz, ax_xz = axs[0, 0], axs[0, 1], axs[1, 0]
    
    for x_ in x0s:
        ax_xy.vlines(x_, 0, volume_shape[0] - 1, color='g', linewidth=1)
        ax_xz.vlines(x_, 0, volume_shape[2] - 1, color='g', linewidth=1)

    for y_ in y0s:
        ax_xy.hlines(y_, 0, volume_shape[1] - 1, color='r', linewidth=1)
        ax_yz.hlines(y_, 0, volume_shape[2] - 1, color='r', linewidth=1)

    for z_ in z0s:
        ax_yz.vlines(z_, 0, volume_shape[0] - 1, color='b', linewidth=0.2)    
        ax_xz.hlines(z_, 0, volume_shape[1] - 1, color='b', linewidth=0.2)
    
    figpath = figs_dir / fig_name
    logger.info(f"{figpath=}")
    fig.savefig(
        fname=figpath,
        dpi=200,
        **figs_common_kwargs,
        metadata={
            "Title": f"vol={volume.fullname}::debug-fig::{fig_name}",
            **figs_common_metadata
        }
    )
    plt.close()

# Crops coordinates 

In [None]:
import itertools

In [None]:
logger.debug("Generating the crop coordinates.")

crops_coordinates = np.array(
    [
        (
            (x0, x0 + crop_shape[0]), 
            (y0, y0 + crop_shape[1]),
            (z0, z0 + crop_shape[2]),
        )
        for x0, y0, z0 in itertools.product(x0s, y0s, z0s)
    ], 
    dtype=tuple
).reshape(len(x0s), len(y0s), len(z0s), 3, 2).astype(int)  # 3 = nb of dimenstions, 2 = (start, end)

logger.debug(f"{crops_coordinates.shape=}\n{crops_coordinates[0, 0, 0]=} ")

if debug__save_crops_coordinates:
    logger.debug(f"Saving crops coordinates at {(coords_fname := volume.dir / f'process-volume.execution={execid}.crops-coordinates.npy')=}")
    np.save(coords_fname, crops_coordinates)

crops_coordinates_sequential = crops_coordinates.reshape(-1, 3, 2, order='F')  # 'F' reshapes with x varying fastest and z slowest

logger.debug(f"{crops_coordinates_sequential.shape=}\n{crops_coordinates_sequential[0]=} ")


# one-z-slice-crops-locations.png

not kept, search fro `one-z-slice-crops-locations.png` in `process-3d-crops-entire-2d-slice`

# debug__materialize_crops

same for
`debug__materialize_crops`

# Examples of 3d crops

In [None]:
from matplotlib import patches

In [None]:
if debug__save_figs:

    logger.debug(f"Plotinng {(n_crop_plots := 3)=} examples of 3d crops.")
    
    for n, (k, j, i) in enumerate(get_kji_iterator()):
                
        if n >= n_crop_plots:
            break
            
        logger.debug(f"{(ijk := (i, j, k))=}")
        
        one_crop = crop_coord2data__data_loaded(crops_coordinates[i, j, k])
        logger.debug(f"{one_crop.shape=}")
        
        logger.debug(f"Saving figure {(fig_name := f'crop-{ijk}.orthogonal-slices.png')=}")    

        fig, axs = plt.subplots(
            nrows=2, ncols=2,
            figsize=(sz := 20, sz), 
            gridspec_kw={"wspace": (gridspace := .01), "hspace": .5 * gridspace}
        )
        for ax in axs.ravel():
            ax.axis("off")
            
        viz.plot_orthogonal_slices(axs, one_crop, normalized_voxels=True)
        
        figpath = figs_dir / fig_name
        logger.info(f"{figpath=}")
        fig.savefig(
            fname=figpath,
            dpi=200,
            **figs_common_kwargs,
            metadata={
                "Title": f"vol={volume.fullname}::debug-fig::{fig_name}",
                **figs_common_kwargs
            }
        )
        plt.close()

# Segment an example

In [None]:
logger.debug(f"Segmenting one crop for debug {(crop_ijk := (0, 0, 0))=}")

crop_coordinates = crops_coordinates[crop_ijk[0], crop_ijk[1], crop_ijk[2]]
crop_data = crop_coord2data__data_loaded(crop_coordinates)
    
logger.debug(f"{crop_data.shape=}")

# [model] - i call it with a first crop bc if something goes wrong then the error
# will appear here instead of in a loop

# modelin
modelin = crop_data.reshape(1, crop_shape[0], crop_shape[1], 1) 
logger.debug(f"{modelin.shape=}")

# modelout
modelout = model.predict(modelin, batch_size=1)
logger.debug(f"{modelout.shape=}")

logger.debug(f"{(n_classes := modelout.shape[-1])=}")

# probas
logger.debug(f"{(crop_probas_target_shape := list(crop_shape) + [n_classes])=}")
crop_probas = modelout.reshape(crop_probas_target_shape).astype(probabilities_dtype)

# preds
crop_preds = crop_probas.argmax(axis=-1)

In [None]:
if debug__save_figs:
    logger.debug(f"Saving figure {(fig_name := f'crop-{crop_ijk}.prediction.png')=}")    
    
    fig, axs = plt.subplots(1, 2, figsize=(sz := 20, 2 * sz))
    fig.set_tight_layout(True)

    axs[0].imshow(crop_data, vmin=0, vmax=1, cmap=cm.gray)
    axs[0].set_title("data")
    
    axs[1].imshow(crop_preds, vmin=0, vmax=n_classes-1, cmap=cm.gray)
    axs[1].set_title("prediction (classes)")

    figpath = figs_dir / fig_name
    logger.info(f"{figpath=}")
    fig.savefig(
        fname=figpath,
        dpi=200,
        **figs_common_kwargs,
        metadata={
            "Title": f"vol={volume.fullname}::debug-fig::{fig_name}",
            **figs_common_kwargs
        }
    )
    plt.close()