In [None]:
# import os
# Avoid mysterious  reproducibility error by setting CUBLAS_WORKSPACE_CONFIG=:4096:8
# or CUBLAS_WORKSPACE_CONFIG=:16:8
# os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

from experiments import run_all

# run_all(["predict"])

In [None]:
# (OLD) Collect Results

import itertools
import math
from pathlib import Path

import bjontegaard as bd
import pandas as pd
import numpy as np
import plotly.express as px

from experiments.results import (
    collect_results,
    get_result_rows,
    MODEL_ORDER,
    MODEL_UNITS,
    DEFAULT_ID_COLS,
)
from models.model import plot_channels

# prevent line break in dataframe printing
pd.set_option("display.expand_frame_repr", False)

dirs_hupu = [
    "run2_keep",
    "run4_keep",
    "run10_keep",
    "runs/run03_keep",
    "runs/run05_keep",
    "runs/run16_keep",
    "runs/run18_keep",
    "runs/run19_keep",
    "runs/run20_keep",
    "runs/run21_keep",
    "runs/run22_keep",
    "runs/run28_keep",
    "runs/run29_keep",
    "runs/run30_keep",
    "runs/run31_keep",
    "runs/run32_keep",
    "runs/run33_keep",
    "runs/run34_keep",
    "runs/run35_keep",
    "runs/run36_keep",
    "runs/run37_keep",
    "runs/run38_keep",
    "runs/run39_keep",
    "runs/run40_keep",
    "runs/run41_keep",
    # "runs/run54_keep",
    "runs/run55_keep",
    "runs/run56_keep",
]

dirs_tupu = [
    "runs/run02_keep",
    "runs/run03_keep",
    "runs/run04_keep",
    "runs/run05_keep",
    "runs/run06_keep",
    "runs/run07_keep",
    "runs/run08_keep",
    "runs/run13_keep",
    "runs/run17_keep",
    "runs/run20_keep",
    "runs/run58",
]

outdirs_hupu = [Path(f"experiments_hupu/{d}") for d in dirs_hupu]
outdirs_tupu = [Path(f"experiments_tupu/{d}") for d in dirs_tupu]
outdirs = outdirs_hupu + outdirs_tupu

probe_results_list, lvc_results_list = collect_results(outdirs, partial=True)

baseline0 = {}
baseline1 = {}

for probe_results, outdir in zip(probe_results_list, outdirs):
    print(f"--- {outdir} ---")

    for model_id, res in probe_results.items():
        for nchunks, grads_norm in res["grads_norm"].items():
            plot_channels(
                grads_norm.numpy(),
                f"grads_norm ({model_id}, {nchunks})",
                show=False,
                save=outdir / f"grad_norm_{model_id}_{nchunks}.png",
            )

        # w_yuv = res["w_yuv"]

        scores = [
            {"model_id": model_id, "score": score, "unit": unit}
            for unit, score in res["orig"].items()
        ]
        print(scores)

        for unit, score in res["orig"].items():
            if len(MODEL_UNITS[model_id]) > 0:
                if unit == MODEL_UNITS[model_id][0]:
                    baseline0[model_id] = score

            if len(MODEL_UNITS[model_id]) > 1:
                if unit == MODEL_UNITS[model_id][1]:
                    baseline1[model_id] = score


rows = get_result_rows(lvc_results_list, outdirs)
df_full = pd.DataFrame(rows)

# sort
df_full = df_full.sort_values(
    by=["probe_model_id", "model_id"], key=lambda x: x.map(MODEL_ORDER)
)
df_full = df_full.sort_values(by=DEFAULT_ID_COLS[2:])

# calculate differences
df_full["diff0_lvc_g"] = df_full["score0_lvc_g"] - df_full["score0_lvc"]
df_full["diff1_lvc_g"] = df_full["score1_lvc_g"] - df_full["score1_lvc"]

try:
    df_full["diff0_lvc_reprobe"] = df_full["score0_lvc_reprobe"] - df_full["score0_lvc"]
    df_full["diff0_lvc_g_reprobe"] = (
        df_full["score0_lvc_g_reprobe"] - df_full["score0_lvc_g"]
    )
    df_full["diff1_lvc_reprobe"] = df_full["score1_lvc_reprobe"] - df_full["score1_lvc"]
    df_full["diff1_lvc_g_reprobe"] = (
        df_full["score1_lvc_g_reprobe"] - df_full["score1_lvc_g"]
    )
except KeyError:
    pass

# calculate distortion
for model_id in MODEL_ORDER.keys():
    try:
        ref_loss = baseline1[model_id]
    except KeyError:
        continue

    df_full.loc[df_full["model_id"] == model_id, "loss_dist"] = (
        df_full[df_full["model_id"] == model_id]["score1_lvc_g"] - ref_loss
    ) ** 2

# split off results
df_ablation = df_full[
    (df_full["outdir"] == Path("experiments_hupu/run2_keep"))
    | (df_full["outdir"] == Path("experiments_hupu/run4_keep"))
]
df_abs = df_full[df_full["outdir"] == Path("experiments_hupu/runs/run03_keep")]
df_sionna = df_full[
    (df_full["outdir"] == Path("experiments_tupu/runs/run02_keep"))
    | (df_full["outdir"] == Path("experiments_tupu/runs/run03_keep"))
    | (df_full["outdir"] == Path("experiments_tupu/runs/run04_keep"))
    | (df_full["outdir"] == Path("experiments_tupu/runs/run05_keep"))
    | (df_full["outdir"] == Path("experiments_tupu/runs/run06_keep"))
    | (df_full["outdir"] == Path("experiments_tupu/runs/run07_keep"))
    | (df_full["outdir"] == Path("experiments_tupu/runs/run08_keep"))
    | (df_full["outdir"] == Path("experiments_tupu/runs/run13_keep"))
    | (df_full["outdir"] == Path("experiments_tupu/runs/run17_keep"))
    | (df_full["outdir"] == Path("experiments_tupu/runs/run20_keep"))
]
df_sionna_fss = df_full[(df_full["outdir"] == Path("experiments_tupu/runs/run02_keep"))]
df_full_precise = df_full[df_full["outdir"] == Path("experiments_hupu/runs/run05_keep")]
df_full = df_full[
    (df_full["outdir"] == Path("experiments_hupu/run10_keep"))
    | (df_full["outdir"] == Path("experiments_hupu/runs/run16_keep"))
    | (df_full["outdir"] == Path("experiments_hupu/runs/run32_keep"))
]

# JPEG data
# id_cols_jpeg = DEFAULT_ID_COLS + ["codec", "param"]
# dirs_jpeg = ["runs/run42_keep"]
# outdirs_jpeg = [Path(f"experiments_hupu/{d}") for d in dirs_jpeg]
# _, lvc_results_list_jpeg = collect_results(outdirs_jpeg, partial=True)
# rows_jpeg = get_result_rows(lvc_results_list_jpeg, outdirs_jpeg)
# df_full_jpeg = pd.DataFrame(rows_jpeg)
# df_full_jpeg = df_full_jpeg.sort_values(
#     by=["model_id"], key=lambda x: x.map(MODEL_ORDER)
# )
# df_full_jpeg = df_full_jpeg.sort_values(by=id_cols_jpeg[1:])

# JPEG + GRACE (+ Sionna)
id_cols_jpeg = DEFAULT_ID_COLS + ["codec", "param"]
outdirs_jpeg = [
    # Path("experiments_tupu/runs/run19_keep"),
    Path("experiments_hupu/runs/run52_keep"),
    Path("experiments_hupu/runs/run53_keep"),
    # Path("experiments_hupu/runs/run54_keep"),
    Path("experiments_hupu/runs/run55_keep"),
    Path("experiments_hupu/runs/run56_keep"),
    Path("experiments_tupu/runs/run58"),
]
_, lvc_results_list_jpeg = collect_results(outdirs_jpeg, partial=True)
rows_jpeg = get_result_rows(lvc_results_list_jpeg, outdirs_jpeg)
df_full_jpeg = pd.DataFrame(rows_jpeg)
df_full_jpeg = df_full_jpeg.sort_values(by=id_cols_jpeg[2:]).sort_values(
    by=["probe_model_id", "model_id"], key=lambda x: x.map(MODEL_ORDER)
)

# Disparity results
outdirs_disparity = [Path("experiments_tupu/runs/run18_keep")]
_, lvc_results_list_disparity = collect_results(outdirs_disparity, partial=True)
rows_disparity = get_result_rows(lvc_results_list_disparity, outdirs_disparity)
df_disparity = pd.DataFrame(rows_disparity)

print("=== Disparity data:")
print(df_disparity)

print("=== JPEG + GRACE + Sionna data:")
print(df_full_jpeg)

print("=== Ablation data:")
print(df_ablation)

print("=== Abs. dist data:")
print(df_abs)

# detect & filter out duplicates
duplicates = df_full[df_full.duplicated(subset=DEFAULT_ID_COLS, keep=False)]
df_full = df_full.drop_duplicates(subset=DEFAULT_ID_COLS)

print("=== Full data:")
print(df_full.columns)
print(df_full)
print(f"=== Duplicates ({len(duplicates)}):")
print(duplicates)

# detect & filter out duplicates in Sionna
duplicates_sionna = df_sionna[df_sionna.duplicated(subset=DEFAULT_ID_COLS, keep=False)]
df_sionna = df_sionna.drop_duplicates(subset=DEFAULT_ID_COLS)
df_sionna = df_sionna.sort_values(by=DEFAULT_ID_COLS[2:])
df_sionna = df_sionna.sort_values(
    by=["probe_model_id", "model_id"], key=lambda x: x.map(MODEL_ORDER)
)

print("=== Sionna Full data:")
print(df_sionna.columns)
print(df_sionna)
print(f"=== Sionna Duplicates ({len(duplicates_sionna)}):")
print(duplicates_sionna)


# BD metrics
def get_bd(df) -> pd.DataFrame:
    """Calculate BD metrics from a given data frame"""

    estimators = pd.unique(df["estimator"])
    modes = pd.unique(df["mode"])
    nchunks = pd.unique(df["nchunks"])
    csnr_dbs = [val for val in pd.unique(df["csnr_db"]) if val != "inf"]
    block_dcts = pd.unique(df["block_dct"])
    interp_method = "pchip"  # akima, pchip

    rows = []
    for model_id, estimator, mode, nc, csnr_db, block_dct in itertools.product(
        MODEL_ORDER.keys(), estimators, modes, nchunks, csnr_dbs, block_dcts
    ):
        df_test = df[
            (df["csnr_db"] == csnr_db)
            & (df["model_id"] == model_id)
            & (df["estimator"] == estimator)
            & (df["nchunks"] == nc)
            & (df["block_dct"] == block_dct)
            & df["grad_w"]
            & df["grad_sel"]
            & df["grad_alloc"]
        ][["cr", "score0_lvc", "score0_lvc_g"]]

        # TODO: Fix non-monotonic curves causing errors

        anchor = df_test[["cr", "score0_lvc"]]
        # anchor = anchor[anchor >= anchor.cummax()]
        test = df_test[["cr", "score0_lvc_g"]]
        # test = test[test >= test.cummax()]

        try:
            bdrate = bd.bd_rate(
                anchor["cr"].to_numpy(),
                anchor["score0_lvc"].to_numpy(),
                test["cr"].to_numpy(),
                test["score0_lvc_g"].to_numpy(),
                method=interp_method,
            )
        except (IndexError, ValueError, AssertionError):
            bdrate = np.nan

        try:
            bdacc = bd.bd_psnr(
                anchor["cr"].to_numpy(),
                anchor["score0_lvc"].to_numpy(),
                test["cr"].to_numpy(),
                test["score0_lvc_g"].to_numpy(),
                method=interp_method,
            )
        except (IndexError, ValueError, AssertionError):
            bdacc = np.nan

        row = {
            "model_id": model_id,
            "estimator": estimator,
            "mode": mode,
            "nchunks": nc,
            "csnr_db": csnr_db,
            "block_dct": block_dct,
            "bdrate": bdrate,
            "bdacc": bdacc * 100,
        }

        # outdir = Path("experiments_tupu/runs/plots/test_bd")
        # outdir.mkdir(exist_ok=True, parents=True)
        # try:
        #     bd.compare_methods(
        #         anchor["cr"].to_numpy(),
        #         anchor["score0_lvc"].to_numpy(),
        #         test["cr"].to_numpy(),
        #         test["score0_lvc_g"].to_numpy(),
        #         rate_label="CR",
        #         distortion_label="Acc",
        #         figure_label="test",
        #         filepath=f"{outdir}/test_bd_{model_id}_{estimator}_{mode}_{nc}_{csnr_db}db_{'bb' if block_dct else 'ff'}.png",
        #     )
        # except (IndexError, ValueError) as e:
        #     print(f"{row}: Error: {e}")

        rows.append(row)

    return pd.DataFrame(rows)


def get_bd_default(df) -> pd.DataFrame:
    """Calculate BD metrics from a given data frame"""

    estimators = pd.unique(df["estimator"])
    modes = pd.unique(df["mode"])
    nchunks = pd.unique(df["nchunks"])
    csnr_dbs = [val for val in pd.unique(df["csnr_db"]) if val != "inf"]
    block_dcts = pd.unique(df["block_dct"])
    interp_method = "pchip"  # akima, pchip

    rows = []
    for model_id, estimator, mode, nc, csnr_db, block_dct in itertools.product(
        MODEL_ORDER.keys(), estimators, modes, nchunks, csnr_dbs, block_dcts
    ):
        df_anchor = df[
            (df["csnr_db"] == csnr_db)
            & (df["model_id"] == model_id)
            & (df["estimator"] == "zf")
            & (df["nchunks"] == 256)
            & (df["block_dct"] == False)
            & df["grad_w"]
            & df["grad_sel"]
            & df["grad_alloc"]
        ][["cr", "score0_lvc", "score0_lvc_g"]]

        df_256 = df[
            (df["csnr_db"] == csnr_db)
            & (df["model_id"] == model_id)
            & (df["estimator"] == estimator)
            & (df["nchunks"] == 256)
            & (df["block_dct"] == block_dct)
            & df["grad_w"]
            & df["grad_sel"]
            & df["grad_alloc"]
        ][["cr", "score0_lvc", "score0_lvc_g"]]

        df_ff = df[
            (df["csnr_db"] == csnr_db)
            & (df["model_id"] == model_id)
            & (df["estimator"] == estimator)
            & (df["nchunks"] == nc)
            & (df["block_dct"] == False)
            & df["grad_w"]
            & df["grad_sel"]
            & df["grad_alloc"]
        ][["cr", "score0_lvc", "score0_lvc_g"]]

        df_zf = df[
            (df["csnr_db"] == csnr_db)
            & (df["model_id"] == model_id)
            & (df["estimator"] == "zf")
            & (df["nchunks"] == nc)
            & (df["block_dct"] == block_dct)
            & df["grad_w"]
            & df["grad_sel"]
            & df["grad_alloc"]
        ][["cr", "score0_lvc", "score0_lvc_g"]]

        df_test = df[
            (df["csnr_db"] == csnr_db)
            & (df["model_id"] == model_id)
            & (df["estimator"] == estimator)
            & (df["nchunks"] == nc)
            & (df["block_dct"] == block_dct)
            & df["grad_w"]
            & df["grad_sel"]
            & df["grad_alloc"]
        ][["cr", "score0_lvc", "score0_lvc_g"]]

        # TODO: Fix non-monotonic curves causing errors

        anchor_default = df_anchor[["cr", "score0_lvc"]]
        anchor_256 = df_256[["cr", "score0_lvc_g"]]
        anchor_ff = df_ff[["cr", "score0_lvc_g"]]
        anchor_zf = df_zf[["cr", "score0_lvc_g"]]
        anchor = df_test[["cr", "score0_lvc"]]
        test = df_test[["cr", "score0_lvc_g"]]

        def get_bd_series(anchor_x, anchor_y, test_x, test_y, method=interp_method):
            try:
                bdrate = bd.bd_rate(
                    anchor_x.to_numpy(),
                    anchor_y.to_numpy(),
                    test_x.to_numpy(),
                    test_y.to_numpy(),
                    method=method,
                )
            except (IndexError, ValueError, AssertionError):
                bdrate = np.nan

            try:
                bdpsnr = bd.bd_psnr(
                    anchor_x.to_numpy(),
                    anchor_y.to_numpy(),
                    test_x.to_numpy(),
                    test_y.to_numpy(),
                    method=interp_method,
                )
            except (IndexError, ValueError, AssertionError):
                bdpsnr = np.nan

            return bdrate, bdpsnr

        bdrate_default, bdacc_default = get_bd_series(
            anchor_default["cr"],
            anchor_default["score0_lvc"],
            anchor["cr"],
            anchor["score0_lvc"],
        )

        bdrate, bdacc = get_bd_series(
            anchor["cr"],
            anchor["score0_lvc"],
            test["cr"],
            test["score0_lvc_g"],
        )

        bdrate_total, bdacc_total = get_bd_series(
            anchor_default["cr"],
            anchor_default["score0_lvc"],
            test["cr"],
            test["score0_lvc_g"],
        )

        bdrate_256, bdacc_256 = get_bd_series(
            anchor_256["cr"],
            anchor_256["score0_lvc_g"],
            test["cr"],
            test["score0_lvc_g"],
        )

        bdrate_ff, bdacc_ff = get_bd_series(
            anchor_ff["cr"],
            anchor_ff["score0_lvc_g"],
            test["cr"],
            test["score0_lvc_g"],
        )

        bdrate_zf, bdacc_zf = get_bd_series(
            anchor_zf["cr"],
            anchor_zf["score0_lvc_g"],
            test["cr"],
            test["score0_lvc_g"],
        )

        row_default = {
            "model_id": model_id,
            "estimator": estimator,
            "mode": mode,
            "nchunks": nc,
            "csnr_db": csnr_db,
            "block_dct": block_dct,
            "bdrate": bdrate_default,
            "bdacc": bdacc_default * 100,
            "result": "lvc",
        }

        row = {
            "model_id": model_id,
            "estimator": estimator,
            "mode": mode,
            "nchunks": nc,
            "csnr_db": csnr_db,
            "block_dct": block_dct,
            "bdrate": bdrate,
            "bdacc": bdacc * 100,
            "result": "lvc_g",
        }

        row_total = {
            "model_id": model_id,
            "estimator": estimator,
            "mode": mode,
            "nchunks": nc,
            "csnr_db": csnr_db,
            "block_dct": block_dct,
            "bdrate": bdrate_total,
            "bdacc": bdacc_total * 100,
            "result": "total",
        }

        row_256 = {
            "model_id": model_id,
            "estimator": estimator,
            "mode": mode,
            "nchunks": nc,
            "csnr_db": csnr_db,
            "block_dct": block_dct,
            "bdrate": bdrate_256,
            "bdacc": bdacc_256 * 100,
            "result": "lvc_g_256",
        }

        row_ff = {
            "model_id": model_id,
            "estimator": estimator,
            "mode": mode,
            "nchunks": nc,
            "csnr_db": csnr_db,
            "block_dct": block_dct,
            "bdrate": bdrate_ff,
            "bdacc": bdacc_ff * 100,
            "result": "lvc_g_ff",
        }

        row_zf = {
            "model_id": model_id,
            "estimator": estimator,
            "mode": mode,
            "nchunks": nc,
            "csnr_db": csnr_db,
            "block_dct": block_dct,
            "bdrate": bdrate_zf,
            "bdacc": bdacc_zf * 100,
            "result": "lvc_g_zf",
        }

        # outdir = Path("experiments_tupu/runs/plots/test_bd")
        # outdir.mkdir(exist_ok=True, parents=True)
        # try:
        #     bd.compare_methods(
        #         anchor["cr"].to_numpy(),
        #         anchor["score0_lvc"].to_numpy(),
        #         test["cr"].to_numpy(),
        #         test["score0_lvc_g"].to_numpy(),
        #         rate_label="CR",
        #         distortion_label="Acc",
        #         figure_label="test",
        #         filepath=f"{outdir}/test_bd_{model_id}_{estimator}_{mode}_{nc}_{csnr_db}db_{'bb' if block_dct else 'ff'}.png",
        #     )
        # except (IndexError, ValueError) as e:
        #     print(f"{row}: Error: {e}")

        rows.append(row_default)
        rows.append(row)
        rows.append(row_total)
        rows.append(row_256)
        rows.append(row_ff)
        rows.append(row_zf)

    return pd.DataFrame(rows)


df_bd = get_bd(df_full)
df_bd_precise = get_bd(df_full_precise)
df_bd_default = get_bd_default(df_full)
df_bd_sionna = get_bd(df_sionna)
df_bd_sionna_default = get_bd_default(df_sionna)

# print("=== BD metrics:")
# print(df_bd)

# print("=== BD metrics (default):")
# print(df_bd_default)

# print("=== BD metrics (Sionna):")
# print(df_bd_sionna)

# print("=== BD metrics (default Sionna):")
# print(df_bd_sionna_default)

json_outdir = Path("experiments/json")
json_outdir.mkdir(exist_ok=True)

all_dfs = {
    "df_ablation": df_ablation,
    "df_abs": df_abs,
    "df_sionna": df_sionna,
    "df_sionna_fss": df_sionna_fss,
    "df_full": df_full,
    "df_full_precise": df_full_precise,
    "df_full_jpeg": df_full_jpeg,
    "df_bd": df_bd,
    "df_bd_precise": df_bd_precise,
    "df_bd_sionna": df_bd_sionna,
    "df_bd_sionna_default": df_bd_sionna_default,
}

for name, df in all_dfs.items():
    json_file = (json_outdir / name).with_suffix(".json")
    print(f"Saving {json_file}")
    df.to_json(json_file, orient="records", default_handler=str)

In [None]:
# (OLD) Generate figures for the paper (requires above cell)

import itertools
from typing import List

import torch
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

outdir = Path("experiments/plots")
outdir.mkdir(exist_ok=True)

W = 1000
W2 = 400
H = 300
SCALE = 4.0

################################################################################
# Helper functions


def add_baseline(
    fig,
    baseline,
    model_ids=["fastseg_small", "fastseg_large", "yolov8_n", "yolov8_s", "yolov8_l"],
    col_major=False,
    pos="top right",
):
    for i, model_id in enumerate(model_ids, start=1):
        row = i if col_major else 0
        col = 0 if col_major else i
        try:
            fig.add_hline(
                y=baseline[model_id],
                line_dash="dot",
                annotation_text=f"{baseline[model_id] * 100:.2f}%",
                annotation_position=pos,
                annotation_xshift=-15,
                row=row,
                col=col,
            )
            # fig.add_hline(
            #     y=baseline[model_id],
            #     line_dash="dot",
            #     annotation_text=f"{baseline[model_id]:.4f}",
            #     annotation_position=pos,
            #     row=1,
            #     col=i,
            # )
            # fig.add_hline(
            #     y=baseline[model_id],
            #     line_dash="dot",
            #     annotation_text=f"{baseline[model_id]:.4f}",
            #     annotation_position=pos,
            #     row=2,
            #     col=i,
            # )
        except KeyError:
            continue


def add_zero_line(fig, label: str = ""):
    fig.add_hline(
        y=0, line_dash="dot", annotation_text=label, annotation_position="bottom right"
    )


def filter_df(df, estimator, mode, block_dct, nchunks=None):
    if nchunks is None:
        return df[
            (df["estimator"] == estimator)
            & (df["mode"] == mode)
            & (df["block_dct"] == block_dct)
            & df["grad_w"]
            & df["grad_sel"]
            & df["grad_alloc"]
        ]
    else:
        return df[
            (df["estimator"] == estimator)
            & (df["mode"] == mode)
            & (df["block_dct"] == block_dct)
            & (df["nchunks"] == nchunks)
            & df["grad_w"]
            & df["grad_sel"]
            & df["grad_alloc"]
        ]


def heatmaps(
    image: np.ndarray,
    rows: int,
    cols: int,
    label: str | None,
    show: bool = True,
    save: str | Path | None = None,
):
    fig = make_subplots(rows=rows, cols=cols)

    if image.shape[0] != (rows * cols):
        raise ValueError("Number of channels does not correspond to number of subplots")

    # https://community.plotly.com/t/how-to-set-log-scale-for-z-axis-on-a-heatmap/292/8
    def colorbar(nmin, nmax):
        labels = np.sort(
            np.concatenate(
                [
                    np.linspace(10**nmin, 10**nmax, 10),
                    10 ** np.linspace(nmin, nmax, 10),
                ]
            )
        )
        # vals = np.linspace(nmin, nmax, nmax+nmin+1)

        return dict(
            tick0=nmin,
            # title="Log Scale",
            tickmode="array",
            tickvals=np.log10(labels),
            ticktext=[f"{x:.2e}" for x in labels],
            # tickvals=vals,
            # ticktext=[f"{10**x:.2e}" for x in labels],
            # tickvals=np.linspace(nmin, nmax, nmax - nmin + 1),
            # ticktext=[
            #     f"{x:.0e}" for x in 10 ** np.linspace(nmin, nmax, nmax - nmin + 1)
            # ],
        )

    zero_mask = np.logical_or(image == 0.0, image == np.nan)
    img_nz = image[~zero_mask]
    image[zero_mask] = np.nan

    gmin = img_nz.min()
    gmax = img_nz.max()
    nmin = int(np.floor(np.log10(gmin)))
    nmax = int(np.ceil(np.log10(gmax)))

    for ch, (r, c) in zip(
        image, itertools.product(range(1, rows + 1), range(1, cols + 1))
    ):
        fig.add_trace(
            go.Heatmap(
                z=np.log10(ch),
                customdata=ch,
                hovertemplate="x: %{x} <br>" + "y: %{y} <br>" + "z: %{customdata:.2e}",
                # colorbar=colorbar(nmin, nmax),
                colorbar=dict(title="10^"),
                # colorscale="Inferno",
                # reversescale=True,
                zmin=np.log10(gmin),
                zmax=np.log10(gmax),
            ),
            row=r,
            col=c,
        )

    fig.update_yaxes(
        autorange="reversed", scaleanchor="x", scaleratio=1, constrain="domain"
    )
    fig.update_xaxes(scaleanchor="y", scaleratio=1, constrain="domain")

    if label is not None:
        fig.update_layout(title=label)

    fig.update_layout(
        width=W,
        height=H,
        autosize=False,
        margin=dict(l=0, r=0, t=0, b=0),
        font=dict(size=25),
    )

    if save:
        fig.write_image(save, scale=SCALE)
        fig.write_html(Path(save).with_suffix(".html"))

    if show:
        fig.show()


################################################################################
# Default + Grad-optimized LVC (RD)

estimator = "zf"
mode = 444
block_dct = False
nchunks = 256

msg = f"{estimator}, {mode}, {'bb' if block_dct else 'ff'}"
df = df_full[df_full["csnr_db"] != "inf"]
df = filter_df(df, estimator, mode, block_dct, nchunks)[
    ["model_id", "nchunks", "cr", "csnr_db", "score0_lvc", "score0_lvc_g"]
]
df_g = df.copy().drop("score0_lvc", axis=1).rename(columns={"score0_lvc_g": "accuracy"})
df_g["scheme"] = "CV-Cast"
df = df.drop("score0_lvc_g", axis=1).rename(columns={"score0_lvc": "accuracy"})
df["scheme"] = "LCT"

df = pd.concat([df, df_g])
df = df.replace("fastseg_small", "fastseg_small (mIoU)")
df = df.replace("fastseg_large", "fastseg_large (mIoU)")
df = df.replace("yolov8_n", "yolov8_n (mAP)")
df = df.replace("yolov8_s", "yolov8_s (mAP)")
df = df.replace("yolov8_l", "yolov8_l (mAP)")
df = df.rename(columns={"cr": "CR"})

fig = px.line(
    df,
    x="CR",
    y="accuracy",
    color="csnr_db",
    facet_col="model_id",
    facet_col_spacing=0.025,
    # facet_row="nchunks",
    symbol="scheme",
    labels="scheme",
    line_dash="scheme",
    line_dash_sequence=["dash", "solid"],
    symbol_sequence=["cross-thin-open", "x-thin-open"],
    markers=True,
    # log_x=True,
    range_y=[0.0, None],
)

add_baseline(fig, baseline0)

fig.update_xaxes(
    zerolinecolor="grey",
    zeroline=True,
    gridcolor="grey",
    griddash="dot",
    tickmode="array",
    tickvals=[0.0, 0.03125, 0.0625, 0.125, 0.25, 0.5, 1.0],
    ticktext=[0, "", "", 0.125, 0.25, 0.5, 1.0],
    # minor=dict(showgrid=True),
    title_standoff=0,
)
fig.update_yaxes(
    tickformat=".0%",
    matches=None,
    showticklabels=True,
    zerolinecolor="grey",
    zeroline=True,
    gridcolor="grey",
    griddash="dot",
    tickmode="auto",
    # tick0=0.0,
    # dtick=0.1,
    minor=dict(showgrid=True),
)
fig.update_layout(
    plot_bgcolor="white",
    width=W,
    height=H,
    legend=dict(orientation="h", x=0.0, y=-0.2),
    margin=dict(l=20, r=20, t=20, b=0),
    yaxis1=dict(range=[0.0, baseline0["fastseg_small"] * 1.1]),
    yaxis2=dict(range=[0.0, baseline0["fastseg_large"] * 1.1]),
    yaxis3=dict(range=[0.0, baseline0["yolov8_n"] * 1.1]),
    yaxis4=dict(range=[0.0, baseline0["yolov8_s"] * 1.1]),
    yaxis5=dict(range=[0.0, baseline0["yolov8_l"] * 1.1]),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/score_rd_zf_444_ff"
fig.write_image(f"{fname}.png", scale=SCALE)
fig.write_html(f"{fname}.html")
fig.show()


################################################################################
# Default + Grad-optimized LVC (acc vs CSNR)

fig = px.line(
    df,
    x="csnr_db",
    y="accuracy",
    color="CR",
    facet_col="model_id",
    facet_col_spacing=0.025,
    # facet_row="nchunks",
    symbol="scheme",
    labels="scheme",
    line_dash="scheme",
    line_dash_sequence=["dash", "solid"],
    symbol_sequence=["cross-thin-open", "x-thin-open"],
    markers=True,
    range_y=[0.0, None],
)

add_baseline(fig, baseline0)

fig.update_xaxes(
    zerolinecolor="grey",
    zeroline=True,
    gridcolor="grey",
    griddash="dot",
    tickmode="array",
    tickvals=[0, 5, 10, 15, 20, 30],
)
fig.update_yaxes(
    tickformat=".0%",
    zerolinecolor="grey",
    zeroline=True,
    matches=None,
    showticklabels=True,
    gridcolor="grey",
    griddash="dot",
    minor=dict(showgrid=True),
)
fig.update_layout(
    plot_bgcolor="white",
    width=W,
    height=H,
    legend=dict(orientation="h", x=0.0, y=-0.2),
    margin=dict(l=20, r=20, t=20, b=0),
    yaxis1=dict(range=[0.0, baseline0["fastseg_small"] * 1.1]),
    yaxis2=dict(range=[0.0, baseline0["fastseg_large"] * 1.1]),
    yaxis3=dict(range=[0.0, baseline0["yolov8_n"] * 1.1]),
    yaxis4=dict(range=[0.0, baseline0["yolov8_s"] * 1.1]),
    yaxis5=dict(range=[0.0, baseline0["yolov8_l"] * 1.1]),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/score_csnr_zf_444_ff"
fig.write_image(f"{fname}.png", scale=SCALE)
fig.write_html(f"{fname}.html")
fig.show()

################################################################################
# JPEG + GRACE + Sionna (acc vs CSNR)

df_jpeg_grace = df_full_jpeg[
    [
        "model_id",
        "nchunks",
        "cr",
        "csnr_db",
        "score0_lvc",
        "score0_lvc_g",
        "codec",
        "param",
        "nbits_per_sym",
        "eq_cr",
    ]
].query("csnr_db != 'inf' and codec in ['grace', 'jpeg']")

df_jpeg_grace["modulation"] = [f"{2**nb}-QAM" for nb in df_jpeg_grace["nbits_per_sym"]]
df_jpeg_grace = df_jpeg_grace.round({"eq_cr": 3})

df_grace = (
    df_jpeg_grace.query("codec == 'grace'")
    .copy()
    .drop("score0_lvc", axis=1)
    .rename(columns={"score0_lvc_g": "accuracy"})
)
df_grace["scheme"] = "GRACE"

df_jpeg = (
    df_jpeg_grace.query("codec == 'jpeg'")
    .copy()
    .drop("score0_lvc_g", axis=1)
    .rename(columns={"score0_lvc": "accuracy"})
)
df_jpeg["scheme"] = "JPEG"

fig = px.line(
    pd.concat([df_grace, df_jpeg]),
    x="csnr_db",
    y="accuracy",
    # color="param",
    color="scheme",
    facet_col="model_id",
    facet_col_spacing=0.025,
    # facet_row="nchunks",
    symbol="eq_cr",
    labels="eq_cr",
    line_dash="modulation",
    line_dash_sequence=["solid", "dash", "dot"],
    symbol_sequence=["cross-thin-open", "x-thin-open", "y-down"],
    markers=True,
    range_y=[0.0, None],
)

# add_baseline(fig, baseline0)

fig.update_xaxes(
    zerolinecolor="grey",
    zeroline=True,
    gridcolor="grey",
    griddash="dot",
    tickmode="array",
    tickvals=[0, 5, 10, 15, 20, 30],
)
fig.update_yaxes(
    tickformat=".0%",
    zerolinecolor="grey",
    zeroline=True,
    matches=None,
    showticklabels=True,
    gridcolor="grey",
    griddash="dot",
    minor=dict(showgrid=True),
)
fig.update_layout(
    plot_bgcolor="white",
    width=W,
    height=H,
    legend=dict(orientation="h", x=0.0, y=-0.2),
    margin=dict(l=20, r=20, t=20, b=0),
    yaxis1=dict(range=[0.0, baseline0["fastseg_small"] * 1.1]),
    yaxis2=dict(range=[0.0, baseline0["fastseg_large"] * 1.1]),
    yaxis3=dict(range=[0.0, baseline0["yolov8_n"] * 1.1]),
    yaxis4=dict(range=[0.0, baseline0["yolov8_s"] * 1.1]),
    yaxis5=dict(range=[0.0, baseline0["yolov8_l"] * 1.1]),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/score_grace_sionna"
fig.write_image(f"{fname}.png", scale=SCALE)
fig.write_html(f"{fname}.html")
fig.show()

################################################################################
# Comparison JPEG vs GRACE

df_jpeg_grace_comp = df_full_jpeg[
    [
        "model_id",
        "nchunks",
        "cr",
        "csnr_db",
        "score0_lvc",
        "score0_lvc_g",
        "codec",
        "param",
        "nbits_per_sym",
        "enc_size",
    ]
].query("csnr_db == 'inf' and codec in ['grace', 'jpeg']")
duplicates_jpeg_comp = df_jpeg_grace_comp[
    df_jpeg_grace_comp.duplicated(subset=["model_id", "codec", "enc_size"], keep=False)
]
df_jpeg_grace_comp = (
    df_jpeg_grace_comp.drop_duplicates(subset=["model_id", "codec", "enc_size"])
    .sort_values(by=["codec", "enc_size"])
    .sort_values(by=["model_id"], key=lambda x: x.map(MODEL_ORDER))
)

# df_jpeg_grace_comp["modulation"] = [f"{2**nb}-QAM" for nb in df_jpeg_grace_comp["nbits_per_sym"]]
df_jpeg_grace_comp = df_jpeg_grace_comp.round({"eq_cr": 3})

df_grace_comp = (
    df_jpeg_grace_comp.query("codec == 'grace'")
    .copy()
    .drop("score0_lvc", axis=1)
    .rename(columns={"score0_lvc_g": "accuracy"})
)
df_grace_comp["scheme"] = "GRACE"

df_jpeg_comp = (
    df_jpeg_grace_comp.query("codec == 'jpeg'")
    .copy()
    .drop("score0_lvc_g", axis=1)
    .rename(columns={"score0_lvc": "accuracy"})
)
df_jpeg_comp["scheme"] = "JPEG"

df_jpeg_grace_comp = pd.concat([df_grace_comp, df_jpeg_comp])

print(df_jpeg_grace_comp)

fig = px.scatter(
    df_jpeg_grace_comp,
    x="enc_size",
    y="accuracy",
    # color="param",
    color="scheme",
    facet_col="model_id",
    facet_col_spacing=0.025,
    # facet_row="nchunks",
    # symbol="modulation",
    # labels="modulation",
    # line_dash="eq_cr",
    # line_dash_sequence=["solid", "dash", "dot"],
    # symbol_sequence=["cross-thin-open", "x-thin-open", "y-down"],
    # markers=True,
    range_y=[0.0, None],
    title="JPEG vs GRACE",
)

# add_baseline(fig, baseline0)

fig.update_xaxes(
    zerolinecolor="grey",
    zeroline=True,
    gridcolor="grey",
    griddash="dot",
    tickmode="array",
    # tickvals=[0, 5, 10, 15, 20, 30],
    matches=None,
)
fig.update_yaxes(
    tickformat=".0%",
    zerolinecolor="grey",
    zeroline=True,
    matches=None,
    showticklabels=True,
    gridcolor="grey",
    griddash="dot",
    minor=dict(showgrid=True),
)
fig.update_layout(
    plot_bgcolor="white",
    width=W,
    height=H,
    legend=dict(orientation="h", x=0.0, y=-0.2),
    margin=dict(l=20, r=20, t=20, b=0),
    yaxis1=dict(range=[0.0, baseline0["fastseg_small"] * 1.1]),
    yaxis2=dict(range=[0.0, baseline0["fastseg_large"] * 1.1]),
    yaxis3=dict(range=[0.0, baseline0["yolov8_n"] * 1.1]),
    yaxis4=dict(range=[0.0, baseline0["yolov8_s"] * 1.1]),
    yaxis5=dict(range=[0.0, baseline0["yolov8_l"] * 1.1]),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/jpeg_grace_comparison"
fig.write_image(f"{fname}.png", scale=SCALE)
fig.write_html(f"{fname}.html")
fig.show()

################################################################################
# Sionna Channel results (RD)

model_order_units = {
    "fastseg_small (mIoU)": 0,
    "fastseg_large (mIoU)": 1,
    "yolov8_n (mAP)": 2,
    "yolov8_s (mAP)": 3,
    "yolov8_l (mAP)": 4,
}

msg = f"{estimator}, {mode}, {'bb' if block_dct else 'ff'}"
df = df_sionna[(df_sionna["csnr_db"] != "inf")]
df = filter_df(df, estimator, mode, block_dct, nchunks)[
    ["model_id", "nchunks", "cr", "csnr_db", "score0_lvc", "score0_lvc_g"]
]
df_g = df.copy().drop("score0_lvc", axis=1).rename(columns={"score0_lvc_g": "accuracy"})
df_g["scheme"] = "CV-Cast"
df = df.drop("score0_lvc_g", axis=1).rename(columns={"score0_lvc": "accuracy"})
df["scheme"] = "LCT"

df = pd.concat([df, df_g])
df = df.replace("fastseg_small", "fastseg_small (mIoU)")
df = df.replace("fastseg_large", "fastseg_large (mIoU)")
df = df.replace("yolov8_n", "yolov8_n (mAP)")
df = df.replace("yolov8_s", "yolov8_s (mAP)")
df = df.replace("yolov8_l", "yolov8_l (mAP)")
df = df.rename(columns={"cr": "CR"})
df = df.rename(columns={"csnr_db": "Eb/N0 (dB)"})

fig = px.line(
    df,
    x="CR",
    y="accuracy",
    color="Eb/N0 (dB)",
    facet_col="model_id",
    facet_col_spacing=0.025,
    # facet_row="model_id",
    # facet_row="nchunks",
    symbol="scheme",
    labels="scheme",
    line_dash="scheme",
    line_dash_sequence=["dash", "solid"],
    symbol_sequence=["cross-thin-open", "x-thin-open"],
    markers=True,
    range_y=[0.0, None],
)

add_baseline(
    fig,
    baseline0,
    # model_ids=["fastseg_large", "fastseg_small"],
    # col_major=True,
)

fig.update_xaxes(
    title="CR",
    zerolinecolor="grey",
    zeroline=True,
    gridcolor="grey",
    griddash="dot",
    tickmode="array",
    tickvals=[0.0, 0.03125, 0.0625, 0.125, 0.25, 0.5, 1.0],
    ticktext=[0, "", "", 0.125, 0.25, 0.5, 1.0],
    # tickvals=[0, 5, 10, 20, 30],
    # minor=dict(showgrid=True),
)
fig.update_yaxes(
    tickformat=".0%",
    zerolinecolor="grey",
    zeroline=True,
    matches=None,
    showticklabels=True,
    gridcolor="grey",
    griddash="dot",
    minor=dict(showgrid=True),
)
fig.update_layout(
    plot_bgcolor="white",
    width=W,
    height=H,
    legend=dict(orientation="h", x=0.0, y=-0.3),
    margin=dict(l=20, r=20, t=20, b=0),
    yaxis1=dict(range=[0.0, baseline0["fastseg_small"] * 1.1]),
    yaxis2=dict(range=[0.0, baseline0["fastseg_large"] * 1.1]),
    yaxis3=dict(range=[0.0, baseline0["yolov8_n"] * 1.1]),
    yaxis4=dict(range=[0.0, baseline0["yolov8_s"] * 1.1]),
    yaxis5=dict(range=[0.0, baseline0["yolov8_l"] * 1.1]),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/score_sionna_rd_zf_444_ff"
fig.write_image(f"{fname}.png", scale=SCALE)
fig.write_html(f"{fname}.html")
fig.show()

################################################################################
# Sionna Channel results (acc vs CSNR)

# msg = f"{estimator}, {mode}, {'bb' if block_dct else 'ff'}"
# df = df_sionna[df_sionna["csnr_db"] != "inf"]
# df = filter_df(df, estimator, mode, block_dct, nchunks)[
#     ["model_id", "nchunks", "cr", "csnr_db", "score0_lvc", "score0_lvc_g"]
# ]
# df_g = df.copy().drop("score0_lvc", axis=1).rename(columns={"score0_lvc_g": "accuracy"})
# df_g["scheme"] = "lvc g"
# df = df.drop("score0_lvc_g", axis=1).rename(columns={"score0_lvc": "accuracy"})
# df["scheme"] = "lvc"

# df = pd.concat([df, df_g])
# df = df.replace("fastseg_small", "fastseg_small (mIoU)")
# df = df.replace("fastseg_large", "fastseg_large (mIoU)")
# # df = df.replace("yolov8_n", "yolov8_n (mAP)")
# # df = df.replace("yolov8_s", "yolov8_s (mAP)")
# # df = df.replace("yolov8_l", "yolov8_l (mAP)")
# df = df.rename(columns={"cr": "CR"})

fig = px.line(
    df,
    x="Eb/N0 (dB)",
    y="accuracy",
    color="CR",
    facet_col="model_id",
    facet_col_spacing=0.025,
    # facet_row="model_id",
    # facet_row="nchunks",
    symbol="scheme",
    labels="scheme",
    line_dash="scheme",
    line_dash_sequence=["dash", "solid"],
    symbol_sequence=["cross-thin-open", "x-thin-open"],
    markers=True,
    range_y=[0.0, None],
)

add_baseline(
    fig,
    baseline0,
    # model_ids=["fastseg_large", "fastseg_small"],
    # col_major=True,
)

fig.update_xaxes(
    title="Eb/N0 (dB)",
    zerolinecolor="grey",
    zeroline=True,
    gridcolor="grey",
    griddash="dot",
    tickmode="array",
    # tickvals=[0, 5, 10, 20, 30],
    # minor=dict(showgrid=True),
)
fig.update_yaxes(
    tickformat=".0%",
    zerolinecolor="grey",
    zeroline=True,
    matches=None,
    showticklabels=True,
    gridcolor="grey",
    griddash="dot",
    minor=dict(showgrid=True),
)
fig.update_layout(
    plot_bgcolor="white",
    width=W,
    height=H,
    legend=dict(orientation="h", x=0.0, y=-0.2),
    margin=dict(l=20, r=20, t=20, b=0),
    yaxis1=dict(range=[0.0, baseline0["fastseg_small"] * 1.1]),
    yaxis2=dict(range=[0.0, baseline0["fastseg_large"] * 1.1]),
    yaxis3=dict(range=[0.0, baseline0["yolov8_n"] * 1.1]),
    yaxis4=dict(range=[0.0, baseline0["yolov8_s"] * 1.1]),
    yaxis5=dict(range=[0.0, baseline0["yolov8_l"] * 1.1]),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/score_sionna_csnr_zf_444_ff"
fig.write_image(f"{fname}.png", scale=SCALE)
fig.write_html(f"{fname}.html")
fig.show()

################################################################################
# Sionna Channel 420 results (acc vs CSNR) -- for VCIP, generate CSV files

msg = f"{estimator}, {mode}, {'bb' if block_dct else 'ff'}"
df = df_sionna_fss[df_sionna_fss["csnr_db"] != "inf"]
df = filter_df(df, estimator="zf", mode=420, block_dct=False, nchunks=64)[
    ["model_id", "nchunks", "cr", "csnr_db", "score0_lvc", "score0_lvc_g"]
]

base = baseline0["fastseg_small"]
df["dist_perc"] = 100 * (base - df["score0_lvc"]) / base

for cr in [0.1, 0.25, 0.5, 0.75, 1.0]:
    fname = f"{outdir}/fss_sionna_cr{str(cr).replace('.', '_')}.csv"
    print(fname)
    df[df["cr"] == cr].to_csv(fname, index=False)

df_g = df.copy().drop("score0_lvc", axis=1).rename(columns={"score0_lvc_g": "accuracy"})
df_g["scheme"] = "CV-Cast"
df = df.drop("score0_lvc_g", axis=1).rename(columns={"score0_lvc": "accuracy"})
df["scheme"] = "LCT"

df = pd.concat([df, df_g])
df = df.replace("fastseg_small", "fastseg_small (mIoU)")
df = df.replace("fastseg_large", "fastseg_large (mIoU)")
df = df.replace("yolov8_n", "yolov8_n (mAP)")
df = df.replace("yolov8_s", "yolov8_s (mAP)")
df = df.replace("yolov8_l", "yolov8_l (mAP)")
df = df.rename(columns={"cr": "CR"})

df = df[(df["scheme"] == "LCT") & (df["model_id"] == "fastseg_small (mIoU)")]

fig = px.line(
    df,
    x="csnr_db",
    y="accuracy",
    color="CR",
    # facet_col="model_id",
    # facet_row="model_id",
    # facet_row="nchunks",
    # symbol="scheme",
    # labels="scheme",
    # line_dash="scheme",
    # line_dash_sequence=["dash", "solid"],
    symbol_sequence=["cross-thin-open", "x-thin-open"],
    markers=True,
    range_y=[0.0, None],
    title="Sionna FSS 420 64 ZFE",
)

add_baseline(
    fig,
    baseline0,
    model_ids=["fastseg_small"],
    col_major=True,
)

fig.update_xaxes(
    title="Eb/N0 (dB)",
    zerolinecolor="grey",
    zeroline=True,
    gridcolor="grey",
    griddash="dot",
    tickmode="array",
    # tickvals=[0, 5, 10, 20, 30],
    # minor=dict(showgrid=True),
)
fig.update_yaxes(
    tickformat=".0%",
    zeroline=True,
    matches=None,
    showticklabels=True,
    gridcolor="grey",
    griddash="dot",
    minor=dict(showgrid=True),
)
fig.update_layout(
    plot_bgcolor="white",
    width=W2,
    height=H,
    legend=dict(orientation="h", x=0.0, y=-0.2),
    margin=dict(l=20, r=20, t=20, b=0),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/score_sionna_csnr_zf_420_ff_fss"
fig.write_image(f"{fname}.png", scale=SCALE)
fig.write_html(f"{fname}.html")
fig.show()

fig = px.line(
    df,
    x="csnr_db",
    y="dist_perc",
    color="CR",
    # facet_col="model_id",
    # facet_row="model_id",
    # facet_row="nchunks",
    # symbol="scheme",
    # labels="scheme",
    # line_dash="scheme",
    # line_dash_sequence=["dash", "solid"],
    symbol_sequence=["cross-thin-open", "x-thin-open"],
    markers=True,
    range_y=[0.0, None],
    title="Sionna FSS 420 64 ZFE",
)

add_baseline(
    fig,
    baseline0,
    model_ids=["fastseg_small"],
    col_major=True,
)

fig.update_xaxes(
    title="Eb/N0 (dB)",
    zerolinecolor="grey",
    zeroline=True,
    gridcolor="grey",
    griddash="dot",
    tickmode="array",
    # tickvals=[0, 5, 10, 20, 30],
    # minor=dict(showgrid=True),
)
fig.update_yaxes(
    # tickformat=".0%",
    zeroline=True,
    matches=None,
    showticklabels=True,
    gridcolor="grey",
    griddash="dot",
    minor=dict(showgrid=True),
)
fig.update_layout(
    plot_bgcolor="white",
    width=W2,
    height=H,
    legend=dict(orientation="h", x=0.0, y=-0.2),
    margin=dict(l=20, r=20, t=20, b=0),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/dist_sionna_csnr_zf_420_ff_fss"
fig.write_image(f"{fname}.png", scale=SCALE)
fig.write_html(f"{fname}.html")
fig.show()

################################################################################
# Disparity results

print("Disparity (10 dB, 0.25):")

mode_renamer = {"probe_model_id": "probe", "model_id": "eval"}
model_renamer = {
    "fastseg_small": "fss",
    "fastseg_large": "fsl",
    "yolov8_n": "y8n",
    "yolov8_s": "y8s",
    "yolov8_l": "y8l",
}

df_disp = df_disparity.query("csnr_db == 10 and cr == 0.25")[
    ["probe_model_id", "model_id", "score0_lvc", "score0_lvc_g"]
].sort_values(by=["probe_model_id", "model_id"], key=lambda x: x.map(MODEL_ORDER))
gt = df_disp.query("probe_model_id == model_id").set_index("probe_model_id")

df_disp_table = df_disp.pivot(
    index="probe_model_id", columns="model_id", values="score0_lvc_g"
)
df_disp_table = df_disp_table.sort_index(axis="index", key=lambda x: x.map(MODEL_ORDER))
df_disp_table = df_disp_table.sort_index(
    axis="columns", key=lambda x: x.map(MODEL_ORDER)
)

df_disp_table_diff = df_disp_table.sub(gt["score0_lvc_g"], axis="columns")
df_disp_table_diff = (df_disp_table_diff * 100).round(2)

df_disp_table_diff = df_disp_table_diff.rename(
    columns=model_renamer, index=model_renamer
)
df_disp_table_diff.index.names = ["probe"]
df_disp_table_diff.columns.names = ["eval"]

print(df_disp_table)
print(df_disp_table_diff)

fig = px.imshow(df_disp_table_diff, text_auto=True)
fig.update_xaxes(side="top")
fig.update_layout(
    plot_bgcolor="white",
    width=W2,
    height=W2,
    coloraxis_showscale=False,
    margin=dict(l=20, r=5, t=20, b=0),
)
fname = f"{outdir}/disparity_10db_cr0_25"
fig.write_image(f"{fname}.png", scale=SCALE)
fig.write_html(f"{fname}.html")
fig.show()

################################################################################
# BD metrics default

result_order = {
    "lvc": 0,
    "lvc_g_256": 1,
    "lvc_g_ff": 1,
    "lvc_g_zf": 1,
    "lvc_g": 2,
}

print("BD metrics default")
df = df_bd_default[
    (df_bd_default["csnr_db"] == 10)
    & (df_bd_default["nchunks"] == 256)
    & (df_bd_default["result"] != "total")
]
df = df.replace(True, "block-based")
df = df.replace(False, "full-frame")
df = df.replace("llse", "LLSE")
df = df.replace("zf", "ZFE")
df = df.rename(columns={"block_dct": "DCT"})
df = df.rename(columns={"bdrate": "BD-Rate"})
df = df.rename(columns={"bdacc": "BD-Accuracy"})

df1 = df[
    (df["DCT"] == "full-frame")
    & (df["estimator"] == "LLSE")
    & (df["result"] != "lvc_g_256")
    & (df["result"] != "lvc_g_ff")
]
df1 = df1.sort_values(by="estimator", ascending=False)
df1 = df1.sort_values(by=["result"], key=lambda x: x.map(result_order))
df1 = df1.sort_values(by=["model_id"], key=lambda x: x.map(MODEL_ORDER))

df1 = df1.rename(columns={"result": "anchor / target"})
df1 = df1.replace("lvc", "LCT (ZFE) / LCT (LLSE)")
df1 = df1.replace("lvc_g", "LCT (LLSE) / CV-Cast (LLSE)")
df1 = df1.replace("total", "LCT (ZFE) / CV-Cast (LLSE)")
df1 = df1.replace("lvc_g_ff", "CV-Cast (ZFE) / CV-Cast (bb)")
df1 = df1.replace("lvc_g_zf", "CV-Cast (ZFE) / CV-Cast (LLSE)")

fig = px.bar(
    df1,
    y="model_id",
    # y="result",
    x="BD-Rate",
    color="anchor / target",
    # color="model_id",
    # color_discrete_map={"ZFE": "#ef553b", "LLSE": "#636efa"},
    pattern_shape="anchor / target",
    # pattern_shape="model_id",
    # pattern_shape_map={"LLSE": "/", "ZFE": ""},
    # pattern_shape_sequence=["/", "\\"],
    barmode="group",
    orientation="h",
)

fig.update_yaxes(autorange="reversed", title=None)
fig.update_xaxes(
    title=None,
    zerolinecolor="grey",
    zeroline=True,
    autorange="reversed",
    gridcolor="grey",
    griddash="dot",
    minor=dict(showgrid=True),
)
fig.update_layout(
    showlegend=False,
    plot_bgcolor="white",
    width=W2,
    height=H * 0.8,
    legend=dict(orientation="h", x=0.0, y=-0.2),
    margin=dict(l=20, r=20, t=20, b=0),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/bdrate_estimators_2"
fig.write_image(f"{fname}.png", scale=SCALE)
fig.write_html(f"{fname}.html")
fig.show()

fig = px.bar(
    df1,
    x="BD-Accuracy",
    y="model_id",
    color="anchor / target",
    # color_discrete_map={"ZFE": "#ef553b", "LLSE": "#636efa"},
    pattern_shape="anchor / target",
    # pattern_shape_map={"LLSE": "/", "ZFE": ""},
    pattern_shape_sequence=["/", "\\"],
    barmode="group",
    orientation="h",
)

fig.update_yaxes(autorange="reversed", title=None)
fig.update_xaxes(
    title=None,
    zerolinecolor="grey",
    zeroline=True,
    gridcolor="grey",
    griddash="dot",
    minor=dict(showgrid=True),
)
fig.update_layout(
    showlegend=True,
    plot_bgcolor="white",
    width=W2,
    height=H,
    legend=dict(orientation="h", x=0.0, y=-0.2),
    margin=dict(l=20, r=20, t=20, b=0),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/bdacc_estimators_2"
fig.write_image(f"{fname}.png", scale=SCALE)
fig.write_html(f"{fname}.html")
fig.show()

df2 = df[
    (df["estimator"] == "ZFE")
    & (df["DCT"] == "block-based")
    & (df["result"] != "lvc_g_256")
    & (df["result"] != "lvc_g_zf")
]
df2 = df2.sort_values(by="DCT", ascending=False)
df2 = df2.sort_values(by=["result"], key=lambda x: x.map(result_order))
df2 = df2.sort_values(by=["model_id"], key=lambda x: x.map(MODEL_ORDER))

df2 = df2.rename(columns={"result": "anchor / target"})
df2 = df2.replace("lvc", "LCT (ff) / LCT (bb)")
df2 = df2.replace("lvc_g", "LCT (bb) / CV-Cast (bb)")
df2 = df2.replace("total", "LCT (ff) / CV-Cast (bb)")
df2 = df2.replace("lvc_g_ff", "CV-Cast (ff) / CV-Cast (bb)")
df2 = df2.replace("lvc_g_zf", "CV-Cast (ff) / CV-Cast (LLSE)")

fig = px.bar(
    df2,
    y="model_id",
    x="BD-Rate",
    color="anchor / target",
    # color_discrete_map={"full-frame": "#ef553b", "block-based": "#636efa"},
    pattern_shape="anchor / target",
    # pattern_shape_map={"full-frame": "", "block-based": "/"},
    pattern_shape_sequence=["/", "\\"],
    barmode="group",
    orientation="h",
)

fig.update_yaxes(autorange="reversed", title=None)
fig.update_xaxes(
    title=None,
    zerolinecolor="grey",
    zeroline=True,
    autorange="reversed",
    gridcolor="grey",
    griddash="dot",
    minor=dict(showgrid=True),
)
fig.update_layout(
    showlegend=False,
    plot_bgcolor="white",
    width=W2,
    height=H * 0.8,
    legend=dict(orientation="h", x=0.0, y=-0.2),
    margin=dict(l=20, r=20, t=20, b=0),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/bdrate_dcts_2"
fig.write_image(f"{fname}.png", scale=SCALE)
fig.write_html(f"{fname}.html")
fig.show()

fig = px.bar(
    df2,
    x="BD-Accuracy",
    y="model_id",
    color="anchor / target",
    # color_discrete_map={"full-frame": "#ef553b", "block-based": "#636efa"},
    pattern_shape="anchor / target",
    # pattern_shape_map={"full-frame": "", "block-based": "/"},
    pattern_shape_sequence=["/", "\\"],
    barmode="group",
    orientation="h",
)

fig.update_yaxes(autorange="reversed", title=None)
fig.update_xaxes(
    title=None,
    zerolinecolor="grey",
    zeroline=True,
    gridcolor="grey",
    griddash="dot",
    minor=dict(showgrid=True),
)
fig.update_layout(
    showlegend=True,
    plot_bgcolor="white",
    width=W2,
    height=H,
    legend=dict(orientation="h", x=0.0, y=-0.2),
    margin=dict(l=20, r=20, t=20, b=0),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/bdacc_dcts_2"
fig.write_image(f"{fname}.png", scale=SCALE)
fig.write_html(f"{fname}.html")
fig.show()

################################################################################
# Print out 64 vs 1024 chunks table

print("BD metrics chunks comparison")

df = (
    df_bd[
        (df_bd["csnr_db"] == 10)
        & (df_bd["estimator"] == "zf")
        & (df_bd["block_dct"] == False)
    ]
    .drop(["estimator", "mode", "csnr_db", "block_dct"], axis=1)
    .round(2)
)
print(df)

print("BD metrics default chunks comparison")

df = (
    df_bd_default[
        (df_bd_default["csnr_db"] == 10)
        & (df_bd_default["nchunks"] != 256)
        & (df_bd_default["estimator"] == "zf")
        & (df_bd_default["block_dct"] == False)
        & (df_bd_default["result"] != "lvc_g_ff")
        & (df_bd_default["result"] != "lvc_g_zf")
        & (df_bd_default["result"] != "total")
    ]
    .drop(["estimator", "mode", "csnr_db", "block_dct"], axis=1)
    .round(2)
)
print(df)

print("BD metrics chunks comparison (precise)")

df = (
    df_bd_precise[
        (df_bd_precise["csnr_db"] == 10)
        & (df_bd_precise["estimator"] == "zf")
        & (df_bd_precise["block_dct"] == False)
    ]
    .drop(["estimator", "mode", "csnr_db", "block_dct"], axis=1)
    .round({"bdrate": 2, "bdacc": 2})
)
print(df)

################################################################################
# Print out Sionna BD metrics

print("Sionna BD metrics")

df = (
    df_bd_sionna[
        # (df_bd_sionna["csnr_db"] == 10)
        (df_bd_sionna["estimator"] == "zf")
        & (df_bd_sionna["block_dct"] == False)
        & (df_bd_sionna["nchunks"] == 256)
    ]
    .drop(["estimator", "mode", "block_dct"], axis=1)
    .round(2)
)
print(df)

################################################################################
# FastSeg small var / grad / var + grad map

print("FastSeg small")
probe_data = torch.load("experiments_hupu/runs/run14_keep/probe_results.pt")[
    "fastseg_small"
]
print(list(probe_data.keys()))

mode = 444
grad_key = "grads_norm_420" if mode == 420 else "grads_norm"

dct_var_y_64 = probe_data["dct_var"][64][0].unsqueeze(dim=0).numpy()
dct_var_y_256 = probe_data["dct_var"][256][0].unsqueeze(dim=0).numpy()
dct_var_y_1024 = probe_data["dct_var"][1024][0].unsqueeze(dim=0).numpy()

gnorm_y_64 = probe_data[grad_key][64][0].unsqueeze(dim=0).numpy()
gnorm_y_256 = probe_data[grad_key][256][0].unsqueeze(dim=0).numpy()
gnorm_y_1024 = probe_data[grad_key][1024][0].unsqueeze(dim=0).numpy()

gnorm_sq_y_64 = probe_data[grad_key][64][0].square().unsqueeze(dim=0).numpy()
gnorm_sq_y_256 = probe_data[grad_key][256][0].square().unsqueeze(dim=0).numpy()
gnorm_sq_y_1024 = probe_data[grad_key][1024][0].square().unsqueeze(dim=0).numpy()

prod_y_64 = gnorm_sq_y_64 * dct_var_y_64
prod_y_256 = gnorm_sq_y_256 * dct_var_y_256
prod_y_1024 = gnorm_sq_y_1024 * dct_var_y_1024

img_64 = np.concatenate([dct_var_y_64, gnorm_y_64, prod_y_64])
img_256 = np.concatenate([dct_var_y_256, gnorm_y_256, prod_y_256])
img_1024 = np.concatenate([dct_var_y_1024, gnorm_y_1024, prod_y_1024])

heatmaps(img_64, 1, 3, None, show=True, save=f"{outdir}/maps_y_fastseg_small_64.png")
heatmaps(img_256, 1, 3, None, show=True, save=f"{outdir}/maps_y_fastseg_small_256.png")
# heatmaps(img_1024, 1, 3, None, show=True, save=None)

################################################################################
# YOLOv8s small var / grad / var + grad map

print("YOLOv8s")
probe_data = torch.load("experiments_hupu/runs/run14_keep/probe_results.pt")["yolov8_s"]

mode = 444
grad_key = "grads_norm_420" if mode == 420 else "grads_norm"

dct_var_y_64 = probe_data["dct_var"][64][0].unsqueeze(dim=0).numpy()
dct_var_y_256 = probe_data["dct_var"][256][0].unsqueeze(dim=0).numpy()
dct_var_y_1024 = probe_data["dct_var"][1024][0].unsqueeze(dim=0).numpy()

gnorm_y_64 = probe_data[grad_key][64][0].unsqueeze(dim=0).numpy()
gnorm_y_256 = probe_data[grad_key][256][0].unsqueeze(dim=0).numpy()
gnorm_y_1024 = probe_data[grad_key][1024][0].unsqueeze(dim=0).numpy()

gnorm_sq_y_64 = probe_data[grad_key][64][0].square().unsqueeze(dim=0).numpy()
gnorm_sq_y_256 = probe_data[grad_key][256][0].square().unsqueeze(dim=0).numpy()
gnorm_sq_y_1024 = probe_data[grad_key][1024][0].square().unsqueeze(dim=0).numpy()

prod_y_64 = gnorm_sq_y_64 * dct_var_y_64
prod_y_256 = gnorm_sq_y_256 * dct_var_y_256
prod_y_1024 = gnorm_sq_y_1024 * dct_var_y_1024

img_64 = np.concatenate([dct_var_y_64, gnorm_y_64, prod_y_64])
img_256 = np.concatenate([dct_var_y_256, gnorm_y_256, prod_y_256])
img_1024 = np.concatenate([dct_var_y_1024, gnorm_y_1024, prod_y_1024])

# heatmaps(img_64, 1, 3, None, show=True, save=None)
heatmaps(img_256, 1, 3, None, show=True, save=f"{outdir}/maps_y_yolov8s.png")
# heatmaps(img_1024, 1, 3, None, show=True, save=None)

In [None]:
# Generate new figures for paper (doesn't require anything)

import warnings
from pathlib import Path

from experiments.plots import *

W = 1000
W2 = 400
W3 = 500
H = 300
H2 = 225
H3 = 250
SCALE = 4.0

outdir = Path("experiments/plots")
outdir.mkdir(exist_ok=True)

fname_plot = acc_vs_csnr(outdir, W, H, scale=SCALE)
fname_plot_sionna = acc_vs_csnr_sionna(outdir, W, H, scale=SCALE)
fname_scheme, fname_cr = dummy_legends(outdir, W, H, scale=SCALE, do_show=True)
apply_dummy_legends(
    outdir, W, H, [fname_plot, fname_plot_sionna], fname_scheme, fname_cr, scale=SCALE
)

disparity_10db_0_25(outdir, W2, W2, scale=SCALE)
disparity_10db_bd(outdir, W2, W2, H2, scale=SCALE)

jpeg_grace(outdir, W, H, scale=SCALE)
jpeg_grace_sionna(outdir, W3, H3, scale=SCALE, codec="jpeg_grace")
jpeg_grace_sionna(outdir, W3, H3, scale=SCALE, codec="tcm")

bd_default2(outdir, W2, H, scale=SCALE)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    bd_default(outdir, W2, H, scale=SCALE)
    bd_table_64_vs_1024(outdir)
    bd_table_sionna_metrics(outdir)

maps_fss(outdir, W, H, SCALE)
maps_y8s(outdir, W, H, SCALE)

grace_quant_table(outdir, W, H, SCALE)

tcm_cr()

In [None]:
# Other plots (requires above cell)

estimators = ["zf", "llse"]
modes = [444]
block_dcts = [False, True]


################################################################################
# LLSE vs ZF comparison

mode = 444
block_dct = False
csnr_db = 10
msg = f"{mode}, {'bb' if block_dct else 'ff'}, {csnr_db} dB"
nchunks = 256

df = df_full[df_full["csnr_db"] == csnr_db]

df_zf = filter_df(df, "zf", mode, block_dct, nchunks)[
    ["model_id", "estimator", "nchunks", "cr", "csnr_db", "score0_lvc", "score0_lvc_g"]
]
df_zf_g = (
    df_zf.copy().drop("score0_lvc", axis=1).rename(columns={"score0_lvc_g": "score0"})
)
df_zf_g["scheme"] = "lvc g"
df_zf = df_zf.drop("score0_lvc_g", axis=1).rename(columns={"score0_lvc": "score0"})
df_zf["scheme"] = "lvc"

df_llse = filter_df(df, "llse", mode, block_dct, nchunks)[
    ["model_id", "estimator", "nchunks", "cr", "csnr_db", "score0_lvc", "score0_lvc_g"]
]
df_llse_g = (
    df_llse.copy().drop("score0_lvc", axis=1).rename(columns={"score0_lvc_g": "score0"})
)
df_llse_g["scheme"] = "lvc g"
df_llse = df_llse.drop("score0_lvc_g", axis=1).rename(columns={"score0_lvc": "score0"})
df_llse["scheme"] = "lvc"

df = pd.concat([df_zf, df_zf_g, df_llse, df_llse_g])

fig = px.line(
    df,
    x="cr",
    y="score0",
    color="estimator",
    facet_col="model_id",
    # facet_row="nchunks",
    symbol="scheme",
    labels="scheme",
    line_dash="scheme",
    line_dash_sequence=["dash", "solid"],
    symbol_sequence=["cross-thin-open", "x-thin-open"],
    markers=True,
)

add_baseline(fig, baseline0)

fig.update_yaxes(matches=None, showticklabels=True)
fig.update_layout(
    plot_bgcolor="white",
    # title={
    #     "text": f"Default + Grad-optimized LVC RD estimator comparison ({msg})",
    #     "x": 0.5,
    #     "xanchor": "center",
    # },
    width=W,
    height=H,
    # legend=dict(orientation="h"),
    # margin=dict(l=20, r=20, t=20, b=20),
    legend=dict(orientation="h", x=0.0, y=-0.2),
    margin=dict(l=20, r=20, t=20, b=0),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/score_rd_444_ff_estimator_comparison"
# fig.write_image(f"{fname}.png", scale=2.0)
# fig.write_html(f"{fname}.html")
fig.show()

################################################################################
# Full-frame vs block-based comparison

mode = 444
estimator = "zf"
csnr_db = 10
msg = f"{estimator}, {mode}, {csnr_db} dB"
nchunks = 256

df = df_full[df_full["csnr_db"] == csnr_db]

df_ff = filter_df(df, estimator, mode, False, nchunks)[
    ["model_id", "nchunks", "cr", "csnr_db", "block_dct", "score0_lvc", "score0_lvc_g"]
]
df_ff_g = (
    df_ff.copy().drop("score0_lvc", axis=1).rename(columns={"score0_lvc_g": "score0"})
)
df_ff_g["scheme"] = "lvc g"
df_ff = df_ff.drop("score0_lvc_g", axis=1).rename(columns={"score0_lvc": "score0"})
df_ff["scheme"] = "lvc"

df_bb = filter_df(df, estimator, mode, True, nchunks)[
    ["model_id", "nchunks", "cr", "csnr_db", "block_dct", "score0_lvc", "score0_lvc_g"]
]
df_bb_g = (
    df_bb.copy().drop("score0_lvc", axis=1).rename(columns={"score0_lvc_g": "score0"})
)
df_bb_g["scheme"] = "lvc g"
df_bb = df_bb.drop("score0_lvc_g", axis=1).rename(columns={"score0_lvc": "score0"})
df_bb["scheme"] = "lvc"

df = pd.concat([df_ff, df_ff_g, df_bb, df_bb_g])

fig = px.line(
    df,
    x="cr",
    y="score0",
    color="block_dct",
    facet_col="model_id",
    # facet_row="nchunks",
    symbol="scheme",
    labels="scheme",
    line_dash="scheme",
    line_dash_sequence=["dash", "solid"],
    symbol_sequence=["cross-thin-open", "x-thin-open"],
    markers=True,
)

add_baseline(fig, baseline0)

fig.update_yaxes(matches=None, showticklabels=True)
fig.update_layout(
    plot_bgcolor="white",
    # title={
    #     "text": f"Default + Grad-optimized LVC RD full-frame vs block-based DCT ({msg})",
    #     "x": 0.5,
    #     "xanchor": "center",
    # },
    width=W,
    height=H,
    # legend=dict(orientation="h"),
    # margin=dict(l=20, r=20, t=20, b=20),
    legend=dict(orientation="h", x=0.0, y=-0.2),
    margin=dict(l=20, r=20, t=20, b=0),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/score_rd_zf_444_block_dct_comparison"
# fig.write_image(f"{fname}.png", scale=2.0)
# fig.write_html(f"{fname}.html")
fig.show()

################################################################################
# Loss Distortion

print("Loss Distortion")
estimator = "zf"
mode = 444
block_dct = False

msg = f"{estimator}, {mode}, {'bb' if block_dct else 'ff'}"
df = df_full[
    (df_full["csnr_db"] != "inf")
    & (
        (df_full["model_id"] == "fastseg_small")
        | (df_full["model_id"] == "fastseg_large")
    )
]
df = filter_df(df, estimator, mode, block_dct)[
    ["model_id", "nchunks", "cr", "csnr_db", "score0_lvc_g", "loss_dist"]
]

fig = px.line(
    df,
    x="cr",
    y="loss_dist",
    color="csnr_db",
    facet_col="model_id",
    facet_row="nchunks",
    # symbol="scheme",
    # labels="scheme",
    # line_dash="scheme",
    # line_dash_sequence=["dash", "solid"],
    symbol_sequence=["cross-thin-open", "x-thin-open"],
    markers=True,
    log_y=True,
    range_y=(10**-9, 100),
)

# add_baseline(fig, baseline0)
add_zero_line(fig)

# fig.update_yaxes(matches=None, showticklabels=True)
fig.update_layout(
    title={
        "text": f"Loss Distortion ({msg})",
        "x": 0.5,
        "xanchor": "center",
    },
    width=1000,
    height=600,
    legend=dict(orientation="h"),
    margin=dict(l=30, r=30, t=30, b=30),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/loss_dist_zf_444_ff"
# fig.write_image(f"{fname}.png", scale=2.0)
# fig.write_html(f"{fname}.html")
fig.show()

# score vs. distortion

df = df_full[
    (df_full["csnr_db"] != "inf")
    & (
        (df_full["model_id"] == "fastseg_small")
        | (df_full["model_id"] == "fastseg_large")
    )
]
df = df.sort_values(by="score0_lvc_g")
msg = "all configs combined"

fig = px.line(
    df,
    x="score0_lvc_g",
    y="loss_dist",
    # color="cr",
    facet_col="model_id",
    facet_row="nchunks",
    # symbol="csnr_db",
    # labels="scheme",
    # line_dash="scheme",
    # line_dash_sequence=["dash", "solid"],
    symbol_sequence=["cross-thin-open"],  # , "x-thin-open"],
    markers=True,
    log_y=True,
    range_y=(10**-9, 100),
)

for i, model_id in enumerate(["fastseg_small", "fastseg_large"], start=1):
    ratio = 0.1
    base = baseline0[model_id]
    print(base)
    fig.add_vrect(
        x0=base - ratio * base,
        x1=base,
        col=i,
        annotation_text=f"{ratio*100}%",
        annotation_position="top left",
        fillcolor="green",
        opacity=0.25,
        line_width=0,
    )
    # fig.add_hline(
    #     y=baseline[model_id],
    #     line_dash="dot",
    #     annotation_text=f"{baseline[model_id]:.4f}",
    #     annotation_position="top left",
    #     row=0,
    #     col=i,
    # )
    # fig.add_hline(
    #     y=baseline[model_id],
    #     line_dash="dot",
    #     annotation_text=f"{baseline[model_id]:.4f}",
    #     annotation_position="top left",
    #     row=1,
    #     col=i,
    # )
    # fig.add_hline(
    #     y=baseline[model_id],
    #     line_dash="dot",
    #     annotation_text=f"{baseline[model_id]:.4f}",
    #     annotation_position="top left",
    #     row=2,
    #     col=i,
    # )

# add_baseline(fig, baseline0)
# add_zero_line(fig)

# fig.update_yaxes(matches=None, showticklabels=True)
fig.update_layout(
    title={
        "text": f"accuracy vs. loss distortion ({msg})",
        "x": 0.5,
        "xanchor": "center",
    },
    width=1000,
    height=600,
    legend=dict(orientation="h"),
    margin=dict(l=30, r=30, t=30, b=30),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/acc_vs_loss_dist_zf_444_ff"
# fig.write_image(f"{fname}.png", scale=2.0)
# fig.write_html(f"{fname}.html")
fig.show()

################################################################################
# BD metrics

print("BD metrics")
df = df_bd[df_bd["csnr_db"] != "inf"]
# df = df_bd[df_bd["csnr_db"] == 10]

fig = px.bar(
    df,
    x="csnr_db",
    # x="estimator",
    y="bdrate",
    color="estimator",
    facet_col="model_id",
    facet_row="nchunks",
    # symbol="block_dct",
    pattern_shape="block_dct",
    # labels="",
    # line_dash="block_dct",
    # line_dash_sequence=["dash", "solid"],
    # symbol_sequence=["cross-thin-open", "x-thin-open"],
    # markers=True,
    barmode="group",
)

add_zero_line(fig)

# fig.update_yaxes(matches=None, showticklabels=True)
fig.update_layout(
    # title={
    #     "text": f"Default + Grad-optimized LVC RD ({msg})",
    #     "x": 0.5,
    #     "xanchor": "center",
    # },
    width=1000,
    height=600,
    legend=dict(orientation="h"),
    margin=dict(l=20, r=20, t=20, b=20),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/bdrate"
# fig.write_image(f"{fname}.png", scale=2.0)
# fig.write_html(f"{fname}.html")
fig.show()

fig = px.bar(
    df,
    x="csnr_db",
    y="bdacc",
    color="estimator",
    facet_col="model_id",
    facet_row="nchunks",
    # symbol="block_dct",
    pattern_shape="block_dct",
    # labels="",
    # line_dash="block_dct",
    # line_dash_sequence=["dash", "solid"],
    # symbol_sequence=["cross-thin-open", "x-thin-open"],
    # markers=True,
    barmode="group",
)

add_zero_line(fig)

# fig.update_yaxes(matches=None, showticklabels=True)
fig.update_layout(
    # title={
    #     "text": f"Default + Grad-optimized LVC RD ({msg})",
    #     "x": 0.5,
    #     "xanchor": "center",
    # },
    width=1000,
    height=600,
    legend=dict(orientation="h"),
    margin=dict(l=20, r=20, t=20, b=20),
)
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fname = f"{outdir}/bdacc"
# fig.write_image(f"{fname}.png", scale=2.0)
# fig.write_html(f"{fname}.html")
fig.show()

### Other

for estimator, mode, block_dct in itertools.product(estimators, modes, block_dcts):
    msg = f"{estimator}, {mode}, {'bb' if block_dct else 'ff'}"
    df_zf = filter_df(df_full, estimator, mode, block_dct)
    df_zf = df_zf[df_zf["csnr_db"] != "inf"]

    # Default + Grad-optimized LVC
    fig = px.line(
        df_zf,
        x="csnr_db",
        y="score0_lvc",
        color="cr",
        facet_col="model_id",
        facet_row="nchunks",
        markers=True,
    )
    fig.update_traces(
        line={"dash": "dash"},
        marker=dict(size=5, symbol="cross-thin-open", line=dict(width=1.25)),
    )

    df_zf["cr"] = df_zf["cr"].copy().astype(str) + "G"
    fig2 = px.line(
        df_zf,
        x="csnr_db",
        y="score0_lvc_g",
        color="cr",
        facet_col="model_id",
        facet_row="nchunks",
        markers=True,
    )
    fig2.update_traces(marker=dict(size=5, symbol="x-thin-open", line=dict(width=1.25)))

    fig.add_traces(list(fig2.select_traces()))
    add_baseline(fig, baseline0)

    fig.update_yaxes(matches=None, showticklabels=True)
    fig.update_layout(
        title={
            "text": f"Default + Grad-optimized LVC ({msg})",
            "x": 0.5,
            "xanchor": "center",
        },
        height=600,
    )
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fname = f"{outdir}/score0_lvc_g"
    # fig.write_image(f"{fname}.png", scale=2.0)
    # fig.write_html(f"{fname}.html")
    fig.show()

for estimator, mode, block_dct in itertools.product(estimators, modes, block_dcts):
    msg = f"{estimator}, {mode}, {'bb' if block_dct else 'ff'}"
    df_zf = filter_df(df_full, estimator, mode, block_dct)
    df_zf = df_zf[df_zf["csnr_db"] != "inf"]

    # Default + Grad-optimized LVC (RD)
    fig = px.line(
        df_zf,
        x="cr",
        y="score0_lvc",
        color="csnr_db",
        facet_col="model_id",
        facet_row="nchunks",
        markers=True,
    )
    fig.update_traces(
        line={"dash": "dash"},
        marker=dict(size=5, symbol="cross-thin-open", line=dict(width=1.25)),
    )

    df_zf["csnr_db"] = df_zf["csnr_db"].copy().astype(str) + "G"
    fig2 = px.line(
        df_zf,
        x="cr",
        y="score0_lvc_g",
        color="csnr_db",
        facet_col="model_id",
        facet_row="nchunks",
        markers=True,
    )
    fig2.update_traces(marker=dict(size=5, symbol="x-thin-open", line=dict(width=1.25)))

    fig.add_traces(list(fig2.select_traces()))
    add_baseline(fig, baseline0)

    fig.update_yaxes(matches=None, showticklabels=True)
    fig.update_layout(
        title={
            "text": f"Default + Grad-optimized LVC RD ({msg})",
            "x": 0.5,
            "xanchor": "center",
        },
        height=600,
    )
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fname = f"{outdir}/score0_lvc_g"
    # fig.write_image(f"{fname}.png", scale=2.0)
    # fig.write_html(f"{fname}.html")
    fig.show()

for estimator, mode, block_dct in itertools.product(estimators, modes, block_dcts):
    msg = f"{estimator}, {mode}, {'bb' if block_dct else 'ff'}"
    df_zf = filter_df(df_full, estimator, mode, block_dct)

    # Difference grad-optimized - default LVC
    fig = px.line(
        df_zf,
        x="csnr_db",
        y="diff0_lvc_g",
        color="cr",
        facet_col="model_id",
        facet_row="nchunks",
        markers=True,
    )
    fig.update_traces(
        marker=dict(size=5, symbol="cross-thin-open", line=dict(width=1.25))
    )
    add_zero_line(fig)
    fig.update_layout(
        title={
            "text": f"Difference between gradient-optimized and original LVC ({msg})",
            "x": 0.5,
            "xanchor": "center",
        },
        height=600,
    )
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fname = f"{outdir}/diff0_lvc_g"
    # fig.write_image(f"{fname}.png", scale=2.0)
    # fig.write_html(f"{fname}.html")
    fig.show()

# Ablation studies
df_ablation["grad_task"] = (
    df_ablation["grad_w"]
    + df_ablation["grad_sel"] * 10
    + df_ablation["grad_alloc"] * 100
)

for nchunks in [64, 256]:
    df_zf = df_ablation[df_ablation["nchunks"] == nchunks]
    fig = px.line(
        df_zf,
        x="csnr_db",
        y="diff0_lvc_g",
        color="grad_task",
        facet_col="model_id",
        facet_row="cr",
        markers=True,
    )
    fig.update_traces(
        marker=dict(size=5, symbol="cross-thin-open", line=dict(width=1.25))
    )
    add_zero_line(fig)
    # fig.update_yaxes(matches=None, showticklabels=True)
    fig.update_layout(
        title={
            "text": f"Ablation study ({nchunks} chunks), used gradients for: 1 - w_yuv, 10 - select, 100 - allocate",
            "x": 0.5,
            "xanchor": "center",
        },
        height=600,
    )
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fname = f"{outdir}/ablation0_lvc_g_{nchunks}"
    # fig.write_image(f"{fname}.png", scale=2.0)
    # fig.write_html(f"{fname}.html")
    fig.show()

for nchunks in [64, 256]:
    df_zf = df_ablation[df_ablation["nchunks"] == nchunks].dropna()
    fig = px.line(
        df_zf,
        x="csnr_db",
        y="diff1_lvc_g",
        color="grad_task",
        facet_col="model_id",
        facet_row="cr",
        markers=True,
    )
    fig.update_traces(
        marker=dict(size=5, symbol="cross-thin-open", line=dict(width=1.25))
    )
    add_zero_line(fig)
    # fig.update_yaxes(matches=None, showticklabels=True)
    fig.update_layout(
        title={
            "text": f"Ablation study ({nchunks} chunks), used gradients for: 1 - w_yuv, 10 - select, 100 - allocate (val loss)",
            "x": 0.5,
            "xanchor": "center",
        },
        height=600,
    )
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    fname = f"{outdir}/ablation1_lvc_g_{nchunks}"
    # fig.write_image(f"{fname}.png", scale=2.0)
    # fig.write_html(f"{fname}.html")
    fig.show()

# Reprobing
# for nchunks in [64, 256]:
#     df = df_full[df_full["nchunks"] == nchunks]
#     df = df[df["grad_w"] & df["grad_sel"] & df["grad_alloc"]]
#     fig = px.line(
#         df,
#         x="csnr_db",
#         y=["diff0_lvc_g", "diff0_lvc_reprobe", "diff0_lvc_g_reprobe"],
#         # color="cr",
#         facet_col="model_id",
#         facet_row="cr",
#         markers=True,
#     )
#     add_zero_line(fig)
#     # fig.update_yaxes(matches=None, showticklabels=True)
#     fig.update_layout(
#         title={
#             "text": f"Reprobe results ({nchunks} chunks) (score)",
#             "x": 0.5,
#             "xanchor": "center",
#         }
#     )
#     fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
#     fname = f"{outdir}/reprobe0_{nchunks}"
#     fig.write_image(f"{fname}.png", scale=2.0)
#     fig.write_html(f"{fname}.html")
#     fig.show()
#
# Reprobing (val loss)
# for nchunks in [64, 256]:
#     df = df_full[df_full["nchunks"] == nchunks]
#     df = df[df["grad_w"] & df["grad_sel"] & df["grad_alloc"]]
#     df = df.dropna()
#     fig = px.line(
#         df,
#         x="csnr_db",
#         y=["diff1_lvc_g", "diff1_lvc_reprobe", "diff1_lvc_g_reprobe"],
#         # color="cr",
#         facet_col="model_id",
#         facet_row="cr",
#         markers=True,
#     )
#     add_zero_line(fig)
#     # fig.update_yaxes(matches=None, showticklabels=True)
#     fig.update_layout(
#         title={
#             "text": f"Reprobe results ({nchunks} chunks) (val loss)",
#             "x": 0.5,
#             "xanchor": "center",
#         }
#     )
#     fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
#     fname = f"{outdir}/reprobe1_{nchunks}"
#     fig.write_image(f"{fname}.png", scale=2.0)
#     fig.write_html(f"{fname}.html")
#     fig.show()