In [None]:
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
dataset = 'classc'

In [None]:
# Read our results

import json

def combine_dict(dict1, dict2):
    for key in dict2.keys():
        if key in dict1:
            dict1[key] = combine_dict(dict1[key], dict2[key])
        else:
            dict1[key] = dict2[key]
    return dict1

def combine_dicts(dicts):
    u = None
    for d in dicts:
        if u is None:
            u = d
        else:
            u = combine_dict(u, d)
    return u

In [None]:
import site

site.addsitedir("/home/xyhang/projects/VCIP2023-grand-challenge/")

from tools.bdrate import BD_RATE, BD_PSNR

In [None]:
from dataclasses import dataclass
from typing_extensions import TypeAlias, List, Dict
import numpy as np

@dataclass
class ResultUnit:
    r: float
    d: float
    t: float

ImageResults: TypeAlias = List[ResultUnit]

class DatasetResults:
    def __init__(self):
        self.img_results: Dict[str, ImageResults] = {}

    def update_image_result(self, img_name, r, d, t):
        self.img_results.setdefault(img_name, []).append(ResultUnit(r=r, d=d, t=t))

    def sort(self):
        for k in self.img_results.keys():
            self.img_results[k].sort(key=lambda x: x.r)

    @property
    def avg_r(self) -> float:
        self.sort()
        results = []
        for rs in zip(*list(self.img_results.values())):
            results.append(np.mean([x.r for x in rs]))
        return results

    @property
    def avg_d(self) -> float:
        self.sort()
        results = []
        for rs in zip(*list(self.img_results.values())):
            results.append(np.mean([x.d for x in rs]))
        return results

    @property
    def avg_t(self) -> float:
        self.sort()
        results = []
        for rs in zip(*list(self.img_results.values())):
            results.append(np.mean([x.t for x in rs]))
        return results

    def _bd_rate_imgwise(self, anchor, min_int=None, max_int=None):
        self.sort()
        anchor.sort()
        ans = {}
        for filename, data in self.img_results.items():
            if filename not in anchor.img_results:
                raise ValueError(f"filename {filename} not in anchor")
            data_anchor = anchor.img_results[filename]
            R1 = [x.r for x in data_anchor]
            D1 = [x.d for x in data_anchor]
            R2 = [x.r for x in data]
            D2 = [x.d for x in data]
            R1 = np.unique(R1)
            R2 = np.unique(R2)
            D1 = np.unique(D1)
            D2 = np.unique(D2)

            if min_int is not None and max_int is not None:
                bd_rate = BD_RATE(R1, D1, R2, D2, min_int=min_int[filename], max_int=max_int[filename])
            else:
                bd_rate = BD_RATE(R1, D1, R2, D2)

            ans[filename] = bd_rate
        return ans

    def bd_rate(self, anchor, min_int=None, max_int=None):
        ans = self._bd_rate_imgwise(anchor, min_int, max_int)
        return np.mean(list(ans.values()))

    def bd_psnr(self, anchor):
        self.sort()
        anchor.sort()
        ans = []
        for filename, data in self.img_results.items():
            if filename not in anchor.img_results:
                raise ValueError(f"filename {filename} not in anchor")
            data_anchor = anchor.img_results[filename]
            R1 = [x.r for x in data_anchor]
            D1 = [x.d for x in data_anchor]
            R2 = [x.r for x in data]
            D2 = [x.d for x in data]
            R1 = np.unique(R1)
            R2 = np.unique(R2)
            D1 = np.unique(D1)
            D2 = np.unique(D2)
            bd_rate = BD_PSNR(R1, D1, R2, D2)
            ans.append(bd_rate)
        return np.mean(ans)

    def time_saving(self, anchor):
        self.sort()
        anchor.sort()
        ans = []

In [None]:
def read_json(filename):
    with open(filename, 'r') as f:
        return json.load(f)

In [None]:
import glob

results = glob.glob(f"{dataset}/*.json")

data_ours = combine_dicts([read_json(r) for r in results])

In [None]:
# speedups = [0.01, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0]
speedups = [0.01,0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0, 4.0]
qscales = [0.1, 0.3, 0.5, 0.7, 0.9]
# qscales = [0.3,0.4,0.5,0.6,0.7]

In [None]:
ours_results = {k: DatasetResults() for k in speedups}

for speedup in speedups:
    for qscale in qscales:
        glob_results = data_ours[f"qscale={qscale}"][f"speedup={speedup}"]
        for filename, result in glob_results.items():
            if filename[:3] != "avg":
                ours_results[speedup].update_image_result(
                    filename, result["bpp"], result["PSNR"], result["t_dec"]
                )

In [None]:
if dataset == 'classd':
    cbanet_levels = [1, 2, 3]
    cbanet_results = {k: DatasetResults() for k in cbanet_levels}
    data_cbanet = read_json(f"../../anchors/{dataset}/results_cbanet.json")

    for level in cbanet_levels:
        for bpp in [1,2,3,4]:
            glob_results = data_cbanet[f"{bpp}"][f"width={level}"]
            for filename, result in glob_results.items():
                if filename[:3] != "avg":
                    cbanet_results[level].update_image_result(
                        filename, result["bpp"], result["PSNR"], result["t_dec"] + 0.05 # ANS time extra
                    )

In [None]:
import pandas as pd

def read_xls(file_path, sheet_name):
    df = pd.read_excel(file_path, sheet_name=sheet_name)
    return df

In [None]:
anchors = ['bpg', 'evc', 'jpeg', 'mlic', 'qarv', 'tcm', 'webp', 'vtm']

anchor_results: Dict[str, DatasetResults] = {k: DatasetResults() for k in anchors}

anchors_r = {}
anchors_d = {}
anchors_t = {}

for anchor in anchors:
    if anchor == 'vtm':
        xls_data: pd.DataFrame = read_xls(f"VTM.xls", dataset)
        filename = None
        for index, x in xls_data.iterrows():
            if isinstance(x.Filename, (str, int)) or (isinstance(x.Filename, float) and not np.isnan(x.Filename)):
                if isinstance(x.Filename, float):
                    filename = str(int(x.Filename))
                else:
                    filename = str(x.Filename)
            anchor_results[anchor].update_image_result(
                filename, x.bpp, x["RGB psnr"], x["Dec Time"] * 6.734
            )
    else:
        anchor_data = read_json(f"../../anchors/{dataset}/{anchor}/results.json")
        if anchor == 'bpg' and dataset == 'classa':
            for t1 in anchor_data.values():
                for t2 in t1.values():
                    for filename, result in t2.items():
                        if filename[:3] != "avg":
                            anchor_results[anchor].update_image_result(
                                filename, result["bpp"], result["PSNR"], result["t_dec"]
                            )
        else:
            for t2 in anchor_data.values():
                for filename, result in t2.items():
                    if filename[:3] != "avg":
                        anchor_results[anchor].update_image_result(
                            filename, result["bpp"], result["PSNR"], result["t_dec"]
                        )

In [None]:
import scipy


def interpolator(x, y):
    x = np.asarray(x)
    y = np.asarray(y)
    lin = np.linspace(min(x), max(x), num=100, retstep=True)
    samples = lin[0]
    v = scipy.interpolate.pchip_interpolate(
        np.sort(x), y[np.argsort(x)], samples
    )

    return samples, v

In [None]:
plt.figure(figsize=(6.4, 4.8))

name_mapping = {
    "evc": r"EVC (LL) $^\dag$",
    "mlic": "MLIC++ $^*$",
    "tcm": "LIC-TCM $^*$",
    "vtm": "VTM 22.0 Intra",
    "qarv": "QARV $^\dag$",
    "bpg": "BPG",
    "webp": "WebP",
    "jpeg": "JPEG",
}

for speedup in speedups:
    r = ours_results[speedup].avg_r
    d = ours_results[speedup].avg_d
    r, d = interpolator(r, d)

    dt = 1.0 / speedup
    label = (
        f"Ours ($\\alpha_t={dt:.2f}$)" if dt <= 10 else f"Ours ($\\alpha_t=\\infty$)"
    )

    plt.plot(
        r,
        d,
        linestyle="--",
        label=label,
    )

for k, v in anchor_results.items():
    r = v.avg_r
    d = v.avg_d
    r, d = interpolator(r, d)
    plt.plot(r, d, label=name_mapping[k])
plt.legend(ncol=2)
plt.xlim(0.0, 1.2)
plt.ylim(30, 42)

plt.xlabel("Bit-rate [bpp]")
plt.ylabel("PSNR [dB]")
plt.minorticks_on()
plt.grid(which="major", linestyle="-")
plt.grid(which="minor", linestyle=":")
titles = {
    "classa": "Class A",
    "classb": "Class B",
    "classc": "Class C",
    "classd": "Kodak",
}
plt.title(titles[dataset])

plt.savefig(f"{dataset}_rd.pdf", bbox_inches="tight")
plt.savefig(f"{dataset}_rd.png", dpi=600, bbox_inches="tight")
plt.show()

In [None]:
anchor = anchor_results['vtm']
min_int = {}
max_int = {}

for filename in anchor.img_results.keys():
    min_int[filename] = max(
        [
            min([x.d for x in data.img_results[filename]])
            for data in anchor_results.values()
        ]
        + [
            min([x.d for x in data.img_results[filename]])
            for data in ours_results.values()
        ]
    )
    max_int[filename] = min(
        [
            max([x.d for x in data.img_results[filename]])
            for data in anchor_results.values()
        ]
        + [
            max([x.d for x in data.img_results[filename]])
            for data in ours_results.values()
        ]
    )

print(min_int)
print(max_int)

plt.figure(figsize=(5, 3))

ts = {}
bds = {}

for k, v in anchor_results.items():
    bd = v.bd_rate(anchor, min_int, max_int)
    t = np.mean(v.avg_t)
    ts[k] = t
    bds[k] = bd

for k, v in anchor_results.items():
    if k == 'evc':
        ts[k] = ts['qarv'] / 3.8
    if k not in ['jpeg', 'bpg', 'webp']:
        marker = 'o'
        plt.scatter(ts[k], bds[k], label=name_mapping[k], zorder=10, marker=marker)
    print(k, ts[k], bds[k])

if dataset == "classd":
    gao_t = np.array([0.2, 1.7, 3.1, 4.4, 7.3]) / 2.0
    gao_bd = [8.7, 6.0, 5.0, 4.0, 3.0]
    t, bd = interpolator(gao_t, gao_bd)
    plt.plot(t, bd, color="orange")
    plt.scatter(gao_t, gao_bd, marker="o", color="orange")
    plt.plot([-1], [0], color="orange", marker="o", label="Gao et.al")

    cba_bd = []
    cba_t = []

    for k, v in cbanet_results.items():
        bd = v.bd_rate(anchor, min_int, max_int)
        t = np.mean(v.avg_t)
        cba_bd.append(bd)
        cba_t.append(t)
    
    plt.plot(cba_t, cba_bd, color='teal')
    plt.scatter(cba_t, cba_bd, marker='o', color='teal', zorder=10)
    plt.plot([-1], [0], marker="o", color="teal", label="CBANet")

    # plt.title(titles[dataset])
    # plt.legend()

    # plt.savefig(f"{dataset}_bd_gao.pdf", bbox_inches="tight")
    # plt.savefig(f"{dataset}_bd_gao.png", dpi=600, bbox_inches="tight")

ours_bd = []
ours_t = []

for k, v in ours_results.items():
    bd = v.bd_rate(anchor, min_int, max_int)
    t = np.mean(v.avg_t)
    t = min(t, ts['qarv'] / k)
    ours_bd.append(bd)
    ours_t.append(t)

print(ours_t, ours_bd)

ours_t = np.asarray(ours_t)
ours_bd = np.asarray(ours_bd)
ours_bd, ours_t = np.sort(ours_bd), ours_t[np.argsort(ours_bd)]
t, bd = interpolator(ours_t, ours_bd)

plt.plot(t, bd, color='blue')
plt.scatter(ours_t, ours_bd, marker='*', color='blue', zorder=10)
plt.plot([-1], [0], marker="*", color="blue", label="Ours")
plt.xlabel("Dec. Time [s]")
plt.xlim(0, 3.0)
plt.ylabel("BD-rate over VTM 22.0 Intra [%]")
plt.minorticks_on()
plt.grid(zorder=0, which='major', color='#999999', linestyle='-')
plt.grid(zorder=0, which='minor', color='#999999', linestyle=':')
titles = {
    "classa": "Class A",
    "classb": "Class B",
    "classc": "Class C",
    "classd": "Kodak",
}
plt.title(titles[dataset])

plt.legend()

plt.savefig(f"{dataset}_bd_c.pdf", bbox_inches='tight')
plt.savefig(f"{dataset}_bd_c.png", dpi=600, bbox_inches="tight")

In [None]:
ours_bd, ours_t

In [None]:
# Accelerate Effect

import scipy.interpolate


for method in ts.keys():
    t = ts[method]
    bd = bds[method]

    t_est = scipy.interpolate.pchip_interpolate(ours_bd, ours_t, bd)
    print(method, 100*(t_est-t)/t)

In [None]:
# Accelerate Effect

import scipy.interpolate

ours_t = np.asarray(ours_t)
ours_bd = np.asarray(ours_bd)

print(ours_t, ours_bd, flush=True)

for method in ts.keys():
    t = ts[method]
    bd = bds[method]
    if t > max(ours_t):
        bd_est = min(ours_bd)
    else:
        X = np.sort(ours_t)
        Y = ours_bd[np.argsort(ours_t)]
        bd_est = scipy.interpolate.pchip_interpolate(X, Y, t)
    print(method, bd_est - bd)

In [None]:
import matplotlib as mpl

def plot_img_results(results: DatasetResults, anchor: DatasetResults, rng_anchor: DatasetResults = None):
    ans_results = results._bd_rate_imgwise(anchor, min_int, max_int)
    ans_results = [(k, v) for k, v in ans_results.items()]
    ans_results.sort(key=lambda x: x[1])
    print(ans_results)
    plt.figure(figsize=(10, 3))

    cm = mpl.colormaps['coolwarm']
    labels, values = zip(*ans_results)
    colors = np.asarray(values)
    colors = colors / (np.abs(colors).max())
    colors = colors / 2 + 0.5
    colors = [cm(k) for k in colors]
    plt.bar(labels, values, color=colors, zorder=10)
    plt.ylabel("BD-rate (%)")
    plt.grid(axis='y', zorder=0)
    plt.xticks(rotation=90)

In [None]:
import matplotlib as mpl


def plot_img_results_comparison(
    results1: DatasetResults, results2, anchor: DatasetResults
):
    ans_results1 = results1._bd_rate_imgwise(anchor)
    ans_results2 = results2._bd_rate_imgwise(anchor)
    ans_results2 = [(k, ans_results1[k] - v) for k, v in ans_results2.items()]
    ans_results2.sort(key=lambda x: x[1])
    print(ans_results2)
    plt.figure(figsize=(10, 3))

    cm = mpl.colormaps["coolwarm"]
    labels, values = zip(*ans_results2)
    colors = np.asarray(values)
    colors = colors / (np.abs(colors).max())
    colors = colors / 2 + 0.5
    colors = [cm(k) for k in colors]
    plt.bar(labels, values, color=colors, zorder=10)
    plt.ylabel("BD-rate (%)")
    plt.grid(axis="y", zorder=0)
    plt.xticks(rotation=90)

In [None]:
plot_img_results(ours_results[3.0], anchor_results["evc"])

In [None]:
plot_img_results_comparison(ours_results[0.01], anchor_results["tcm"], anchor_results["vtm"])
plt.savefig("ours_tcm_imgwise.pdf", bbox_inches="tight")
plt.savefig("ours_tcm_imgwise.png", dpi=600, bbox_inches="tight")

In [None]:
plot_img_results_comparison(
    ours_results[1.0], anchor_results["qarv"], anchor_results["vtm"]
)
plt.savefig("ours_qarv_imgwise.pdf", bbox_inches='tight')
plt.savefig("ours_qarv_imgwise.png", dpi=600, bbox_inches="tight")

In [None]:
plot_img_results(
    ours_results[2.0], anchor_results["evc"]
)

In [None]:
def f(imgresult, label):
    plt.plot([x.r for x in imgresult], [x.d for x in imgresult], label=label, marker='o')

f(anchor_results["vtm"].img_results["DSC08902"], "vtm")
f(anchor_results["evc"].img_results["DSC08902"], "evc")
# f(anchor_results["mlic"].img_results["DSC05885"], "mlic")
f(ours_results[3.0].img_results["DSC08902"], "ours")
plt.legend()

In [None]:
from src.fileio import FileIO

glb3 = glob.glob("classc/PSNR/speedup-2.0/*/DSC08927.bin")

for filename in glb3:
    fileio = FileIO.load(filename, False, 512)
    print(filename, fileio.method_id)