In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

red = "#D81B60"
yellow = "#FFC107"
blue = "#1E88E5"
green = "#004D40"

from IPython.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))
display(HTML("<style>.output_result { max-width:100% !important; }</style>"))

In [None]:
strudel_df = pd.read_csv('../results/rowclass/strugritte_results_row_False_line_False.csv')
renemb_df = pd.read_csv('../results/rowclass/renemb_results.csv')
strugritte_df = pd.read_csv('../results/rowclass/strugritte_results_row_False_line_True.csv')

strudel_df = strudel_df.rename(columns={"strugritte": "strudel"})

df = pd.merge(strudel_df, renemb_df, on="measure")
df = pd.merge(df, strugritte_df, on="measure")
res = df.set_index("measure")*100
res = res.transpose()

display(res)

In [None]:
from matplotlib.patches import Patch

plt.rcParams['font.family'] = 'Rasa'
labelsize = 20

colors = [red, yellow, blue]
patterns = ['/', 'o', 'x']

def plot_bar(ax, title, bars):
  ax.grid(zorder=-10)
  ax.set_axisbelow(True)

  n_rows = len(bars[0])
  positions = [np.arange(n_rows)+((n_rows+1)*idx) for idx in range(len(bars))]
  for j,bar in enumerate(bars):
    ax.barh(positions[j], bar, color=colors, hatch=patterns, ) #xerr=std["strudel"][b.name]

  # write numbers on top of the bars
  ypad =0
  xpad=.02
  for j, b in enumerate(bars):
    for i, v in enumerate(b):
      ax.text(v+xpad, (j*(n_rows+1))+i+ypad, str(round(v, 2)), color='black', size=labelsize, va='center')

  ax.set_yticks([])
  ax.set_xlabel(title, size=labelsize+2);
  ticks = [0, 20, 40, 60, 80, 100]
  ax.set_xticks(ticks=ticks, labels=ticks,size=labelsize);
  ax.set_xlim(0, 130)

fig = plt.figure(figsize=(20, 5))
gs = fig.add_gridspec(nrows=1,ncols=6, wspace=0.05)
axs = gs.subplots(sharey=True)

datasets = ["saus", "cius", "deex","govuk","mendeley","troy"]
measures = ["data","metadata", "group", "notes", "header", "derived"]
for idx, ds in enumerate(datasets):
  bars = [res[f"{ds}_{m}_f1"] for m in measures]
  plot_bar(axs[idx], ds.title(), bars)

axs[0].invert_yaxis()

n_rows = len(bars[0])
xpoint = -2
for idx,m in enumerate(measures):
  axs[0].text(xpoint, 1+((n_rows+1)*idx), m.title()+" F1", ha='right', va='center', size=labelsize)

# add a legend
custom_bars = [
              Patch(facecolor=colors[0], hatch=patterns[0]),
              Patch(facecolor=colors[1], hatch=patterns[1]),
              Patch(facecolor=colors[2], hatch=patterns[2]),
              ]

axs[1].legend(handles=custom_bars,
              labels = ["Strudel", "RenEMB", "Hybrid"], #[x.capitalize() for x in res.index],
              loc='center',
              bbox_to_anchor=(2.1, 1.07),
              ncols=3,
              fontsize=labelsize)

plt.savefig('lineclass_new_results.png', dpi=300, bbox_inches='tight')