# KITTI Images Evaluation

Create evaluation plot for metrics evaluation on KITTI dataset for VAE, GAN and JPEG

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt
import pandas as pd
import os
import numpy as np

In [None]:
sns.set()
sns.set_style("white")

In [None]:
df = pd.DataFrame()

csv_files_folder = "./evaluation_files_KITTI"
for csv_file in os.listdir(csv_files_folder):
    df_temp = pd.read_csv(os.path.join(csv_files_folder, csv_file))
    df = df.append(df_temp, ignore_index=True)

In [None]:
import matplotlib

matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    "font.family": "serif",
    "text.usetex": True,
    "pgf.rcfonts": False,
})
matplotlib.rcParams.update({'font.size': 4})

In [None]:
df.head()

In [None]:
df["quality_level"].unique()

In [None]:
df.loc[(df["model"] == "GAN") & (df["quality_level"] == "high"),"quality_level"] = "medium"
df.loc[(df["model"] == "GAN") & (df["quality_level"] == "super_high"),"quality_level"] = "high"
df["quality_level"].unique()

In [None]:
df["input_shape"] = df["input_shape"].fillna("(256, 256)")


In [None]:
df_256_256 = df[df["input_shape"] == "(256, 256)"]
df_256_256

In [None]:
df_256_256[df_256_256["q_bpp"]<=1].groupby("model").count()

In [None]:
%matplotlib inline

import matplotlib.patches as mpatches

metrics = ["ms_ssim", "lpips", "mse", "psnr"]
m_labels = ["MS-SSIM", "LPIPS", "MSE", "PSNR [dB]"]

fig = plt.figure(figsize=(5,1.5))
data = df_256_256[(df_256_256["q_bpp"] <= 1.5) & (df_256_256["q_bpp"] >= 0.15)]

for i, (metric, m_label) in enumerate(zip(metrics, m_labels)):
    plt.subplot(1,4,i+1)
    sns.regplot(x="q_bpp", y=metric, data=data[data["model"] == "VAE"], logx=True, scatter=False)
    sns.regplot(x="q_bpp", y=metric, data=data[data["model"] == "GAN"], logx=True, scatter=False)
    sns.regplot(x="q_bpp", y=metric, data=data[data["model"] == "JPEG"], logx=True, scatter=False)
    plt.ylabel("", size=7)
    plt.title(m_label, size=7)
    plt.xlabel("Bit-rate [bpp]", size=7)
    plt.xlim([0.00, 1.01])
    plt.xticks(np.arange(0, 1.25, step=0.25),size=7,rotation=90)
    plt.yticks(size=7)
    vae_line = mpatches.Patch(color="#1f77b4", label='VAE')
    gan_line = mpatches.Patch(color="#ff7f0e", label='GAN')
    jpeg_line = mpatches.Patch(color="#2ca02c", label='JPEG')
    plt.grid()
fig.tight_layout()
plt.legend(handles=[vae_line, gan_line, jpeg_line], fontsize="x-small", bbox_to_anchor=(-0.8, -0.9), ncol=3)
plt.savefig(f"./for_latex/VAE_vs_GAN_vs_JPEG_line_flat.pgf", dpi=400, bbox_inches='tight')
plt.show()