In [49]:
from datetime import datetime
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import seaborn as sn

import bee_utils as bee

pd.set_option("display.precision", 2)

matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'font.size' : 10,
    'text.usetex': True,
    'pgf.rcfonts': False,
})


In [50]:
PN_BASE_DIR = Path("../Pointnet_Pointnet2_pytorch/log/classification/")
PRED_FILE = PN_BASE_DIR / "msg_cls5C_e40_bs8_pts4096_split7030_ds5fps/logs/pred_per_sample_7030.csv"

exp_name = PRED_FILE.parent.parent.name

df = pd.read_csv(PRED_FILE, header="infer")

display(df)

Unnamed: 0,sample_path,target_name,pred_name,bee,butterfly,dragonfly,wasp,bumblebee
0,butterfly/hn-but-2_42_28.csv,butterfly,butterfly,-4.74,-0.06,-3.40e+00,-5.70,-5.61
1,butterfly/hn-but-2_42_39.csv,butterfly,butterfly,-6.88,-0.14,-2.76e+00,-6.76,-6.32
2,butterfly/hn-but-2_42_44.csv,butterfly,butterfly,-9.39,-0.04,-3.27e+00,-9.49,-8.74
3,butterfly/hn-but-2_40_1.csv,butterfly,dragonfly,-4.19,-1.03,-8.36e-01,-4.69,-1.78
4,butterfly/hn-but-2_42_29.csv,butterfly,butterfly,-6.22,-0.27,-1.56e+00,-6.47,-7.32
...,...,...,...,...,...,...,...,...
538,dragonfly/mb-dra2-1_33_2.csv,dragonfly,dragonfly,-9.07,-10.14,-3.30e-03,-5.98,-7.71
539,dragonfly/mb-dra1-1_64_4.csv,dragonfly,dragonfly,-7.82,-9.54,-2.40e-03,-6.85,-7.35
540,dragonfly/mb-dra1-1_5_1.csv,dragonfly,dragonfly,-1.93,-4.13,-1.97e-01,-5.83,-5.03
541,dragonfly/mb-dra1-1_63_14.csv,dragonfly,dragonfly,-6.70,-9.58,-5.40e-03,-5.95,-6.63


In [51]:
def abbr_class_name(cl, mapp=None):
    if mapp is not None and cl in mapp:
        return mapp[cl]
    if cl in bee.CLASSES and len(cl)>3:
        return cl[:3]+"."
    else:
        return cl
    
mapp1 = {
    "bee":"bee","bumblebee":"bumbleb.","wasp":"wasp","dragonfly":"dragonfly","butterfly":"butterfly","insect":"insect"
}

In [52]:


all_classes_ordered = ["bee","bumblebee","wasp","dragonfly","butterfly","insect"]
classes_ordered = [cl for cl in all_classes_ordered if cl in df.columns]

df_reordered = df[classes_ordered]

# TODO change "bublebee" to "bumbleb." otherwise the name is too long for the plot.
# TODO smaller font

# Build confusion matrix
cf_matrix = confusion_matrix(df["target_name"], df["pred_name"], labels=classes_ordered)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index = [abbr_class_name(i,mapp1) for i in classes_ordered],
                     columns = [abbr_class_name(i,mapp1) for i in classes_ordered])
plt.figure(figsize = (5,3.2))
ax = sn.heatmap(df_cm, annot=True, cmap="Blues")
ax.set_ylabel("Ground-Truth")
ax.set_xlabel("Klassifikation")

ax.get_figure().show()


  ax.get_figure().show()


In [53]:
# save for latex


fig = ax.get_figure()
fig.tight_layout()
# fig.set_size_inches(6.3,5)
datetime_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir = Path(f'output/from_notebooks/cm/{datetime_str}_{exp_name}')
output_dir.mkdir(exist_ok=True, parents=True)
fig.savefig(str(output_dir / f"cm_{exp_name}.png"))
fig.savefig(str(output_dir / f"cm_{exp_name}.pgf"))


################### WICHTIG! #####################
# Auch generiertes ...-img0.png muss kopiert werden!
# In PGF Datei muss Pfad von ...-img0.png angepasst werden:
# Davor muss das Verzeichnispfad im Latex-res-Dir angegeben werden!
# ZB: {res/pgf/cm_msg_cls5C_e40_bs8_pts4096_split7030_ds5fps-img0.png}
##################################################