In [1]:
import sys, os
sys.path
from os.path import join, abspath
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import tensor
import json
import random
import yaml
from addict import Dict
import ast
from utils import ExpandStrArray

In [None]:
def Dashboard(train_report: str, eval_report: str, window_size: int=100):
    CreateStats(train_report, ["predictions", "truths"], 0)
    CreateStats(eval_report, ["predictions", "truths"], 0)
    train_f1 = StrListColumnToArray(train_report, "f1_scores")
    eval_f1 = StrListColumnToArray(eval_report, "f1_scores")
    train_report = ArrayToColumns(train_report, train_f1, ["cancer_f1", "lat_f1", "view_f1", "implant_f1"])
    eval_report = ArrayToColumns(eval_report, eval_f1, ["cancer_f1", "lat_f1", "view_f1", "implant_f1"])
    loss_ma = MovingAvg(train_report.loss, window_size)
    bf1_ma = MovingAvg(train_report.cancer_f1, window_size)
    lr_ma = MovingAvg(train_report.learning_rate, window_size)
    tpr_ma = MovingAvg(train_report.tpr, window_size)
    tnr_ma = MovingAvg(train_report.tnr, window_size)
    sel_cols = ["learning_rate", "tpr", "fpr", "precision", "tnr", "cancer_f1", "lat_f1", 
                "view_f1", "implant_f1"]
    train_epoch_stats = train_report.groupby("epoch").mean(numeric_only=True)[sel_cols]
    balf1scores = []
    balf1latscores = []
    for i in range(len(eval_report.index)):
        preds = np.array(ast.literal_eval(eval_report.loc[i, "predictions"]))
        gt = np.array(ast.literal_eval(eval_report.loc[i, "truths"]))
        balf1scores.append(BalancedF1Score(preds[:, 0], gt[:, 0]))
        balf1latscores.append(BalancedF1Score(preds[:, 1], gt[:, 1]))
    lim = len(eval_report.index)
    epoch_ticks = eval_report.epoch.astype(np.int16)
    fig, axs = plt.subplots(2, 3, figsize=(12,8))
    fig.suptitle("Training Log")
    # plot loss
    axs[0, 0].set_title("Loss")
    axs[0, 0].set_xlabel("Block")
    axs[0, 0].set_ylabel("Loss")
    axs[0, 0].scatter(train_report.index, train_report.loss, s=0.3, c="red", label="value")
    axs[0, 0].plot(train_report.index, loss_ma, "-b", label="moving avg")
    axs[0, 0].legend()
    # plot block stats
    axs[0, 1].set_title("Block Statistics")
    axs[0, 1].set_xlabel("Block")
    axs[0, 1].set_ylabel("Score")
    axs[0, 1].set_ylim([-0.05, 1.05])
    axs[0, 1].scatter(train_report.index, train_report.cancer_f1, s=0.5, c="red", label="cancer_f1_score")
    axs[0, 1].scatter(train_report.index, train_report.tpr, marker="*", s=0.5, c="blue", label="true_pos_rate")
    axs[0, 1].scatter(train_report.index, train_report.tnr, marker="^", s=0.5, c="green", label="true_neg_rate")
    axs[0, 1].plot(train_report.index, bf1_ma, "-c")
    axs[0, 1].plot(train_report.index, tpr_ma, "-y")
    axs[0, 1].plot(train_report.index, tnr_ma, "-m")
    axs[0, 1].legend()
    # plot AUC
    axs[0, 2].set_title("Training AUC")
    axs[0, 2].set_xlabel("1 - Specificity (FPR)")
    axs[0, 2].set_xlim([-0.05, 1.05])
    axs[0, 2].set_ylim([-0.05, 1.05])
    axs[0, 2].set_ylabel("Sensitivity (TPR)")
    axs[0, 2].scatter(train_report.fpr, train_report.tpr, s=1., c="red", label="train")
    axs[0, 2].scatter(eval_report.fpr, eval_report.tpr, marker="*", s=10.,
                      c="blue", label="eval")
    axs[0, 2].plot([0., 1.], [0., 1.], "-k", linewidth=0.5)
    axs[0, 2].legend()
    # plot learning rate
    axs[1, 0].set_title("Learning Rate")
    axs[1, 0].set_xlabel("Block")
    axs[1, 0].set_ylabel("Learning Rate")
    axs[1, 0].plot(train_report.block, train_report.learning_rate, "-r", linewidth=1.)
    #axs[1, 0].plot(train_epoch_stats.index, train_epoch_stats.learning_rate, "-r", linewidth=1.)
    # axs[1, 0].plot(epoch_ticks, lr_ma, "-b")
    # plot epoch stats
    axs[1, 1].set_title("F1 Scores")
    axs[1, 1].set_xlabel("Epoch")
    axs[1, 1].set_ylabel("Score")
    axs[1, 1].set_ylim([-0.05, 1.05])
    axs[1, 1].plot(epoch_ticks, eval_report.cancer_f1, "-r",
                   linewidth=1.5, label="eval_cancer_f1")
    axs[1, 1].plot(epoch_ticks, eval_report.lat_f1, "-b",
                   linewidth=1.5, label="eval_lat_f1")
    axs[1, 1].plot(epoch_ticks, eval_report.view_f1, "-g",
                   linewidth=1.5, label="eval_view_f1")
    axs[1, 1].plot(epoch_ticks, eval_report.implant_f1, "-y",
                   linewidth=1.5, label="eval_imp_f1")
    axs[1, 1].plot(epoch_ticks, balf1scores, ".r",
                   linewidth=1.5, label="bal_cancer_eval_f1")
    axs[1, 1].plot(epoch_ticks, balf1latscores, ".b",
                   linewidth=1.5, label="bal_lat_eval_f1")
    axs[1, 1].plot(epoch_ticks, train_epoch_stats.cancer_f1[:lim], "--r",
                   linewidth=1.5, label="train_cancer_f1")
    axs[1, 1].plot(epoch_ticks, train_epoch_stats.lat_f1[:lim], "--b",
                   linewidth=1.5, label="train_lat_f1")

    axs[1, 1].legend()
    # plot Recall, Precision and TNR
    axs[1, 2].set_title("Other Statistics")
    axs[1, 2].set_xlabel("Epoch")
    axs[1, 2].set_ylabel("Score")
    axs[1, 2].set_ylim([-0.05, 1.05])
    axs[1, 2].plot(epoch_ticks, eval_report.tpr, "-g",
                   linewidth=1.5, label="eval_tpr")
    axs[1, 2].plot(epoch_ticks, train_epoch_stats.tpr[:lim], "--g",
                   linewidth=1., label="train_tpr")
    axs[1, 2].plot(epoch_ticks, eval_report.tnr, "-y",
                   linewidth=1.5, label="eval_tnr")
    axs[1, 2].plot(epoch_ticks, train_epoch_stats.tnr[:lim], "--y",
                   linewidth=1., label="train_tnr")
    axs[1, 2].legend()
    plt.tight_layout()
    plt.show() 

In [80]:
train_rep_dir = "/home/isaiah/TotalSegmentator/results/20230323/nnunet_train_reps_02/"
eval_rep_dir = "/home/isaiah/TotalSegmentator/results/20230323/nnunet_eval_reps_02/"

tdfs = []
edfs = []
for i in range(5):
    tfile = join(train_rep_dir, f"rank{i}.csv")
    efile = join(eval_rep_dir, f"rank{i}.csv")
    tdfs.append(pd.read_csv(tfile))
    edfs.append(pd.read_csv(efile))
tdf = pd.concat(tdfs, axis=0)
edf = pd.concat(edfs, axis=0)
dfg_train = tdf.groupby("block").mean().astype({"epoch": np.uint8})
efg_train = ExpandStrArrayColumns(edf, "epoch")