In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from bokeh.embed import json_item
from bokeh.models import ColumnDataSource
from bokeh.plotting import figure
from bokeh.transform import dodge
from omegaconf import OmegaConf
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

sys.path.append("..")
from src.dataset.vtt import CATEGORIES, TOPICS, VTTDataset  # noqa: E402
from src.utils.datatool import read_json, read_jsonlines  # noqa: E402
from src.utils.plottool import matplotlib_header  # noqa: E402

dataset = VTTDataset(
    "test", return_raw_text=True, transform_cfg={"normalize": False}
)


LOG_ROOT = "/log/exp/vtt/"
details_cache = {}
exp_ids_cache = {}
results_cache = {}

def get_exp_ids():
    exp_names = []
    for exp_root in sorted(Path(LOG_ROOT).glob("*")):
        if (exp_root / "detail.jsonl").exists():
            config = OmegaConf.load(exp_root / "config.yaml")
            exp_name = config.name
            exp_id = exp_root.name
            exp_time = exp_id.split(".")[-1]
            i = 1
            while True:
                if (
                    exp_name in exp_ids_cache
                    and exp_ids_cache[exp_name] != exp_id
                ):
                    exp_name = f"{config.name}_{i}"
                    i += 1
                else:
                    break
            exp_ids_cache[exp_name] = exp_id
            exp_names.append((exp_name, exp_time))

            summary_path = (
                exp_root
                / "wandb"
                / "latest-run"
                / "files"
                / "wandb-summary.json"
            )
            if summary_path.exists():
                results_cache[exp_name] = read_json(summary_path)
    exp_names = [
        x[0] for x in sorted(exp_names, key=lambda x: x[1], reverse=True)
    ]
    return exp_names


def index2result(index, exp_names=[]):
    index = max(0, min(int(index), len(dataset) - 1))
    data = dataset[index]
    text_table = get_text_table(index, data, exp_names)
    metrics_plot = get_metrics_pyplot(index, exp_names)
    # metrics_plot = get_metrics_bokeh(index, exp_names)
    metrics_table = get_metrics_table(index, exp_names)
    pred_classification_table = get_classification_table(index, exp_names)
    overall_metrics_plot = get_overall_metrics_pyplot(exp_names)
    overall_metrics_table = get_overall_metrics_table(exp_names)
    return (
        index,
        CATEGORIES[data["category"]],
        TOPICS[data["topic"]],
        cache_test_image(data),
        text_table,
        metrics_plot,
        metrics_table,
        pred_classification_table,
        overall_metrics_plot,
        overall_metrics_table,
    )


def random_result(exp_names=[]):
    index = np.random.randint(len(dataset))
    return index2result(index, exp_names)


def cache_test_image(data, cache_dir="/data/vtt/cache/"):
    cache_dir = Path(cache_dir)
    cache_dir.mkdir(exist_ok=True, parents=True)
    cache_image = cache_dir / f"test_{data['index']}.png"
    if not cache_image.exists():
        images = data["states"][data["states_mask"]]
        save_image(images, str(cache_image), nrow=images.size(0), pad_value=1.0)
    return str(cache_image)


def get_text_table(index, data, exp_names):
    results = {"NO.": list(range(1, len(data["text"]) + 1)), "GT": data["text"]}
    for exp_name in exp_names:
        if exp_name not in details_cache:
            exp_id = exp_ids_cache[exp_name]
            details_cache[exp_name] = read_jsonlines(
                f"{LOG_ROOT}/{exp_id}/detail.jsonl"
            )
        results[exp_name] = details_cache[exp_name][index]["preds"]
    df = pd.DataFrame(results)
    return df


def get_metrics_pyplot(index, exp_names):
    matplotlib_header(1 / 3)
    plt.rcParams["legend.fontsize"] = 12

    metrics = ["BLEU_4", "METEOR", "ROUGE", "CIDEr", "BERTScore"]
    fig = plt.figure()
    results = {
        key: [
            np.mean(details_cache[exp_name][index][key])
            for exp_name in exp_names
        ]
        for key in metrics
    }
    n_metrics = len(metrics)
    n_exp = len(exp_names)
    width = min((1 - 0.1) / n_exp, 0.2)

    x = np.arange(n_metrics)
    for i, exp_name in enumerate(exp_names):
        idx_exp = exp_names.index(exp_name)
        plt.bar(
            x + width * (i - n_exp / 2 + 0.5),
            [results[key][idx_exp] for key in metrics],
            width=width,
            label=exp_name,
        )
    plt.xticks(x, metrics)

    plt.legend()
    return fig


def get_metrics_bokeh(index, exp_names):
    metrics = ["BLEU_4", "METEOR", "ROUGE", "CIDEr", "BERTScore"]
    results = {
        "exp": exp_names,
    }
    results.update(
        {
            key: [
                np.mean(details_cache[exp_name][index][key])
                for exp_name in exp_names
            ]
            for key in metrics
        }
    )
    df = pd.DataFrame(results)
    source = ColumnDataSource(data=df)

    p = figure(
        x_range=exp_names,
        y_range=(0, 1),
        title="",
        height=350,
        toolbar_location=None,
        tools="",
    )

    n_metrics = len(metrics)
    width = min((1 - 0.1) / n_metrics, 0.2)
    for i, metric in enumerate(metrics):
        p.vbar(
            x=dodge("exp", (i - (n_metrics / 2)) * width, range=p.x_range),
            top=metric,
            width=width,
            source=source,
            legend_label=metric,
        )

    p.x_range.range_padding = 0.1
    p.xgrid.grid_line_color = None
    p.legend.location = "top_left"
    p.legend.orientation = "horizontal"
    return json_item(p)


def get_classification_table(index, exp_names):
    keys = ["category_pred", "topic_pred"]
    results = {
        "Exp": exp_names,
    }
    results.update(
        {
            key: [
                details_cache[exp_name][index][key]
                if key in details_cache[exp_name][index]
                else " "
                for exp_name in exp_names
            ]
            for key in keys
        }
    )
    df = pd.DataFrame(results)
    return df


def get_metrics_table(index, exp_names):
    metrics = ["BLEU_4", "METEOR", "ROUGE", "CIDEr", "BERTScore"]
    results = {
        "Exp": exp_names,
    }
    results.update(
        {
            key: [
                np.mean(details_cache[exp_name][index][key])
                for exp_name in exp_names
            ]
            for key in metrics
        }
    )
    df = pd.DataFrame(results)
    return df


def get_overall_metrics_table(exp_names):
    metrics = ["BLEU_4", "METEOR", "ROUGE", "CIDEr", "BERTScore"]
    results = {
        "Exp": exp_names,
    }
    results.update(
        {
            key: [
                results_cache[exp_name][f"test/{key}"]
                if exp_name in results_cache
                and f"test/{key}" in results_cache[exp_name]
                else 0.0
                for exp_name in exp_names
            ]
            for key in metrics
        }
    )
    df = pd.DataFrame(results)
    return df


def get_overall_metrics_pyplot(exp_names):
    matplotlib_header(1 / 3)
    plt.rcParams["legend.fontsize"] = 12

    metrics = ["BLEU_4", "METEOR", "ROUGE", "CIDEr", "BERTScore"]
    fig = plt.figure()
    results = {
        key: [
            results_cache[exp_name][f"test/{key}"]
            if exp_name in results_cache
            and f"test/{key}" in results_cache[exp_name]
            else 0.0
            for exp_name in exp_names
        ]
        for key in metrics
    }
    n_metrics = len(metrics)
    n_exp = len(exp_names)
    width = min((1 - 0.1) / n_exp, 0.2)

    x = np.arange(n_metrics)
    for i, exp_name in enumerate(exp_names):
        idx_exp = exp_names.index(exp_name)
        plt.bar(
            x + width * (i - n_exp / 2 + 0.5),
            [results[key][idx_exp] for key in metrics],
            width=width,
            label=exp_name,
        )
    plt.xticks(x, metrics)

    plt.legend()
    return fig

In [None]:
choices = get_exp_ids()

In [None]:
exp_names = [
    'baseline_cst_clip',
    'baseline_glacnet_ViT-L/14',
    'baseline_densecap_norm_zero',
    'ttnet_sota_v5_base',
    'ttnet_sota_v5_0.15_0.5_zero_wclass_0.25_wcat_0.1',
]

In [None]:
for exp_name in exp_names:
    print(f"{exp_name}: {exp_ids_cache[exp_name]}")

In [None]:
RENAME = {
        "baseline_cst_clip": "CST", "baseline_glacnet_ViT-L/14": "GLACNet", "baseline_densecap_norm_zero": "DenseCap",
        "ttnet_sota_v5_base": "TTNet base", "ttnet_sota_v5_0.15_0.5_zero_wclass_0.25_wcat_0.1": "TTNet"
}
def show_result(index):
    data = dataset[index]
    fig, ax = plt.subplots(nrows=4, ncols=1)
    # plt.subplot(211)
    
    
    df = get_text_table(index, data, exp_names)
    df = df.rename(columns=RENAME)
    
    # matplotlib_header(1/3)
    width, height = plt.figaspect(1)
    plt.rcParams["figure.figsize"] = (16, 16)
    plt.rcParams["legend.fontsize"] = 12

    metrics = ["BLEU_4", "METEOR", "ROUGE", "CIDEr", "BERTScore"]
    # fig = plt.figure()
    results = {
        key: [
            np.mean(details_cache[exp_name][index][key])
            for exp_name in exp_names
        ]
        for key in metrics
    }
    n_metrics = len(metrics)
    n_exp = len(exp_names)
    width = min((1 - 0.1) / n_exp, 0.2)

    x = np.arange(n_metrics)
    for i, exp_name in enumerate(exp_names):
        idx_exp = exp_names.index(exp_name)
        ax[0].bar(
            x + width * (i - n_exp / 2 + 0.5),
            [results[key][idx_exp] for key in metrics],
            width=width,
            label=RENAME[exp_name],
        )
    ax[0].set_xticks(x, metrics)

    ax[0].legend()
    # plt.show()
    # plt.subplot(212)
    img = mpimg.imread(cache_test_image(data))
    imgplot = ax[1].imshow(img) 
    ax[1].set_xticks([])
    ax[1].set_yticks([])
    df_1 = df.loc[:,["NO.", "GT", "CST", "GLACNet"]]
    table = ax[2].table(cellText=df_1.values, colLabels=df_1.columns, loc='center', bbox=[0,0,1,1], colWidths=[0.1,0.3,0.3,0.3])
    df_2 = df.loc[:,["NO.", "DenseCap", "TTNet base", "TTNet"]]
    table = ax[3].table(cellText=df_2.values, colLabels=df_2.columns, loc='center', bbox=[0,0,1,1], colWidths=[0.1,0.3,0.3,0.3])
    ax[2].set_xticks([])
    ax[2].set_yticks([])
    ax[3].set_xticks([])
    ax[3].set_yticks([])
    # plt.show(imgplot)
    fig.savefig(f"sample/vtt_cases/example_{index}_{CATEGORIES[data['category']]}_{TOPICS[data['topic']]}.jpg")
    plt.close(fig)
    return df

In [None]:
df = show_result(1)

In [None]:
for i in range(len(dataset)):
    show_result(i)

In [None]:
def get_metrics_table(index, exp_names):
    metrics = ["BLEU_4", "METEOR", "ROUGE", "CIDEr", "BERTScore"]
    results = {
        "Exp": [RENAME[exp_name] for exp_name in exp_names],
    }
    results.update(
        {
            key: [
                np.mean(details_cache[exp_name][index][key])
                for exp_name in exp_names
            ]
            for key in metrics
        }
    )
    df = pd.DataFrame(results)
    return df

In [None]:
exp_names = [
    'baseline_cst_clip',
    'baseline_glacnet_ViT-L/14',
    'baseline_densecap_norm_zero',
    'ttnet_sota_v5_0.15_0.5_zero_wclass_0.25_wcat_0.1',
]

In [None]:
import os

In [None]:
def good_cases():
    for index in range(len(dataset)):
        data = dataset[index]
        df = get_metrics_table(index, exp_names)
        if df.CIDEr.argmax() == 3 and df.CIDEr[3] > 6:
            name = f"example_{index}_{CATEGORIES[data['category']]}_{TOPICS[data['topic']]}.jpg"
            print(name)
            os.symlink(os.path.join("..", "vtt_cases", name), os.path.join("sample", "vtt_good_cases", name))
    return df

In [None]:
df = good_cases()

In [None]:
df.CIDEr[3]

In [None]:
df.CIDEr.argmax()

In [None]:
get_overall_metrics_table(exp_names)

In [None]:
def bad_cases():
    for index in range(len(dataset)):
        data = dataset[index]
        df = get_metrics_table(index, exp_names)
        if df.CIDEr.argmax() == 3 and df.CIDEr[3] < 6:
            name = f"example_{index}_{CATEGORIES[data['category']]}_{TOPICS[data['topic']]}.jpg"
            print(name)
            os.symlink(os.path.join("..", "vtt_cases", name), os.path.join("sample", "vtt_good_cases", name))
    return df

In [None]:
exp_names = [
    'baseline_glacnet_ViT-L/14',
    'baseline_densecap_norm_zero',
    'ttnet_sota_v5_0.15_0.5_zero_wclass_0.25_wcat_0.1',
]

In [None]:
RENAME = {
        "baseline_cst_clip": "CST", "baseline_glacnet_ViT-L/14": "GLACNet", "baseline_densecap_norm_zero": "DenseCap",
        "ttnet_sota_v5_base": "TTNet base", "ttnet_sota_v5_0.15_0.5_zero_wclass_0.25_wcat_0.1": "TTNet", "GT": "Groundtruth"
}

In [None]:
great_cases = [143, 220, 262, 295, 308, 358, 412, 1359, 197, 208, 304]

In [None]:
model

In [None]:
models = ["DenseCap", "GLACNet", "TTNet", "Groundtruth"]
for index in great_cases:
    data = dataset[index]
    df = get_text_table(index, data, exp_names)
    df = df.rename(columns=RENAME)
    metrics = get_metrics_table(index, exp_names)
    text = ""
    text += "\\vspace{5pt}\n"
    text += "\\begin{tiny}\n\n"
    print()
    print(index)
    for model in models:
        text += "\\begin{minipage}[c]{0.245\linewidth}\n\n"
        # if model == "GLACNet":
        #     cite = "~\citep{kimGLACNetGLocal2019}"
        # elif model == "DenseCap":
        #     cite = "~\citep{johnsonDenseCapFullyConvolutional2016a}"
        # else:
        #     cite = ""
        cite = ""
        text += "\\textbf{"+ model + cite + ":} \n\n"
        for i, s in enumerate(df[model]):
            if s == df["Groundtruth"][i]:
                # text += "\\textcolor{green}{" f"[{i+1}] {s.capitalize()}." + "} "
                text += f"{i+1}. {s.capitalize()}. "
            else:
                text += "\\textcolor{red}{" f"{i+1}. {s.capitalize()}." + "} "
                # text += f"[{i+1}] {s.capitalize()}. "
            text += "\n\n"
        # text += f' (CIDEr: {metrics[metrics["Exp"] == model]["CIDEr"].values[0]*100:.2f})\n\n'
        # text += f'\n\n'
        text += "\end{minipage}\n"
    text += "\end{tiny}\n"
    text += "\\vspace{5pt}\n"
    print(text)

In [None]:
print(df.style.hide(axis="index").to_latex())

In [None]:
df