In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

import copy
import functools
import gc
import itertools
import logging
import operator
import os
import pathlib
import re
import socket
import sys
import time
from collections import Counter
from dataclasses import asdict, dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from pprint import PrettyPrinter, pprint
from typing import *

In [None]:
%autoreload 2

import humanize
import matplotlib
import numpy as np
import pandas as pd
import scipy as sp
import tensorflow as tf
import yaml
from matplotlib import cm, patches, pyplot as plt
from numpy import ndarray
from numpy.random import RandomState
from progressbar import progressbar as pbar
from pymicro.file import file_utils
from sklearn import metrics, metrics as met, model_selection, preprocessing
from tensorflow import keras
from tensorflow.keras import (
    callbacks as keras_callbacks,
    layers,
    losses,
    metrics as keras_metrics,
    optimizers,
    utils,
)
from tqdm import tqdm
from yaml import YAMLObject

In [None]:
%autoreload 2

from tomo2seg import (
    callbacks as tomo2seg_callbacks,
    data as tomo2seg_data,
    losses as tomo2seg_losses,
    schedule as tomo2seg_schedule,
    slack,
    slackme,
    utils as tomo2seg_utils,
    viz as tomo2seg_viz,
    volume_sequence,
)
from tomo2seg.data import EstimationVolume, Volume
from tomo2seg.logger import add_file_handler, dict2str, logger
from tomo2seg.model import Model as Tomo2SegModel
from tomo2seg.analyse_gt import AnalyseGroundTruthMetaArgs as MetaArgs
from tomo2seg.analyse_gt import AnalyseGroundTruthOuputs as Outputs
from tomo2seg import analyse_gt
from tomo2seg.outputs import BaseOutputs, mkdir_ok

In [None]:
# this registers a custom exception handler for the whole current notebook
get_ipython().set_custom_exc((Exception,), slackme.custom_exc)

# MetaArgs

In [None]:
# [manual-input]
from tomo2seg.datasets import (
    VOLUME_COMPOSITE_V1 as VOLUME_NAME_VERSION,
    VOLUME_COMPOSITE_V1_LABELS_REFINED3 as LABELS_VERSION,
)

meta_args = MetaArgs(
    
    volume_name=VOLUME_NAME_VERSION[0],
    volume_version=VOLUME_NAME_VERSION[1],
    labels_version=LABELS_VERSION,
    
#     partitions_to_compute="...",  # default: all
    partitions_to_compute=("train", "val", "test"),
    
    script_name="analyse-gt-00.ipynb",
    
    host=None,  # None = auto
    runid=None,  # None = auto
    random_state_seed=42,  # None = auto
)

# `tomo2seg` objects 

In [None]:
volume = Volume.with_check(
    name=meta_args.volume_name, version=meta_args.volume_version
)

logger.info(f"volum=\n{dict2str(asdict(volume))}")

# Args

In [None]:
partitions_to_compute = analyse_gt.validate_partitions_to_compute(
    meta_args.partitions_to_compute, volume
)

logger.info(f"{partitions_to_compute=}")

random_state_seed = meta_args.random_state_seed
random_state = np.random.RandomState(random_state_seed)

runid = meta_args.runid

volume_name = volume.fullname

data_path = str(volume.data_path)
data_meta = {
    "dtype": volume.metadata.dtype,
    "dims": volume.metadata.dimensions,
}

labels_path = str(volume.versioned_labels_path(meta_args.labels_version))

partition_slices = {
    part: analyse_gt.partition2slice(volume[part]) 
    for part in partitions_to_compute
}

labels_idx = volume.metadata.labels
labels_names = [volume.metadata.labels_names[idx] for idx in labels_idx]
labels_idx_name = dict(zip(labels_idx, labels_names))
n_classes = len(labels_idx)

adjacent_layers_correlation_nlayer_arg_vals = (1,)

outputs_dir = volume.dir / "ground-truth-analysis"  # todo move me to the volume obs
outputs_dir.mkdir(exist_ok=True)

# Outputs

In [None]:
outputs = Ouputs(outputs_dir)

# Setup

In [None]:
logger.setLevel(logging.DEBUG)

# Exec

## Load data

In [None]:
logger.info("Loading data from disk.")
data_volume = file_utils.HST_read(
    data_path,  # it doesn't accept paths...
    autoparse_filename=False,  # the file names are not properly formatted
    data_type=data_meta["dtype"],
    dims=data_meta["dims"],
    verbose=False,
)
logger.debug(f"{data_volume.shape=}")

logger.info("Loading labels from disk.")
labels_volume = file_utils.HST_read(
    labels_path,  # it doesn't accept paths...
    autoparse_filename=False,  # the file names are not properly formatted
    data_type="uint8",
    dims=data_meta["dims"],
    verbose=False,
)
logger.debug(f"{labels_volume.shape=}")

In [None]:
def iterate_partitions() -> Tuple[ndarray, ndarray]:
    """avoid loading all the partitions one by one (more memory)"""
    
    for part in partitions_to_compute:
        slice_ = partition_slices[part]
        yield (
            part,
            data_volume[slice_],
            labels_volume[slice_],
        )

## histogram per label

### [compute] histogram per label

In [None]:
logger.info(f"Computing value histograms per label on the partitions.")
hists_per_label = {
    alias_: analyse_gt.get_hist_per_label(
        data_.ravel(), 
        labels_.ravel(),
        labels_idx=labels_idx,
    )[0]  # 1 is the edges array
    for alias_, data_, labels_ in iterate_partitions()
}
logger.info("Done.")

In [None]:
logger.info(f"Computing value histograms per label on the whole volume.")
hists_per_label[None], hist_bin_edges = analyse_gt.get_hist_per_label(
    data_volume.ravel(),
    labels_volume.ravel(),
    labels_idx=labels_idx,
)
logger.info("Done.")

In [None]:
hist_bins = hist_bin_edges[:-1]

### [save] histogram per label

In [None]:
logger.info(f"Saving value histogram per label for all partitions and the whole volume.")

logger.info("Saving bins.")

filepath = outputs.histogram_per_label_bins
logger.debug(f"bins ==> {filepath=}")
np.save(file=filepath, arr=hist_bin_edges[:-1])

logger.info("Saving values.")

for alias, hist in hists_per_label.items():
    
    filepath = outputs.histogram_per_label(alias)
    logger.debug(f"{alias=} ==> {filepath=}")
    
    np.save(file=filepath, arr=hist)

logger.info("Done.")

### derived computations

#### class imbalance

In [None]:
class_imbalance = {
    partition_alias: part_hist_per_label.sum(axis=1) 
    for partition_alias, part_hist_per_label in hists_per_label.items()
}

#### histograms

In [None]:
hists = {
    partition_alias: part_hist_per_label.sum(axis=0) 
    for partition_alias, part_hist_per_label in hists_per_label.items()
}

hists_norm = {
    partition_alias: part_hist / part_hist.sum() 
    for partition_alias, part_hist in hists.items()
}

#### normalized histograms per label 

In [None]:
hists_per_label_norm = {
    partition_alias: part_hist / part_hist.sum(axis=1, keepdims=True)
    for partition_alias, part_hist in hists_per_label.items()
}

#### histograms per label globally normalized

In [None]:
hists_per_label_global_norm = {
    partition_alias: part_hist / part_hist.sum() 
    for partition_alias, part_hist in hists_per_label.items()
}

### plots

In [None]:
def get_line_label(label_idx, with_nvoxels=False):
    lab = labels_names[label_idx]
    lab += (
        f" (nvoxels: {humanize.intcomma(class_imb[label_idx])})" 
        if with_nvoxels else 
        ""
    )
    return lab

#### class imbalance

In [None]:
nplots = len(partitions_to_compute) + 1

fig, axs = plt.subplots(
    nrows := int(np.ceil(nplots / 2)), 
    ncols := 2, 
    figsize=(ncols * (sz := 5), nrows * sz), 
    dpi=(dpi := 100),
)

colors = ["b", "r", "g", "orange"][:len(partitions_to_compute) + 1]
colors = dict(zip(list(partitions_to_compute) + [None], colors))

for ax, (partition_alias, class_imb) in zip(axs.ravel(), class_imbalance.items()):

    display = tomo2seg_viz.ClassImbalanceDisplay(
        volume_name=f"{volume_name}" + (
            "" if partition_alias is None else f"  --  partition={partition_alias}"
        ),
        labels_idx=labels_idx,
        labels_names=labels_names,
        labels_counts=class_imb,
    ).plot(
        ax=ax,
        barh_kwargs=dict(
            color=colors[partition_alias],
            height=.6,
        ),
        count_fmt_func=lambda c: f"{humanize.intword(c)}",
        perc_fmt_func=lambda p: f"{p:.1%}",
    )

fig.savefig(fname=outputs.class_imbalance_plot, format="png")

#### value histogram

In [None]:
for partition_alias, hist_ in hists_norm.items():

    fig, ax = plt.subplots(1, 1, figsize=(2 * (sz := 4), sz), dpi=(dpi := 70))

    # i want to get the vertical borders to show up
    display = tomo2seg_viz.VoxelValueHistogramDisplay(
        volume_name=f"{volume_name}" + (
            "" if partition_alias is None else f"  --  partition={partition_alias}"
        ),
        bins=hist_bins.tolist(),
        values=hist_.tolist(),
    ).plot(ax)

    fig.savefig(
        fname=outputs.histogram_plot(partition_alias),
        format="png",
    )

#### value histogram per label

In [None]:
for partition_alias in hists_per_label_norm.keys():
    
    hist_per_label_normed_global_ = hists_per_label_global_norm[partition_alias]
    hist_per_label_normed_ = hists_per_label_norm[partition_alias]
    
    fig, axs = plt.subplots(
        nrows := 2, ncols := 1, figsize=(ncols * 1.75 * (sz := 6), nrows * sz), dpi=(dpi := 70),
        gridspec_kw=dict(hspace=sz / 15)
    )

    display = tomo2seg_viz.VoxelValueHistogramPerClassDisplay(
        
        volume_name=f"{volume_name}" + (
            "" if partition_alias is None else f"  --  partition={partition_alias}"
        ),

        bins=hist_bins.tolist(),
        
        values_per_label=hist_per_label_normed_.tolist(),
        values_per_label_global_proportion=hist_per_label_normed_global_.tolist(),
        
        labels_idx=labels_idx,
        line_labels={
            idx: get_line_label(idx, with_nvoxels=True) 
            for idx in labels_idx
        },
        
    ).plot(axs)
    
    # [manual-input]
    axs[0].set_ylim(top=.20)

    fig.savefig(
        fname=outputs.histogram_per_label_plot(partition_alias),
        format="png",
    )

## adjacent layers correlation

### [compute] adjacent layers correlation

In [None]:
@dataclass
class AdjacentLayerCorrelation:
    axis: int
    nlayers: int
    label: Optional[int]
    
    values: List[int] = field(repr=False)
        
correlations = [
    AdjacentLayerCorrelation(
        axis = axis,
        nlayers = nlayers,
        label = label,
        values = analyse_gt.adjacent_layers_correlation(
            labels_volume,
            axis,
            nlayers,
            partial(
                analyse_gt.jaccard,
                label=label,
            )
        )
    )
    for axis, nlayers, label in pbar(list(itertools.product(
        list(range(3)),
        adjacent_layers_correlation_nlayer_arg_vals,
        [None] + list(labels_idx),
    )))
]

### [save] adjacent layers correlation

In [None]:
logger.info("Saving adjacent layers correlation series.")

for corr in pbar(correlations):
    
    filepath = outputs.layers_correlation(
        axis=corr.axis,
        nlayers=corr.nlayers,
        label=corr.label,
    )
    
    np.save(filepath, corr.values)

### [plot] adjacent layers correlation

In [None]:
fig, axs = plt.subplots(
    nrows := n_classes + 1,
    ncols := 3,
    figsize = (
        ncols * (sz := 8),
        nrows * sz,
    ),
    dpi = 200,
)


def corr2ax(corr: AdjacentLayerCorrelation):
    return axs[
        corr.label if corr.label is not None else -1, 
        corr.axis
    ]


for corr in correlations:
    ax = corr2ax(corr)
    ax.plot(
        corr.values,
        label=f"nlayers={corr.nlayers}",
        linewidth=.5,
        linestyle=':',
    )
    
for ax in axs.ravel():
    ax.set_ylim(0, 1)
    
for axis in range(3):
    for label in list(range(n_classes)) + [None]:
        axs[label if label is not None else -1, axis].set_title(
            f"{axis=} label={label if label is not None else 'all'}"
        )
        
fig.suptitle(f"{volume_name} adjacent layer correlation")
        
fig.savefig(fname=outputs.layers_correlation_plot, format="png")
plt.close();

## class imbalance per layer

### [compute] class imbalance per layer

In [None]:
logger.info("Computing class imbalance per layer series.")

@dataclass
class LayerwiseLabelCount:
    """only intended for this notebook"""
    
    axis: int
    values: ndarray = field(repr=False)  # nb of voxels per class per layer
        
    def __post_init__(self):
        
        assert self.values.shape == (labels_volume.shape[self.axis], n_classes),  f"{self.values.shape=} {self.axis=} {labels_volume.shape=} {n_classes=}"
        assert self.values.dtype == np.int64, f"{self.values.dtype=}"
        
        nvoxels_per_layer = list(labels_volume.shape)
        nvoxels_per_layer.pop(self.axis)
        nvoxels_per_layer = functools.reduce(
            operator.mul,
            nvoxels_per_layer,
        )
        
        for rowidx, row in enumerate(self.values):
            assert row.sum() == nvoxels_per_layer, f"{rowidx=} {row=} {nvoxels_per_layer}"

            
layerwise_label_count = {
    axis: LayerwiseLabelCount(
        axis = axis,
        values = analyse_gt.class_counts_per_layer(
            labels_volume,
            axis,
            n_classes,
        )
    )
    for axis in range(3)
}

### [save] class imbalance per layer

In [None]:
logger.info("Saving class imbalance per layer series.")

for lab_count in pbar(layerwise_label_count.values()):
    
    filepath = outputs.layerwise_class_count(
        axis=lab_count.axis,
    )
    
    np.save(filepath, lab_count.values)

### [plot] class imbalance per layer

In [None]:
fig, axs = plt.subplots(
    nrows := 3,
    ncols := 2,
    figsize = (
        ncols * (sz := 8),
        nrows * sz,
    ),
    dpi = 200,
)


def lab_count2axs(lab_count: LayerwiseLabelCount):
    return axs[lab_count.axis, 0], axs[lab_count.axis, 1]


for lab_count in layerwise_label_count.values():
    
    ax_count, ax_proportion = lab_count2axs(lab_count)
    
    proportions = lab_count.values / lab_count.values.sum(axis=1, keepdims=True)
    
    ax_count.plot(lab_count.values)
    ax_proportion.plot(proportions)
    
    ax_count.legend(labels_names)
    ax_proportion.legend(labels_names)
    
    ax_proportion.set_ylim(0, 1)
    
    ax_count.set_title(f"counts axis={lab_count.axis}")
    ax_proportion.set_title(f"proportions axis={lab_count.axis}")

fig.suptitle(f"{volume_name}: class imbalance per layer")

fig.savefig(fname=outputs.layerwise_class_count_plot, format="png")
plt.close();

# Physical metrics

In [None]:
# - voxel size
# - volume size
# - fiber length
# - fiber diameter
# - porosity diameter