# Imports

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

import copy
import functools
import gc
import itertools
import logging
import operator
import os
import pathlib
import re
import socket
import sys
import time
from collections import Counter
from dataclasses import asdict, dataclass, field
from enum import Enum
from functools import partial
from pathlib import Path
from pprint import PrettyPrinter, pprint
from typing import *

In [None]:
%autoreload 2

import humanize
import matplotlib
import numpy as np
import pandas as pd
import scipy as sp
import tensorflow as tf
import yaml
from matplotlib import cm, patches, pyplot as plt
from numpy import ndarray
from numpy.random import RandomState
from progressbar import progressbar as pbar
from pymicro.file import file_utils
from sklearn import metrics, metrics as met, model_selection, preprocessing
from tensorflow import keras
from tensorflow.keras import (
    callbacks as keras_callbacks,
    layers,
    losses,
    metrics as keras_metrics,
    optimizers,
    utils,
)
from tqdm import tqdm
from yaml import YAMLObject

In [None]:
%autoreload 2

from tomo2seg import (
    callbacks as tomo2seg_callbacks,
    data as tomo2seg_data,
    losses as tomo2seg_losses,
    schedule as tomo2seg_schedule,
    slack,
    slackme,
    utils as tomo2seg_utils,
    viz as tomo2seg_viz,
    volume_sequence,
)
from tomo2seg.data import EstimationVolume, Volume
from tomo2seg.logger import add_file_handler, dict2str, logger
from tomo2seg.model import Model as Tomo2SegModel

In [None]:
# this registers a custom exception handler for the whole current notebook
get_ipython().set_custom_exc((Exception,), slackme.custom_exc)

# Args

In [None]:
# [manual-input]
from tomo2seg.datasets import (
#     VOLUME_COMPOSITE_V1 as VOLUME_NAME_VERSION,
#     VOLUME_COMPOSITE_V1_LABELS_REFINED3 as LABELS_VERSION,
    VOLUME_FRACTURE00_SEGMENTED00 as VOLUME_NAME_VERSION,
    VOLUME_FRACTURE00_SEGMENTED00_LABELS_REFINED3 as LABELS_VERSION,
)

volume_name, volume_version = VOLUME_NAME_VERSION
labels_version = LABELS_VERSION

random_state_seed = 42
runid = int(time.time())
# runid = 1607944057

# None == all
partitions_to_compute_aliases = None

logger.info(f"{volume_name=}")
logger.info(f"{volume_version=}")
logger.info(f"{labels_version=}")
logger.info(f"{partitions_to_compute_aliases=}")

# Setup

In [None]:
logger.setLevel(logging.DEBUG)
random_state = np.random.RandomState(random_state_seed)

In [None]:
volume = Volume.with_check(
    name=volume_name, version=volume_version
)

logger.debug(f"volum=\n{dict2str(asdict(volume))}")

volume_partitions_aliases = tuple(volume.metadata.set_partitions.keys())

if partitions_to_compute_aliases is None:

    logger.info("Using all available parittions.")
    
    partitions_to_compute_aliases = volume_partitions_aliases

else:
    assert len(partitions_to_compute_aliases) >= 0
    
    for part_alias in partitions_to_compute_aliases:

        try:
            volume[part_alias]

        except KeyError as ex:
            logger.exception(ex)
            raise ValueError(f"Invalid volume partition. {volume.fullname=} {volume_partitions_aliases=} {partitions_to_compute_aliases=}")

logger.info(f"{partitions_to_compute_aliases=}")

In [None]:
exec_name = f"{volume.fullname}.ground-truth-analysis.runid={tomo2seg_utils.fmt_runid(runid)}"
exec_dir = volume.dir / exec_name
figs_dir = exec_dir

logger.info(f"{exec_name=}")
logger.info(f"{exec_dir=}")

exec_dir.mkdir(exist_ok=True)

# Load data

In [None]:
logger.info("Loading data from disk.")

data_volume = file_utils.HST_read(
    str(volume.data_path),  # it doesn't accept paths...
    
    autoparse_filename=False,  # the file names are not properly formatted
    data_type=volume.metadata.dtype,
    dims=volume.metadata.dimensions,
    verbose=False,
)

logger.debug(f"{data_volume.shape=}")

logger.info("Loading labels from disk.")

labels_volume = file_utils.HST_read(
    str(volume.versioned_labels_path(labels_version)),  # it doesn't accept paths...
    
    autoparse_filename=False,  # the file names are not properly formatted
    data_type="uint8",
    dims=volume.metadata.dimensions,
    verbose=False,
)

logger.debug(f"{labels_volume.shape=}")

def iterate_partitions() -> Tuple[ndarray, ndarray]:
    """avoid loading all the partitions one by one (more memory)"""
    for partition_alias in partitions_to_compute_aliases:
        yield (
            partition_alias,
            volume[partition_alias].get_volume_partition(data_volume),
            volume[partition_alias].get_volume_partition(labels_volume),
        )

# Useful variables

In [None]:
labels_idx = volume.metadata.labels
labels_names = [volume.metadata.labels_names[idx] for idx in labels_idx]

labels_idx_name = list(zip(labels_idx, labels_names))

n_classes = len(labels_idx)

logger.debug(f"{n_classes=}")
logger.debug(f"{labels_idx=}")
logger.debug(f"{labels_names=}")

# [compute] value histogram per label

In [None]:
MAX_BIN_EDGE = {
    "uint8": 256,
    "uint16": 65536,
}

max_bin_edge = MAX_BIN_EDGE[volume.metadata.dtype]

logger.debug(f"{max_bin_edge=}")

n_bins = 256

logger.debug(f"{n_bins=}")

hist_bin_edges = np.linspace(0, max_bin_edge, n_bins + 1).astype(int)

def get_hist_per_label(data_seq, labels_seq):
    
    assert (tensor_order := len(data_seq.shape)) == 1, f"{tensor_order}"
    
    data_hists_per_label = np.zeros((n_classes, n_bins), dtype=np.int64) # int64 is important to not overflow

    for label_idx in labels_idx:

        logger.debug(f"Computing histogram for {label_idx=}")

        data_hists_per_label[label_idx], bins = np.histogram(
            data_seq[labels_seq == label_idx],
            bins=hist_bin_edges,
            density=False,
        )
        
    return data_hists_per_label

logger.info(f"Computing value histograms per label on the partitions.")
hists_per_label = {
    partition_alias: get_hist_per_label(part_data.ravel(), part_labels.ravel())
    for partition_alias, part_data, part_labels in iterate_partitions()
}

logger.info(f"Computing value histograms per label on the whole volume.")
hists_per_label[None] = get_hist_per_label(
    data_volume.ravel(),
    labels_volume.ravel(),
)

# [save] value histogram per label

In [None]:
def get_filename_value_hist_per_label(partition_: str) -> Path:
    fname = f"value-histogram-per-label"
    fname += f".partition={partition_}" if partition_ is not None else ""
    fname += ".npy"
    return fname

logger.info(f"Saving value histogram per label for all partitions and the whole volume.")

for partition_alias, histogram_per_label in hists_per_label.items():
    
    filename = get_filename_value_hist_per_label(partition_alias)
    
    logger.debug(f"Saving {partition_alias=} ==> {filename=}")
    
    filepath = exec_dir / filename
    
    logger.debug(f"{filepath=}")

    np.save(
        file=filepath,
        arr=histogram_per_label,
    )

In [None]:
del data_volume, labels_volume

In [None]:
gc.collect()

In [None]:
logger.info("Saving bins.")

hist_bins = hist_bin_edges[:-1]

filename = exec_dir / "value-histogram-per-label.bins.npy"

logger.debug(f"{filename=}")
    
filepath = exec_dir / filename

logger.debug(f"{filepath=}")

np.save(
    file=filepath,
    arr=hist_bins,
)

# derived computations

## class imbalance

In [None]:
class_imbalance = {
    partition_alias: part_hist_per_label.sum(axis=1) 
    for partition_alias, part_hist_per_label in hists_per_label.items()
}

## value histograms

In [None]:
hists = {
    partition_alias: part_hist_per_label.sum(axis=0) 
    for partition_alias, part_hist_per_label in hists_per_label.items()
}

hists_norm = {
    partition_alias: part_hist / part_hist.sum() 
    for partition_alias, part_hist in hists.items()
}

## value histograms NORMED per label 

In [None]:
hists_per_label_norm = {
    partition_alias: part_hist / part_hist.sum(axis=1, keepdims=True)
    for partition_alias, part_hist in hists_per_label.items()
}

## value histograms per label GLOBAL NORMED

In [None]:
hists_per_label_global_norm = {
    partition_alias: part_hist / part_hist.sum() 
    for partition_alias, part_hist in hists_per_label.items()
}

# plots

In [None]:
def get_line_label_simple(label_idx):
    return labels_names[label_idx]

def get_line_label_with_nvoxels(label_idx):
    return f"{labels_names[label_idx]} (nvoxels: {humanize.intcomma(class_imb[label_idx])})"

## class imbalance

In [None]:
for partition_alias, class_imb in class_imbalance.items():

    fig, ax = plt.subplots(1, 1, figsize=(sz := 7, sz), dpi=(dpi := 120))

    display = tomo2seg_viz.ClassImbalanceDisplay(
        volume_name=f"{volume.fullname}" + ("" if partition_alias is None else f"  --  partition={partition_alias}"),
        labels_idx=labels_idx,
        labels_names=labels_names,
        labels_counts=class_imb,
    ).plot(ax)

    logger.info(f"Saving figure {(figname := display.title + '.png')=}")
    
    display.fig_.savefig(
        fname=figs_dir / figname,
        format="png",
        metadata=display.metadata,
    )

## value histogram

In [None]:
for partition_alias, hist_ in hists_norm.items():

    fig, ax = plt.subplots(1, 1, figsize=(2 * (sz := 8), sz), dpi=(dpi := 120))

    # i want to get the vertical borders to show up
    display = tomo2seg_viz.VoxelValueHistogramDisplay(
        volume_name=f"{volume.fullname}" + ("" if partition_alias is None else f"  --  partition={partition_alias}"),
        bins=hist_bins.tolist(),
        values=hist_.tolist(),
    ).plot(ax)

    logger.info(f"Saving figure {(figname := display.title + '.png')=}")

    display.fig_.savefig(
        fname=figs_dir / figname,
        format="png",
        metadata=display.metadata,
    )

## value histogram per label

In [None]:
for partition_alias in hists_per_label_norm.keys():
    
    hist_per_label_normed_global_ = hists_per_label_global_norm[partition_alias]
    hist_per_label_normed_ = hists_per_label_norm[partition_alias]
    
    fig, axs = plt.subplots(
        nrows := 2, ncols := 1, figsize=(ncols * 1.75 * (sz := 8), nrows * sz), dpi=(dpi := 120),
        gridspec_kw=dict(hspace=sz / 15)
    )

    display = tomo2seg_viz.VoxelValueHistogramPerClassDisplay(
        
        volume_name=f"{volume.fullname}" + ("" if partition_alias is None else f"  --  partition={partition_alias}"),

        bins=hist_bins.tolist(),
        
        values_per_label=hist_per_label_normed_.tolist(),
        values_per_label_global_proportion=hist_per_label_normed_global_.tolist(),
        
        labels_idx=labels_idx,
        line_labels={
            idx: get_line_label_with_nvoxels(idx) for idx in labels_idx
        },
        
    ).plot(axs)
    
    # [manual-input]
    axs[0].set_ylim(top=.20)

    logger.info(f"Saving figure {(figname := display.title + '.png')=}")
    display.fig_.savefig(
        fname=figs_dir / figname,
        format="png",
        metadata=display.metadata,
    )

# Physical metrics

# Save notebook

In [None]:
this_nb_name = "analyse-ground-truth-00.ipynb"

import os
this_dir = os.getcwd()
logger.warning(f"{this_nb_name=} {this_dir=}")

os.system(f"jupyter nbconvert {this_dir}/{this_nb_name} --output-dir {str(exec_dir)} --to html")