In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

import sys
from pathlib import Path
import matplotlib.pyplot as plt
import tikzplotlib

try:
    sys.path.append(str(Path().cwd().parent))
except IndexError:
    pass

from external.bjontegaard_metrics.bj_delta import *
from cqp.bitrate_model.util import *

In [None]:
BASE_PATH = Path.cwd().parent / "util" / "bd-gop-1"
METRIC = "vmaf"

BITRATE_MAX_LIM = 8000

all_files = list(BASE_PATH.glob("*.csv"))

df_baseline = pd.read_csv(BASE_PATH / "quality-none-1-05.csv")

In [None]:
def fix_label_string(label: str) -> str:
    return label.replace("-05", "") if "gauss" not in label else label


def calc_avg_over_multiple_videos(df: pd.DataFrame) -> pd.DataFrame:
    """Calculate the average bitrate and quality scores over multiple videos"""
    return df.groupby(by=[KEYS.QP, KEYS.KSIZE, KEYS.SIGMA]).mean().reset_index()


def plot_rd_curve(df, label):
    """Plot single RD curve"""
    # Remove sigma if not gauss
    label = fix_label_string(label)
    df = calc_avg_over_multiple_videos(df)
    plt.plot(df.bitrate, df[METRIC], label=label)
    plt.scatter(df.bitrate, df[METRIC])
    plt.ylabel(METRIC.upper())
    plt.xlabel(f"Bitrate [kbit/s] (QP $\in$ [{int(df.qp.max())}, ..., {int(df.qp.min())}])")
    plt.xlim(0, BITRATE_MAX_LIM)


def plot_rd_curves(df_files):
    for file in df_files:
        plot_rd_curve(pd.read_csv(BASE_PATH / file), get_label_from_file_name(file))
    plt.legend()


def plot_rd_plane(df, filter_param):
    df = calc_avg_over_multiple_videos(df)
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    df = df.sort_values(by=[KEYS.SIGMA, KEYS.KSIZE, KEYS.QP])
    qp_len = len(pd.unique(df.qp))
    filter_param_len = int(df.shape[0] / qp_len)
    bitrate = np.reshape(df.bitrate.to_numpy(), (filter_param_len, qp_len))
    quality = np.reshape(df[METRIC].to_numpy(), (int(df.shape[0] / qp_len), qp_len))
    for i in range(filter_param_len):
        ax.scatter3D(bitrate[i], bitrate[i], quality[i])
    plt.xlabel(f"Bitrate [kbit/s] (QP $\in$ [{int(df.qp.max())}, ..., {int(df.qp.min())}])")
    plt.ylabel(f"Bitrate [kbit/s] ({filter_param})")
    ax.set_zlabel(METRIC.upper(), rotation=90)


def plot_rd_curve_filter(df):
    """Plot RD curve for multiple filter params"""
    df = calc_avg_over_multiple_videos(df)
    grouped = df.groupby([KEYS.QP])
    for key in grouped.groups.keys():
        group = grouped.get_group(key).sort_values(by=[KEYS.BITRATE])
        plt.scatter(group.bitrate, group[METRIC])
        plt.plot(group.bitrate, group[METRIC])
    plt.xlim(0, BITRATE_MAX_LIM)


def extend_data_by_quality_cost_and_rate_savings(df, df_baseline):
    df = calc_avg_over_multiple_videos(df)
    df_baseline = calc_avg_over_multiple_videos(df_baseline)
    df["savings"] = df_baseline.sort_values(by=KEYS.QP).bitrate - df.sort_values(by=KEYS.QP).bitrate
    df["cost"] = df_baseline.sort_values(by=KEYS.QP)[METRIC] - df.sort_values(by=KEYS.QP)[METRIC]
    return df


def extend_all_by_cost_and_savings(df_files, df_baseline):
    return pd.concat([
        extend_data_by_quality_cost_and_rate_savings(pd.read_csv(BASE_PATH / file), df_baseline)
        for file in df_files
    ])


def plot_quality_cost_over_rate_savings(df, label):
    label = fix_label_string(label)
    plt.scatter(df.savings, df.cost, label=label)
    labels = [int(qp) for qp in df.qp]
    labels.reverse()
    plt.xlabel(f"$\Delta$ Bitrate [kbit/s] (QP $\in$ [{int(df.qp.min())}, ..., {int(df.qp.max())}])")
    plt.ylabel(f"$\Delta$ {METRIC.upper()}")
    plt.xlim(-250, 5000)
    plt.legend()


def plot_cost_savings(df_files, df_baseline):
    df_base_extended = extend_data_by_quality_cost_and_rate_savings(df_baseline, df_baseline)
    plot_quality_cost_over_rate_savings(df_base_extended, "none")
    for file in df_files:
        df = pd.read_csv(BASE_PATH / file)
        df = extend_data_by_quality_cost_and_rate_savings(df, df_baseline)
        plot_quality_cost_over_rate_savings(df, get_label_from_file_name(file))


def plot_cost_savings_filter(df_files, df_baseline):
    df = extend_all_by_cost_and_savings(df_files, df_baseline)
    grouped = df.groupby([KEYS.QP])
    for key in grouped.groups.keys():
        group = grouped.get_group(key).sort_values(by=["savings"])
        plt.scatter(group.savings, group.cost)
        plt.plot(group.savings, group.cost)


def plot_bd_over_qp(df, key_rate=KEYS.BITRATE, key_quality=METRIC):
    plt.scatter(df.qp, df[key_quality], label=f"BD-{key_quality.upper()}")
    plt.plot(df.qp, df[key_quality])
    plt.scatter(df.qp, df[key_rate], label="BD-Rate [$\%$]")
    plt.plot(df.qp, df[key_rate])
    plt.axhline(0, color="gray")
    plt.xlabel("QP")
    plt.ylabel("BD")
    plt.legend()
    print(
        f"BD-{key_quality.upper()} and BD-Rate over QP\n"
        f"Avg.: BD-{key_quality.upper()}: {df[key_quality].mean():.2f}, BD-Rate: {df[key_rate].mean():.2f}"
    )


def print_bj_delta(df_one, df_two):
    print("BD-PSNR: ", bj_delta(df_one.bitrate, df_one.psnr, df_two.bitrate, df_two.psnr, mode=0))
    print("BD-RATE: ", bj_delta(df_one.bitrate, df_one.psnr, df_two.bitrate, df_two.psnr, mode=1))

    print("BD-SSIM: ", bj_delta(df_one.bitrate, df_one.ssim, df_two.bitrate, df_two.ssim, mode=0))
    print("BD-RATE: ", bj_delta(df_one.bitrate, df_one.ssim, df_two.bitrate, df_two.ssim, mode=1))

    print("BD-VMAF: ", bj_delta(df_one.bitrate, df_one.vmaf, df_two.bitrate, df_two.vmaf, mode=0))
    print("BD-RATE: ", bj_delta(df_one.bitrate, df_one.vmaf, df_two.bitrate, df_two.vmaf, mode=1))


def calc_avg_bd_per_qp(df_one, df_two, key_rate=KEYS.BITRATE, key_quality=METRIC):
    grouped_one = df_one.groupby([KEYS.QP])
    grouped_two = df_two.groupby([KEYS.QP])
    df_bd = {"qp": [], key_quality: [], key_rate: []}
    for one, two in zip(grouped_one.groups.keys(), grouped_two.groups.keys()):
        group_one = grouped_one.get_group(one).sort_values(by=[key_rate])
        group_two = grouped_two.get_group(two).sort_values(by=[key_rate])
        df_bd["qp"].append(group_one.qp.values[0])
        df_bd[key_quality].append(
            bj_delta(group_one[key_rate].values, group_one[key_quality].values, group_two[key_rate].values, group_two[key_quality].values, mode=0)
        )
        df_bd[key_rate].append(
            bj_delta(group_one[key_rate].values, group_one[key_quality].values, group_two[key_rate].values, group_two[key_quality].values, mode=1)
        )
    return pd.DataFrame(df_bd)


def get_label_from_file_name(name):
    return str(Path(name).stem).split("-",1)[1]


def save_plot(name):
    path = BASE_PATH / name
    plt.savefig(path.with_suffix(".png"))
    plt.savefig(path.with_suffix(".pdf"))
    tikzplotlib.save(path.with_suffix(".tex"))


def calculate_mscr(df_cost_savings):
    max_cost_savings = df_cost_savings.groupby([KEYS.KSIZE, KEYS.SIGMA]).max()
    return np.log10((max_cost_savings.savings / max_cost_savings.cost).mean())

In [None]:
plot_rd_curve(df_baseline, "none")
files_jpeg = ["quality-jpeg-10-05.csv", "quality-jpeg-20-05.csv", "quality-jpeg-40-05.csv", "quality-jpeg-60-05.csv"]
plot_rd_curves(files_jpeg)
save_plot("rd-jpeg")

In [None]:
plot_rd_curve(df_baseline, "none")
files_gauss_3 = ["quality-gauss-3-05.csv", "quality-gauss-3-06.csv", "quality-gauss-3-07.csv", "quality-gauss-3-08.csv", "quality-gauss-3-10.csv", "quality-gauss-3-15.csv"]
plot_rd_curves(files_gauss_3)
save_plot("rd-gauss-3")

In [None]:
plot_rd_curve(df_baseline, "none")
files_gauss_5 = ["quality-gauss-5-05.csv", "quality-gauss-5-06.csv", "quality-gauss-5-07.csv", "quality-gauss-5-08.csv", "quality-gauss-5-10.csv", "quality-gauss-5-15.csv"]
plot_rd_curves(files_gauss_5)
save_plot("rd-gauss-5")

In [None]:
plot_rd_curve(df_baseline, "none")
files_median = ["quality-median-3-05.csv", "quality-median-5-05.csv", "quality-median-7-05.csv", "quality-median-9-05.csv"]
plot_rd_curves(files_median)
save_plot("rd-median")

In [None]:
df_jpeg = read_df_by_keys(all_files, ["quality-jpeg"])
df_gauss_3 = read_df_by_keys(all_files, ["quality-gauss-3"])
df_gauss_5 = read_df_by_keys(all_files, ["quality-gauss-5"])
df_median = read_df_by_keys(all_files, ["quality-median"])

In [None]:
plot_rd_curve(df_baseline, "none")
plot_rd_curve_filter(df_jpeg)
save_plot("rd-jpeg-qp")

In [None]:
plot_rd_plane(pd.concat([df_baseline, df_jpeg]), "$Q\in$[10,20,40,60]")
save_plot("rdp-jpeg")

In [None]:
plot_rd_curve(df_baseline, "none")
plot_rd_curve_filter(df_gauss_3)
save_plot("rd-gauss-3-qp")

In [None]:
plot_rd_plane(pd.concat([df_baseline, df_gauss_3]), "$\sigma\in$[0.5,...,1.5]")
save_plot("rdp-gauss-3")

In [None]:
plot_rd_curve(df_baseline, "none")
plot_rd_curve_filter(df_gauss_5)
save_plot("rd-gauss-5-qp")

In [None]:
gauss_5_plane = pd.concat([df_baseline, df_gauss_5])
plot_rd_plane(gauss_5_plane, "$\sigma\in$[0.5,...,1.5]")
save_plot("rdp-gauss-5")

In [None]:
plot_rd_curve(df_baseline, "none")
plot_rd_curve_filter(df_median)
save_plot("rd-median-qp")

In [None]:
median_plane = pd.concat([df_baseline, df_median])
plot_rd_plane(median_plane, "$k\in$[3,5,7,9]")
save_plot("rdp-median")

In [None]:
plot_cost_savings(files_jpeg, df_baseline)
save_plot("cost-jpeg")

In [None]:
plot_cost_savings_filter(files_jpeg, df_baseline)
save_plot("cost-jpeg-qp")

In [None]:
plot_cost_savings(files_gauss_3, df_baseline)
save_plot("cost-gauss-3")

In [None]:
plot_cost_savings_filter(files_gauss_3, df_baseline)
save_plot("cost-gauss-3-qp")

In [None]:
plot_cost_savings(files_gauss_5, df_baseline)
save_plot("cost-gauss-5")

In [None]:
plot_cost_savings_filter(files_gauss_5, df_baseline)
save_plot("cost-gauss-5-qp")

In [None]:
plot_cost_savings(files_median, df_baseline)
save_plot("cost-median")

In [None]:
plot_cost_savings_filter(files_median, df_baseline)
save_plot("cost-median-qp")

In [None]:
df_jpeg_cost_savings = extend_all_by_cost_and_savings(files_jpeg, df_baseline)
df_gauss_3_cost_savings = extend_all_by_cost_and_savings(files_gauss_3, df_baseline)
df_gauss_5_cost_savings = extend_all_by_cost_and_savings(files_gauss_5, df_baseline)
df_median_cost_savings = extend_all_by_cost_and_savings(files_median, df_baseline)

In [None]:
bd_gauss_3_jpeg = calc_avg_bd_per_qp(df_gauss_3, df_jpeg)
plt.ylim(-40, 60)  # Hard overshooting comes from jpeg
plot_bd_over_qp(bd_gauss_3_jpeg)
save_plot("bd-gauss-3-jpeg")

In [None]:
bd_gauss_3_gauss_5 = calc_avg_bd_per_qp(df_gauss_3, df_gauss_5)
plot_bd_over_qp(bd_gauss_3_gauss_5)
save_plot("bd-gauss-3-gauss-5")

In [None]:
bd_gauss_3_gauss_5_cost_savings = calc_avg_bd_per_qp(df_gauss_3_cost_savings, df_gauss_5_cost_savings, key_rate="savings", key_quality="cost")
plot_bd_over_qp(bd_gauss_3_gauss_5_cost_savings, key_rate="savings", key_quality="cost")

In [None]:
bd_gauss_3_median = calc_avg_bd_per_qp(df_gauss_3, df_median)
plot_bd_over_qp(bd_gauss_3_median)
save_plot("bd-gauss-3-median")

In [None]:
bd_gauss_3_median_cost_savings = calc_avg_bd_per_qp(df_gauss_3_cost_savings, df_median_cost_savings, key_rate="savings", key_quality="cost")
plot_bd_over_qp(bd_gauss_3_median_cost_savings, key_rate="savings", key_quality="cost")

In [None]:
print(f"MSCR JPEG:    {calculate_mscr(df_jpeg_cost_savings):.2f}")
print(f"MSCR Gauss-3: {calculate_mscr(df_gauss_3_cost_savings):.2f}")
print(f"MSCR Gauss-5: {calculate_mscr(df_gauss_5_cost_savings):.2f}")
print(f"MSCR Median:  {calculate_mscr(df_median_cost_savings):.2f}")