In [None]:
%load_ext autoreload

In [None]:
import os
import copy
import dataclasses
from dataclasses import asdict, dataclass
import functools
import gc
import itertools
import logging
import operator
import pprint as pprint_module
import time
from functools import partial
from pathlib import Path
import sys

In [None]:
import humanize
import numpy as np
import tensorflow as tf
from cnn_segm import keras_custom_loss
from matplotlib import pyplot as plt
from numpy.random import RandomState
from progressbar import progressbar as pbar
from pymicro.file import file_utils
import socket
import pandas as pd
from tensorflow.keras import layers
from tensorflow.keras import optimizers
from tensorflow.keras.models import Model as KerasModel

In [None]:
%autoreload 2

from tomo2seg.process import reduce_dimensions 
from tomo2seg.args import ProcessVolumeArgs as Args
from tomo2seg import viz as t2s_viz
from tomo2seg.data import EstimationVolume
from tomo2seg.data import Volume
from tomo2seg.logger import add_file_handler as logger_add_file_handler
from tomo2seg.logger import dict2str
from tomo2seg.logger import logger
from tomo2seg.model import Model as Tomo2SegModel
from tomo2seg import utils as tomo2seg_utils
from tomo2seg import slackme
from tomo2seg import slack
from tomo2seg import volume_sequence
from tomo2seg import hosts as t2s_hosts
from tomo2seg import datasets as t2s_datasets 

In [None]:
# this registers a custom exception handler for the whole current notebook
get_ipython().set_custom_exc((Exception,), slackme.custom_exc)

# Args

In [None]:
# [manual-input]
volume_name_version = t2s_datasets.VOLUME_COMPOSITE_V1
model_name = "paper-unet-2d.f16-stripping-layernorm.fold000.1612-341-593"

train_partition_name = None
val_partition_name = None

# [derived-input]
train_partition_name = train_partition_name or 'train'
val_partition_name = val_partition_name or 'val'

volume_name = volume_name_version[0]
volume_version = volume_name_version[1]

# `tomo2seg` objects 

In [None]:
tomo2seg_model = Tomo2SegModel.build_from_model_name(model_name)

volume = Volume.with_check(
    name=volume_name, 
    version=volume_version
)

train_partition = volume[train_partition_name]
val_partition = volume[val_partition_name]

# Vars

In [None]:
n_classes = volume.nclasses

output_dir = tomo2seg_model.model_path

# Ouput

In [None]:
def mkdir_ok(property_): 
    """
    Make sure that a directory returned from a property exists.
    """
    
    @functools.wraps(property_)
    def wrapper(self) -> Path:
        dir_: Path = property_(self)
        dir_.mkdir(exist_ok=True)
        return dir_
    
    return wrapper


@dataclass
class OutputFiles:
    
    root_dir: Path
    
    def __post_init__(self):
        
        assert self.root_dir.is_dir()
        
    @property
    @mkdir_ok
    def snapshots_root(self) -> Path:
        return self.root_dir / "snapshots-during-training"
        
    @property
    @mkdir_ok
    def snapshots_dir(self) -> Path:
        return self.snapshots_root / "snapshots"
    
    @property
    @mkdir_ok
    def snapshots_single_crop_dir(self) -> Path:
        return self.snapshots_root / "snapshots_single_crop_dir"
    
    @property
    @mkdir_ok
    def snapshots_single_crop_from_each_dataset_dir(self) -> Path:
        return self.snapshots_root / "snapshots_single_crop_from_each_dataset_dir"
        
        

In [None]:
output_files = OutputFiles(output_dir)

# Setup

In [None]:
# get a distribution strategy to use both gpus (see https://www.tensorflow.org/guide/distributed_training)
tf_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
logger.debug(f"{tf_strategy=}")

##### models

In [None]:
len(tomo2seg_model.autosaved2_all())

# Exec

In [None]:
def get_keras_model(model_path):

    model = tf.keras.models.load_model(str(model_path), compile=False)
    in_ = model.layers[0]
    in_shape = in_.input_shape[0]
    input_n_channels = in_shape[-1]
    # make it capable of getting any dimension in the input
    # "-2" = 1 for the batch size, 1 for the nb.channels
    anysize_target_shape = (len(in_shape) - 2) * [None] + [input_n_channels]     
    anysize_input = layers.Input(
        shape=anysize_target_shape,
        name="input_any_image_size"
    )
    model.layers[0] = anysize_input
    # this doesn't really matter bc this script will not fit the model
    optimizer = optimizers.Adam()
    loss_func = keras_custom_loss.jaccard2_loss
    model.compile(loss=loss_func, optimizer=optimizer)
    return model

In [None]:
# todo: move me up
models_path_list = [
    model_path
    for model_path in tomo2seg_model.autosaved2_all()
]

In [None]:
with tf_strategy.scope():
    logger.info(f"Loading models with {tf_strategy.__class__.__name__}.")
    
    keras_models = [
        get_keras_model(model_path)
        for model_path in pbar(models_path_list)
    ]
    
    logger.info("done")

In [None]:
# todo: move me up
from typing import Tuple

In [None]:
@dataclass
class ModelSnapshot:
    
    filename: str
    keras: KerasModel
        
    def __lt__(self, other):
        return self.filename < other.filename
    
    @property
    def epoch_valloss(self) -> Tuple[int, float]:
        epoch, loss = self.filename.split(".")[-3:-1]
        epoch = int(epoch.split("-")[-0])
        loss = float("0." + loss)
        return epoch, loss
    
    @property
    def epoch(self) -> int:
        return self.epoch_valloss[0]
        
    @property
    def val_loss(self) -> float:
        return self.epoch_valloss[1]

In [None]:
models = sorted([
    ModelSnapshot(
        filename=path.name,
        keras=k_model,
    )
    for k_model, path in zip(keras_models, models_path_list)
])

##### data

In [None]:
# todo: move me up
data_path = volume.data_path
data_dtype = volume.metadata.dtype

In [None]:
logger.info(f"Loading data from disk at file: {data_path.name}")
logger.debug(f"{data_path=}")

normalization_factor = volume_sequence.NORMALIZE_FACTORS[data_dtype]  # todo move to utils
logger.debug(f"{normalization_factor=}")

data_volume = file_utils.HST_read(
    str(volume.data_path),  # it doesn't accept paths...
    autoparse_filename=False,  # the file names are not properly formatted
    data_type=volume.metadata.dtype,
    dims=volume.metadata.dimensions,
    verbose=True,
) / normalization_factor  # normalize

logger.debug(f"{data_volume.shape=}")

In [None]:
logger.info(f"Cutting data with {train_partition.alias=} and {val_partition.alias=}")  # todo: disentanble me, move me to utils
logger.debug(f"{train_partition=}")
logger.debug(f"{val_partition=}")

data_volume_train = train_partition.get_volume_partition(data_volume)
data_volume_val = val_partition.get_volume_partition(data_volume)
    
logger.info("done")

In [None]:
# the crops are not being serialized correctly...
# this cells fixes it
import csv

# todo: move me up
train_metacrop_history_path = tomo2seg_model.train_metacrop_history_path
val_metacrop_history_path = tomo2seg_model.val_metacrop_history_path


# todo: move to utils
def modify_filename(filepath: Path, prefix: str = "", suffix: str = "") -> Path:
    path, filename = os.path.split(filepath)
    filename, extension = os.path.splitext(filename)
    filename = f"{prefix}{filename}{suffix}{extension}"
    return Path(path) / filename


def fix_history_csv(in_filepath, out_filepath):
    
    with in_filepath.open("r") as infile, out_filepath.open("w") as outfile:

        reader = csv.reader(infile, delimiter=";")  # it is not the ";"
        writer = csv.writer(outfile, delimiter=";")  # now it'll be

        for line in pbar(reader, prefix=in_filepath.name):

            pieces = line[0].split(",")

            if len(pieces) > 9:  # not the header
                pieces = [
                    pieces[0],
                    ",".join(pieces[1:4]),
                    ",".join(pieces[4:7]),
                    ",".join(pieces[7:10]),
                ] + pieces[10:]

            writer.writerow(pieces)

            
fix_history_csv(
    train_metacrop_history_path,
    modify_filename(train_metacrop_history_path, suffix=".fixed")
)
            
fix_history_csv(
    val_metacrop_history_path,
    modify_filename(val_metacrop_history_path, suffix=".fixed")
)
    
train_metacrop_history_path = modify_filename(train_metacrop_history_path, suffix=".fixed")
val_metacrop_history_path = modify_filename(val_metacrop_history_path, suffix=".fixed")

In [None]:
train_metacrop_hist = pd.read_csv(train_metacrop_history_path, sep=";")
val_metacrop_hist = pd.read_csv(val_metacrop_history_path, sep=";")

# get only the first batch in each epoch
train_metacrop_hist = train_metacrop_hist[train_metacrop_hist.batch_idx == 0]
val_metacrop_hist = val_metacrop_hist[val_metacrop_hist.batch_idx == 0]

# get only the first sample in each batch
batch_size = 10
train_metacrop_hist = train_metacrop_hist[np.arange(train_metacrop_hist.shape[0]) % batch_size == 0]
val_metacrop_hist = val_metacrop_hist[np.arange(val_metacrop_hist.shape[0]) % batch_size == 0]

In [None]:
# def line2meta_crop(line: str):
    
line = train_metacrop_hist.iloc[0].to_dict()

def csv_line2obj(line):
    import ast
    del line['batch_idx']
    del line['gt_type']

    line['gt'] = volume_sequence.GT2D[line['gt']]
    line['et'] = ast.literal_eval(line['et'])

    for kw in ['x', 'y', 'z']:
        line[kw] = slice(*[
            int(x) 
            if x != "None" else
            None
            for x in [
                x.strip()
                for x in line[kw].split("(")[1].split(")")[0].split(",")
            ]
        ])

    return volume_sequence.MetaCrop3D(**line)

In [None]:
train_meta_crops = [csv_line2obj(line) for _, line in train_metacrop_hist.iterrows()]
val_meta_crops = [csv_line2obj(line) for _, line in val_metacrop_hist.iterrows()]

In [None]:
meta2crop = functools.partial(
    volume_sequence.meta2crop,
    is_label=False,
    interpolation='spline',
)


def process_metacrop(metacrop, datavol, keras_model):
    
    crop_data = meta2crop(
        metacrop, 
        volume=datavol,
    )
    crop_data = np.expand_dims(
        crop_data,
        axis=0
    )
    
    crop_segm = keras_model.predict(crop_data)
    crop_segm = crop_segm.squeeze().argmax(axis=-1)  
    
    return crop_data, crop_segm

# Snapshots

In [None]:
@dataclass
class Snapshot:
    model: ModelSnapshot
    train_crop_data: np.ndarray
    train_crop_segm: np.ndarray
    val_crop_data: np.ndarray
    val_crop_segm: np.ndarray

In [None]:
snapshots = []

for model in pbar(models):
    epoch = model.epoch
    
    train_meta_crop = train_meta_crops[epoch]
    
    train_crop_data, train_crop_segm = process_metacrop(
        train_meta_crop,
        data_volume_train,
        model.keras,
    )
    
    val_meta_crop = val_meta_crops[epoch]

    val_crop_data, val_crop_segm = process_metacrop(
        val_meta_crop,
        data_volume_val,
        model.keras,
    )
    
    snapshots.append(Snapshot(
        model,
        train_crop_data,
        train_crop_segm,
        val_crop_data,
        val_crop_segm,
    ))

In [None]:
snapshots_single_crop = []

val_meta_crop = train_meta_crop = volume_sequence.MetaCrop3D(
    x=slice(512, 768, None),
    y=slice(512, 768, None),
    z=slice(0, 1, None),
    et=None,
    gt=volume_sequence.GT2D.identity,
    vs=0,
    is_2halfd=False,
)

for model in pbar(models):
    
    train_crop_data, train_crop_segm = process_metacrop(
        train_meta_crop,
        data_volume_train,
        model.keras,
    )
    
    val_crop_data, val_crop_segm = process_metacrop(
        val_meta_crop,
        data_volume_val,
        model.keras,
    )
    
    snapshots_single_crop.append(Snapshot(
        model,
        train_crop_data,
        train_crop_segm,
        val_crop_data,
        val_crop_segm,
    ))

In [None]:
snapshots_single_crop_from_each_dataset = []

train_meta_crop = train_meta_crops[0]
val_meta_crop = val_meta_crops[0]

for model in pbar(models):
    
    train_crop_data, train_crop_segm = process_metacrop(
        train_meta_crop,
        data_volume_train,
        model.keras,
    )
    
    val_crop_data, val_crop_segm = process_metacrop(
        val_meta_crop,
        data_volume_val,
        model.keras,
    )
    
    snapshots_single_crop_from_each_dataset.append(Snapshot(
        model,
        train_crop_data,
        train_crop_segm,
        val_crop_data,
        val_crop_segm,
    ))

# Plot

In [None]:
# todo: move me up
hist_df = pd.read_csv(tomo2seg_model.history_path)

In [None]:
# todo: move me up
model_name = tomo2seg_model.name
model_alias = "2D (f16)"

In [None]:
from enum import Enum
from typing import *
from dataclasses import *
from tomo2seg.viz import Axes, check_matplotlib_support
from numpy import ndarray

@dataclass
class TrainingDisplay(t2s_viz.Display):
    """Structured inspired in `sklearn.metrics.RocCurveDisplay`"""

    class XAxisMode(Enum):
        epoch = 0
        batch = 1
        crop = 2
        voxel = 3
        time = 4

    history: Dict[str, List]
        
    x_axis_mode: Union[XAxisMode, Tuple[XAxisMode]] = (XAxisMode.epoch,)
    metrics: Tuple[str] = ("loss", "val_loss")    

    model_name: Optional[str] = None

    # not arguments
    xs_: dict = field(init=False)
    ys_: dict = field(init=False)
    ax_: Axes = field(init=False)

    def safe_get_from_history(self, key, assertion_types):
        
        assertion_types = (assertion_types,) if not isinstance(assertion_types, tuple) else assertion_types

        try:
            should_be_list = self.history[key]

        except KeyError as ex:
            msg = f"The history dict given to {self.__class__.__name__} does not have {key=}."
            logger.error(msg)
            raise ex
        
        assert isinstance(should_be_list, list), f"{type(should_be_list)=}"
        assert any(
            isinstance(should_be_list[0], at)
            for at in assertion_types
        ), f"{type(should_be_list[0])=} not \in {assertion_types=}"
        
        return should_be_list

    def __post_init__(self):
        
        mode = self.x_axis_mode
        
        self.x_axis_mode = mode = (mode,) if not isinstance(mode, tuple) else mode

        for mod in mode:
            assert isinstance(mod, self.XAxisMode), f"{type(mod)=} in {self.x_axis_mode=}"
        
        xs = {}

        for mod in self.x_axis_mode:
            
            if mod == self.XAxisMode.epoch:
                
                try:
                    x = self.safe_get_from_history("epoch", int)

                except KeyError as ex:

                    n_epochs = len(self.history["loss"])

                    logger.warning(
                        f"{self._missing_signal_error_msg(ex.args[0], False)}\n"
                        f"Using a default sequence (0, 1, ..., {n_epochs - 1=})"
                    )
                    
                    x = np.range(n_epochs)
                    
                x = np.array(x)

            elif mod == self.XAxisMode.batch:
                
                epoch_size = np.array(self.safe_get_from_history("train.epoch_size", int))
                x = np.cumsum(epoch_size)

            elif mod == self.XAxisMode.crop:
                
                epoch_size = np.array(self.safe_get_from_history("train.epoch_size", int))
                batch_size = np.array(self.safe_get_from_history("train.batch_size", int))
                x = np.cumsum(epoch_size * batch_size)

            elif mod == self.XAxisMode.voxel:
                
                epoch_size = np.array(self.safe_get_from_history("train.epoch_size", int))
                batch_size = np.array(self.safe_get_from_history("train.batch_size", int))
                train_crop_shape = self.safe_get_from_history("train.train_crop_shape", tuple)
                
                n_voxels = np.array([
                    shape[0] * shape[1] * shape[2]
                    for shape in train_crop_shape
                ])
                x = np.cumsum(epoch_size * batch_size * n_voxels)

            elif mod == self.XAxisMode.time:
                
                seconds = np.array(self.safe_get_from_history("seconds", int))
                x = np.cumsum(seconds)

            else:
                raise NotImplementedError(f"{self.x_axis_mode=}")

            assert len(x) > 1, "You don't have enough epochs to plot. Go to the gym and call me later."

            xs[mod.name] = x

        self.xs_ = xs
        
        metrics = self.metrics
        
        self.metrics = metrics = (metrics,) if not isinstance(metrics, tuple) else metrics

        for met in metrics:
            assert isinstance(met, str), f"{type(met)=} in {self.metrics=}"
        
        ys = {}
        
        for met in self.metrics:
            ys[met] = np.array(self.safe_get_from_history(met, (int, float)))
        
        self.ys_ = ys

    @property
    def title(self) -> str:
        return (self.model_name or "") + f".training-plot"

    def plot(
        self,
        ax: ndarray,
        metric_kwargs: dict = None,
        val_metric_kwargs: dict = None,
        n_xticks: int = 11,
    ) -> "TrainingHistoryDisplay":
        
        check_matplotlib_support(this_func_name := f"{(this_class_name := self.__class__.__name__)}.plot")

        assert isinstance(ax, Axes), f"{type(ax)=}"
        
        # i don't know why this is done, I just copied
        self.ax_ = ax
        self.fig_ = ax.figure

        for metric_name in self.metrics:

            x = self.xs_[self.x_axis_mode[0].name]
            y = self.ys_[metric_name]
            
            split = "val" if metric_name.startswith("val_") else "train"
            
            effective_kwargs = {
                **dict(label=split),
                **(
                    (metric_kwargs if split == 'train' else val_metric_kwargs) 
                    or dict()
                )
            }

            # noinspection PyArgumentList
            self.plots_[metric_name] = ax.plot(x, y, **effective_kwargs)

        tick_locator = plt.LinearLocator(numticks=n_xticks)
        
        x_tickss = [
            tick_locator.tick_values(
                vmin=min(x := self.xs_[mod.name]), 
                vmax=max(x),
            )
            for mod in self.x_axis_mode
        ]

        ax.set_xticks(x_tickss[0])

        # format the ticks
        x_tickss = [
            [
                str(int(val)) if mod == self.XAxisMode.epoch else
                str(int(val)) if mod == self.XAxisMode.batch else
                str(int(val / 1000)) + "k" if mod == self.XAxisMode.crop else
                humanize.intword(int(float(f"{val:.2g}"))) if mod == self.XAxisMode.voxel else
                humanize.time.naturaldelta(val, minimum_unit="seconds") if mod == self.XAxisMode.time else
                "err"
                for val in ticks
            ]
            for ticks, mod in zip(x_tickss, self.x_axis_mode)
        ]

        # transpose
        x_tickss = list(zip(*x_tickss))
        x_ticks = ["\n".join(strs) for strs in x_tickss]

        ax.set_xticklabels(x_ticks)

        ax.set_title(f"training history{f': {self.model_name}' or ''}")

        ax.set_ylabel({', '.join(self.metrics)})
        ax.set_xlabel("/".join([mod.name for mod in self.x_axis_mode]))

        # losses tend to go down, so this should be a good position
        # notice that using default loc=None is slower
        ax.legend(loc="upper right")

        return self


In [None]:
def plot(snapshot, dir_):
    
    fig, axs = plt.subplots(
        nrows := 3, ncols := 2,
        figsize=(
            ncols * (sz := 5),
            nrows * sz,
        ),
        dpi=100,
    )

    hist_gs = axs[2, 0].get_gridspec()

    for ax in axs[2, :]:
        ax.remove()

    hist_ax = fig.add_subplot(hist_gs[2, :])
    # fig.tight_layout()

    hist_display = TrainingDisplay(
        history=hist_df.to_dict('list'),
        model_name=model_alias,
    ).plot(ax=hist_ax)

    hist_ax.set_yscale("log")

    hist_ax.vlines(snapshot.model.epoch, ymin=0, ymax=1, linestyle='--', color='gray')

    train_axs = axs[0, :]
    train_display = t2s_viz.SliceDataPredictionDisplay(
        slice_data=snapshot.train_crop_data.squeeze(),  # todo move the squeeze to the processing
        slice_prediction=snapshot.train_crop_segm,
        slice_name="train",
        n_classes=n_classes,
    ).plot(train_axs)
    axs[0,1].set_title(f"train (epoch={snapshot.model.epoch})")

    val_axs = axs[1, :]
    val_display = t2s_viz.SliceDataPredictionDisplay(
        slice_data=snapshot.val_crop_data.squeeze(),  # todo move the squeeze to the processing
        slice_prediction=snapshot.val_crop_segm,
        slice_name="val",
        n_classes=n_classes,
    ).plot(val_axs)
    axs[1,1].set_title(f"val (epoch={snapshot.model.epoch})")
    
    fig.suptitle(f"{model_alias} (epoch={snapshot.model.epoch})")

    filepath = dir_ / f"{model_name}.epoch{snapshot.model.epoch:03d}.png"

    fig.savefig(filepath, format="png")       
    plt.close()


In [None]:
for snap in snapshots:
    plot(snap, output_files.snapshots_dir)

In [None]:
for snap in snapshots_single_crop:
    plot(snap, output_files.snapshots_single_crop_dir)

In [None]:
for snap in snapshots_single_crop_from_each_dataset:
    plot(snap, output_files.snapshots_single_crop_from_each_dataset_dir)

In [None]:
import os
import zipfile

In [None]:
def zipdir(path, ziph):
    # ziph is zipfile handle
    for root, dirs, files in os.walk(path):
        for file in files:
            ziph.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.join(path, '..')))
  

In [None]:
for dir_ in [
    output_files.snapshots_dir,
    output_files.snapshots_single_crop_dir,
    output_files.snapshots_single_crop_from_each_dataset_dir,
]:
    zip_ = dir_.parent / (dir_.name + ".zip")
    zipf = zipfile.ZipFile(zip_, 'w', zipfile.ZIP_DEFLATED)
    zipdir(dir_, zipf)
    zipf.close()