In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

red = "#D81B60"
yellow = "#FFC107"
blue = "#1E88E5"
green = "#004D40"
from IPython.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))
display(HTML("<style>.output_result { max-width:100% !important; }</style>"))

In [None]:
import sklearn.metrics
rootpath = "..\\results\\dialect\\"
df_dicts = []
result_path = f"_test_original_results.csv"

results_df = {}
for sys in ["renemb", "clevercs"]:
    df = pd.read_csv(rootpath + sys + result_path+"_0")
    df = df.rename(columns={"prediction_time": "prediction_time_0"})
    for i in range(1,5):
        tmp_df = pd.read_csv(rootpath + sys + result_path+f"_{i}", sep=",", quotechar='"')
        tmp_df = tmp_df.rename(columns={"prediction_time": f"prediction_time_{i}"})
        df = pd.merge(df, tmp_df[[f"prediction_time_{i}"]], left_index=True, right_index=True, how="outer")   
        del tmp_df
    df["mean_time"] = df[[f"prediction_time_{i}" for i in range(5)]].mean(axis=1)
    results_df[sys] = df
    del df

In [None]:
plt.rcParams['font.family'] = 'Rasa'
labelsize = 20

results_df["clevercs"]
plt.boxplot([results_df["clevercs"]["mean_time"],results_df["r"]["mean_time"]],
            # palette = [blue, red],
            labels=["CleverCSV", "RenEMB"])
# change fontsize of labels
plt.xticks(fontsize=labelsize)
plt.yticks(fontsize=labelsize)
plt.yscale("log")
plt.title("Average file-wise runtime", size=labelsize);
plt.ylabel("Runtime (s)", size=labelsize);

print("CleverCSV: ", results_df["clevercs"]["mean_time"].median())
print("RenEMB: ", results_df["renemb"]["mean_time"].median())
# add the tick 0.03 to the x-axis, with the label '0.03'

plt.grid()
plt.show() 




In [None]:
from matplotlib.patches import Patch

plt.rcParams['font.family'] = 'Rasa'
labelsize = 20

colors = [red, yellow, blue]
patterns = ['/', 'o', 'x']

def plot_bar(ax, title, bar1,bar2,bar3,bar4):
  ax.grid(zorder=-10)
  ax.set_axisbelow(True)

  ax.barh([0,1,2],    bar1, color=colors, hatch=patterns, xerr=std["renemb"][bar1.name])
  ax.barh([4,5,6],    bar2, color=colors, hatch=patterns, xerr=std["renemb"][bar2.name])
  ax.barh([8,9,10],   bar3, color=colors, hatch=patterns, xerr=std["renemb"][bar3.name])
  ax.barh([12,13,14], bar4, color=colors, hatch=patterns, xerr=std["renemb"][bar4.name])
  ax.invert_yaxis()
  # add a legend
  # plt.legend(['No augmentation', 'Synonym', 'Backtranslation'], loc='upper left', bbox_to_anchor=(1, 1), fontsize=labelsize)

  # write numbers on top of the bars
  ypad =0
  xpad=.02
  for j, bar in enumerate([bar1, bar2, bar3, bar4]):
    for i, v in enumerate(bar):
      ax.text(v+xpad, (j*4)+i+ypad, str(round(v, 2)), color='black', size=labelsize, va='center')


  # only have the corresponding x labels corresponding to the middle bar of each plot
  ax.set_yticks([])
  #
  ax.set_xlabel(title, size=labelsize+2);
  ticks = [0, 20, 40, 60, 80, 100]
  ax.set_xticks(ticks=ticks, labels=ticks,size=labelsize);
  ax.set_xlim(0, 130)
  # plt.title("Dialect detection results", size=labelsize+2);

fig = plt.figure(figsize=(10, 6))
gs = fig.add_gridspec(nrows=1,ncols=3, wspace=0.05)
axs = gs.subplots(sharey=True)

bars = [res[f"dev_f1_{x}"] for x in ["delimiter", "quotechar", "escapechar"]] + [res["dev_accuracy"]]
plot_bar(axs[0], 'Validation set', *bars)

# stack another plot on the right of the first one reusing the same y axis
bars = [res[f"test_f1_{x}"] for x in ["delimiter", "quotechar", "escapechar"]] + [res["test_accuracy"]]
plot_bar(axs[1], 'Test set', *bars)

bars = [res[f"weird_f1_{x}"] for x in ["delimiter", "quotechar", "escapechar"]] + [res["weird_accuracy"]]
plot_bar(axs[2], 'Difficult set', *bars)

xpoint = -30
axs[0].text(xpoint, 1, 'Delimiter', ha='center', va='center', size=labelsize)
axs[0].text(xpoint, 5, 'Quote', ha='center', va='center', size=labelsize)
axs[0].text(xpoint, 9, 'Escape', ha='center', va='center', size=labelsize)
axs[0].text(xpoint, 13, 'Dialect\nAccuracy', ha='center', va='center', size=labelsize)

# add a legend
custom_bars = [Patch(facecolor=colors[0], hatch=patterns[0]),
               Patch(facecolor=colors[1], hatch=patterns[1]),
               Patch(facecolor=colors[2], hatch=patterns[2]),]

axs[1].legend(handles=custom_bars,
              labels = ["CleverCSV", "RenEMB", "Hybrid"] ,
              loc='center',
              bbox_to_anchor=(0.5, 1.07),
              ncols=3,
              fontsize=labelsize)

plt.savefig('dialect_results.png', dpi=300, bbox_inches='tight')