In [None]:
from utils import *

data = load_data('../data/salsa_train.json')
systems, edit_types = set([x['system'] for x in data]), set(data[0]['annotations'].keys())

In [None]:
train_data = load_data('../data/inspection_rating_annotated', preprocess=True, adjudicated=True)
test_data = load_data('../data/test_set_inspection_rating_annotated', preprocess=True, adjudicated=True)
data = train_data + test_data

In [None]:
excluded_systems = [
    'new-wiki-1/GPT-3-zero-shot',
    'new-wiki-1/T5-3B',
    'new-wiki-4/Ctrl-T5-3b',
    'new-wiki-4/Turbo',
    'new-wiki-4/Vicuna-7b'
]

system_name_mapping.update({
    'new-wiki-1/T5-11B': 'T5',
    'new-wiki-1/GPT-3-few-shot': 'GPT-3.5',
    'new-wiki-4/Ctrl-T5-11b': 'T5 Ctrl',
})

data = [sent for sent in data if sent['system'] not in excluded_systems]

In [None]:
savefig = True

fig_dim = (12.1, 3) # 3.5

add_line = False
line_location = [2.5, 2.5]

fig, ax_quality = plt.subplots(1, 3, figsize=fig_dim)
width = 0.7

plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'TW Cen MT'

# plt.rcParams['font.family'] = 'serif'
# plt.rcParams['font.sans-serif'] = 'Times New Roman'

font_size_legend = 11
font_size_x_labels = 12 
font_size_y_label = 12
font_size_title = 15
legend_loc = (1, 0.5)

def fix_text(text, syntax_error=False):
    new_text = text.replace(' ', '\n').replace('t-', 't-\n')
    if syntax_error:
        new_text = f'Bad {new_text}'
    if 'level' in new_text:
        new_text = f'{new_text} Reorder'
        new_text = new_text.replace('Word-level ', 'Word-level\n')
    if 'Grammar' in new_text:
        new_text = new_text.replace('Grammar', 'Grammar\nError')
    return new_text

for plt_idx, family in enumerate(Family):
    out = get_edits_by_family(data, family, combine_humans=False)
    # Get the system labels by preserving the order of systems
    system_labels = [x for x in system_name_mapping if x in out.keys()]
    x = np.arange(len(system_labels))

    # Graph the quality edits
    quality_data = {system : out[system]['quality'] for system, _ in out.items()}
    bottom = [0 for x in range(len(system_labels))]
    if family == Family.CONTENT:
        quality_iterator = [Information.MORE, Information.LESS]
    elif family == Family.SYNTAX:
        quality_iterator = [x for x in ReorderLevel] + [Edit.STRUCTURE] + [Edit.SPLIT]
    elif family == Family.LEXICAL:
        quality_iterator = [Information.SAME]
    for quality_type in quality_iterator:
        val = [quality_data[label][quality_type] for label in system_labels]
        if sum(val) != 0:
            # Custom labels
            label = quality_type.value
            # if family == Family.SYNTAX:
            #     label = 'Quality ' + quality_type.value.lower() + ' edit'
            if quality_type == Information.SAME:
                label = 'Paraphrase'

            ax_quality[plt_idx].bar(x, val, width, bottom=bottom, label=label, color=color_mapping[quality_type])
        bottom = [bottom[i] + val[i] for i in range(len(val))]

    displayed_x_labels = [system_name_mapping[label] for label in system_labels]

    ax_quality[plt_idx].tick_params(labelsize=font_size_x_labels)
    ax_quality[plt_idx].set_xticks(np.arange(len(system_labels)))
    ax_quality[plt_idx].set_xticklabels(displayed_x_labels, rotation=30, ha="right")
    ax_quality[plt_idx].spines[['right', 'top']].set_visible(False)

    if family == Family.CONTENT:
        ax_quality[plt_idx].set_ylabel('# Quality Edits / Sent.', fontsize=font_size_y_label)

    # Add line
    if add_line:
        ax_quality[plt_idx].plot(line_location, [0, ax_quality[plt_idx].get_ylim()[-1]], ls='--', c='k')

    handles, labels = ax_quality[plt_idx].get_legend_handles_labels()
    ax_quality[plt_idx].legend(handles[::-1], labels[::-1], loc='center left', bbox_to_anchor=legend_loc,
        fancybox=True, ncol=1, borderaxespad=1.,fontsize=font_size_legend,
        facecolor='white',edgecolor='black',framealpha=1,frameon=False,
        columnspacing=1,handlelength=1,handleheight=1,handletextpad=0.6,
        borderpad=0.2, alignment='left', 
        # title=f'{family.value} Quality', title_fontproperties={'weight': 500, 'size': font_size_legend, 'style': 'italic'}
        )

    for ha in ax_quality[plt_idx].legend_.legendHandles:
        ha.set_edgecolor("black")

    trans = mtrans.Affine2D().translate(10, 0)
    for t in ax_quality[plt_idx].get_xticklabels():
        t.set_transform(t.get_transform()+trans)

    # Set the margins a little higher than the max value
    max_quality = max([sum(x.values()) for x in quality_data.values()])
    tick_range_quality = np.arange(0, max_quality*1.2, step=0.5)
    ax_quality[plt_idx].set_yticks(tick_range_quality)

    for text in ax_quality[plt_idx].legend_.texts:
        text.set_text(fix_text(text.get_text()))

    ax_quality[0].set_title('Conceptual Quality', fontsize=font_size_title)
    ax_quality[1].set_title('Syntax Quality', fontsize=font_size_title)
    ax_quality[2].set_title('Lexical Quality', fontsize=font_size_title)

plt.tight_layout()
if savefig:
    out_filename = f'../paper/plot/quality-edits.pdf'
    plt.savefig(out_filename, format="pdf", bbox_inches='tight', pad_inches=0.0)
    plt.close(fig)
else:
    plt.show()

fig, ax_error = plt.subplots(1, 3, figsize=fig_dim)

for plt_idx, family in enumerate(Family):
    out = get_edits_by_family(data, family, combine_humans=False)
    # Get the system labels by preserving the order of systems
    system_labels = [x for x in system_name_mapping if x in out.keys()]
    x = np.arange(len(system_labels))

    # Graph the error edits
    error_data = {system : out[system]['error'] for system, _ in out.items()}
    bottom = [0 for x in range(len(system_labels))]
    if family == Family.CONTENT:
        error_iterator = [e for e in Error if e != Error.UNNECESSARY_INSERTION]
    elif family == Family.SYNTAX:
        error_iterator = [x for x in ReorderLevel] + [Edit.STRUCTURE] + [Edit.SPLIT]
    elif family == Family.LEXICAL:
        error_iterator = [Error.COMPLEX_WORDING, Quality.ERROR, Error.INFORMATION_REWRITE]
    for error_type in error_iterator:
        val = [error_data[label][error_type] for label in system_labels]

        # This is a really awful solution, but it would be too much to change the classification of this error
        if family == Family.LEXICAL and error_type == Error.UNNECESSARY_INSERTION:
            tmp = get_edits_by_family(data, Family.CONTENT)
            tmp = {system : tmp[system]['error'] for system, _ in tmp.items()}
            val = [tmp[label][Error.UNNECESSARY_INSERTION] for label in system_labels]

        if sum(val) != 0:
            # Custom labels
            label = error_type.value
            if family == Family.SYNTAX or Family.LEXICAL:
                if label == 'Error':
                    label = 'Grammar'
                # label += ' Error'

            color = color_mapping[error_type]
            if family == Family.SYNTAX:
                color = color_mapping_override[error_type]
            ax_error[plt_idx].bar(x, val, width, bottom=bottom, label=label, color=color)
        bottom = [bottom[i] + val[i] for i in range(len(val))]

    displayed_x_labels = [system_name_mapping[label] for label in system_labels]

    ax_error[plt_idx].tick_params(labelsize=font_size_x_labels)
    ax_error[plt_idx].set_xticks(np.arange(len(system_labels)))
    ax_error[plt_idx].set_xticklabels(displayed_x_labels, rotation=30, ha='right')
    ax_error[plt_idx].spines[['right', 'top']].set_visible(False)

    # Add line
    if add_line:
        ax_error[plt_idx].plot(line_location, [0, ax_error[plt_idx].get_ylim()[-1]], ls='--', c='k')

    if family == Family.CONTENT:
        ax_error[plt_idx].set_ylabel('# Error Edits / Sent.', fontsize=font_size_y_label)

    handles, labels = ax_error[plt_idx].get_legend_handles_labels()
    ax_error[plt_idx].legend(handles[::-1], labels[::-1], loc='center left', bbox_to_anchor=legend_loc,
        fancybox=True, ncol=1, borderaxespad=1.,fontsize=font_size_legend,
        facecolor='white',edgecolor='black',framealpha=1,frameon=False,
        columnspacing=1,handlelength=1,handleheight=1,handletextpad=0.6,
        borderpad=0.2, alignment='left', 
        # title=f'{family.value} Error', title_fontproperties={'weight': 500, 'size': font_size_legend, 'style': 'italic'}
        )

    for ha in ax_error[plt_idx].legend_.legendHandles:
        ha.set_edgecolor("black")

    trans = mtrans.Affine2D().translate(10, 0)
    for t in ax_error[plt_idx].get_xticklabels():
        t.set_transform(t.get_transform()+trans)

    for text in ax_error[plt_idx].legend_.texts:
        text.set_text(fix_text(text.get_text(), syntax_error=(family == Family.SYNTAX)))

    # Set the margins a little higher than the max value
    step_dict = {
        Family.CONTENT: 0.2,
        Family.SYNTAX: 0.1,
        Family.LEXICAL: 0.2,
    }
    max_error = max([sum(x.values()) for x in error_data.values()])
    tick_range_error = np.arange(0, max_error*1.2, step=step_dict[family])
    ax_error[plt_idx].set_yticks(tick_range_error)

    ax_error[0].set_title('Conceptual Error', fontsize=font_size_title)
    ax_error[1].set_title('Syntax Error', fontsize=font_size_title)
    ax_error[2].set_title('Lexical Error', fontsize=font_size_title)

plt.tight_layout()
if savefig:
    out_filename = f'../paper/plot/error-edits.pdf'
    plt.savefig(out_filename, format="pdf", bbox_inches='tight', pad_inches=0.0)
    plt.close(fig)
else:
    plt.show()

### Test Set Edit Frequency

In [None]:
from utils import *

test_data = load_data('../data/test_set_inspection_rating_annotated', preprocess=True, adjudicated=True)
systems = set([x['system'] for x in test_data])
edit_types = set(test_data[0]['annotations'].keys())

In [None]:
excluded_systems = [
    'new-wiki-4/Ctrl-T5-3b',
    'new-wiki-4/Ctrl-T5-11b'
]

data = [sent for sent in test_data if sent['system'] not in excluded_systems]

In [None]:
savefig = True

# fig_dim = (9.5, 3) # 3.5
# fig_dim = (6.5, 3)
fig_dim = (6.5, 2.7) 
# fig_dim = (5, 3) # <- previously used
fig_dim = (6, 3)

# fig_dim = (4, 6) # stacked graphs

fig, ax_quality = plt.subplots(1, 2, figsize=fig_dim)
width = 0.75

# import matplotlib.gridspec as gridspec
# fig = plt.figure(figsize=fig_dim)
# gs = gridspec.GridSpec(1, 3, width_ratios=[1, 0.1, 1])
# ax_quality = [plt.subplot(gs[0]), plt.subplot(gs[2])]

plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'TW Cen MT'

font_size_legend = 9
font_size_x_labels = 12 
font_size_y_label = 12
font_size_title = 15
legend_loc = (1, 0.5)

class SpecialCategory(Enum):
    ELABORATION_ALL = 'Bad Elaboration'
    REORDER = 'Reorder'
    OTHER = 'Other'

color_mapping.update({
    SpecialCategory.ELABORATION_ALL: '#a2cf72',
    SpecialCategory.OTHER: '#d9d9d9'
})

def fix_text(text, syntax_error=False):
    new_text = text.replace(' ', '\n').replace('t-', 't-\n')
    # new_text = text.replace('Structure', 'Bad Structure')
    if syntax_error:
        new_text = f'Bad {new_text}'
    if 'level' in new_text:
        # new_text = f'{new_text} Reorder'
        # new_text = new_text.replace('Word-level ', 'Word-level\n')
        new_text = 'Bad Reorder'
    if 'Grammar' in new_text:
        new_text = new_text.replace('Grammar', 'Grammar\nError')
    
    # Manual text wrapping
    new_text = new_text.replace(' ', '\n')

    # n = 7
    # text_split = new_text.split('\n')
    # for line in text_split:
    #     if len(line) > n:
    #         m = n
    #         if len(line) < 2*n:
    #             m = 6
    #         new_text = new_text.replace(line, f'{line[:m]}-\n{line[m:]}')
    return new_text

bottom = [0 for x in range(len(set([x['system'] for x in data])))]
for _, family in reversed(list(enumerate(Family))):
    out = get_edits_by_family(data, family, combine_humans=False)

    # Get the system labels by preserving the order of systems
    system_labels = [x for x in system_name_mapping if x in out.keys()]
    x = np.arange(len(system_labels))

    # Graph the quality edits
    quality_data = {system : out[system]['quality'] for system, _ in out.items()}
    
    if family == Family.CONTENT:
        quality_iterator = [Information.MORE, Information.LESS]
    elif family == Family.SYNTAX:
        quality_iterator = [ReorderLevel.COMPONENT] + [Edit.STRUCTURE] + [Edit.SPLIT]
    elif family == Family.LEXICAL:
        quality_iterator = [Information.SAME]
    for quality_type in quality_iterator:
        val = [quality_data[label][quality_type] for label in system_labels]
        if sum(val) != 0:
            # Custom labels
            label = quality_type.value
            # if family == Family.SYNTAX:
            #     label = 'Quality ' + quality_type.value.lower() + ' edit'
            if quality_type == Information.SAME:
                label = 'Paraphrase'

            ax_quality[0].bar(x, val, width, bottom=bottom, label=label, color=color_mapping[quality_type])
        bottom = [bottom[i] + val[i] for i in range(len(val))]

displayed_x_labels = [system_name_mapping[label] for label in system_labels]

ax_quality[0].tick_params(labelsize=font_size_x_labels)
ax_quality[0].set_xticks(np.arange(len(system_labels)))
ax_quality[0].set_xticklabels(displayed_x_labels, rotation=30, ha="right")
ax_quality[0].spines[['right', 'top']].set_visible(False)

if family == Family.CONTENT:
    ax_quality[0].set_ylabel('# Quality Edits / Sent.', fontsize=font_size_y_label)

# Put legend to right of graph
# handles, labels = ax_quality[0].get_legend_handles_labels()
# ax_quality[0].legend(handles[::-1], labels[::-1], loc='center left', bbox_to_anchor=legend_loc,
#     fancybox=True, ncol=1, borderaxespad=1.,fontsize=font_size_legend,
#     facecolor='white',edgecolor='black',framealpha=1,frameon=False,
#     columnspacing=1,handlelength=1,handleheight=1,handletextpad=0.6,
#     borderpad=0.2, alignment='left', 
#     # title=f'{family.value} Quality', title_fontproperties={'weight': 500, 'size': font_size_legend, 'style': 'italic'}
#     )

# Put legend below graph
# legend_loc = (0.4, -0.4)
# ax_quality[0].legend(handles[::-1], labels[::-1], loc='upper center', bbox_to_anchor=legend_loc,
#     fancybox=True, ncol=2, borderaxespad=1.,fontsize=font_size_legend,
#     facecolor='white',edgecolor='black',framealpha=1,frameon=False,
#     columnspacing=1,handlelength=1,handleheight=1,handletextpad=0.6,
#     borderpad=0.2, alignment='left', 
#     # title=f'{family.value} Quality', title_fontproperties={'weight': 500, 'size': font_size_legend, 'style': 'italic'}
#     )

# for ha in ax_quality[0].legend_.legendHandles:
#     ha.set_edgecolor("black")

trans = mtrans.Affine2D().translate(10, 0)
for t in ax_quality[0].get_xticklabels():
    t.set_transform(t.get_transform()+trans)

# Set the margins a little higher than the max value
max_quality = max([sum(x.values()) for x in quality_data.values()])
tick_range_quality = np.arange(0, 11, step=1.0) # max_quality*1.2
ax_quality[0].set_yticks(tick_range_quality)

# for text in ax_quality[0].legend_.texts:
#     text.set_text(fix_text(text.get_text()))

bottom = [0 for x in range(len(system_labels))]
for _, family in reversed(list(enumerate(Family))):
    out = get_edits_by_family(data, family, combine_humans=False)
    
    # Get the system labels by preserving the order of systems
    system_labels = [x for x in system_name_mapping if x in out.keys()]
    x = np.arange(len(system_labels))

    # Graph the error edits
    error_data = {system : out[system]['error'] for system, _ in out.items()}
    # Combine error types

    if family == Family.CONTENT:
        for sys_name, system in error_data.items():
            error_data[sys_name][SpecialCategory.ELABORATION_ALL] = sum([system[e_type] for e_type in [
                Error.IRRELEVANT,
                Error.FACTUAL,
                Error.CONTRADICTION,
                Error.REPETITION
            ]])
        error_iterator = [SpecialCategory.ELABORATION_ALL] + [Error.BAD_DELETION]
    elif family == Family.SYNTAX:
        error_iterator = [ReorderLevel.COMPONENT] + [Edit.STRUCTURE] + [Edit.SPLIT]
    elif family == Family.LEXICAL:    
        # BEGIN BAD CODE

        out_i = get_edits_by_family(data, Family.CONTENT, combine_humans=False)
        error_data_i = {system : out_i[system]['error'] for system, _ in out_i.items()}
        for family_j in [Family.SYNTAX, Family.LEXICAL]:
            family_j_out = get_edits_by_family(data, family_j, combine_humans=False)
            error_data_j = {system : family_j_out[system]['error'] for system, _ in family_j_out.items()}
            for sys_name_j, system_j in error_data_j.items():
                error_data_i[sys_name_j].update(system_j)
        
        for sys_name, system in error_data_i.items():
            error_data[sys_name][SpecialCategory.OTHER] = sum([system[e_type] for e_type in [
                # Error.BAD_SPLIT, # Edit.SPLIT
                Error.INFORMATION_REWRITE,
                Quality.ERROR
            ]])

        # END BAD CODE
        error_iterator = [SpecialCategory.OTHER, Error.COMPLEX_WORDING] # Error.INFORMATION_REWRITE, Quality.ERROR
    
    for error_type in error_iterator:
        val = [error_data[label][error_type] for label in system_labels]

        # This is a really awful solution, but it would be too much to change the classification of this error
        if family == Family.LEXICAL and error_type == Error.UNNECESSARY_INSERTION:
            tmp = get_edits_by_family(data, Family.CONTENT)
            tmp = {system : tmp[system]['error'] for system, _ in tmp.items()}
            val = [tmp[label][Error.UNNECESSARY_INSERTION] for label in system_labels]

        if sum(val) != 0:
            # Custom labels
            label = error_type.value
            if family == Family.SYNTAX or Family.LEXICAL:
                if label == 'Error':
                    label = 'Grammar'
                # label += ' Error'

            color = color_mapping[error_type]
            if family == Family.SYNTAX:
                color = color_mapping_override[error_type]
            ax_quality[1].bar(x, val, width, bottom=bottom, label=label, color=color)
        bottom = [bottom[i] + val[i] for i in range(len(val))]

displayed_x_labels = [system_name_mapping[label] for label in system_labels]

ax_quality[1].tick_params(labelsize=font_size_x_labels)
ax_quality[1].set_xticks(np.arange(len(system_labels)))
ax_quality[1].set_xticklabels(displayed_x_labels, rotation=30, ha='right')
ax_quality[1].spines[['right', 'top']].set_visible(False)

if family == Family.CONTENT:
    ax_quality[1].set_ylabel('# Error Edits / Sent.', fontsize=font_size_y_label)

handles, labels = ax_quality[1].get_legend_handles_labels()

from matplotlib.legend_handler import HandlerBase
from matplotlib.patches import Polygon

pairwise_facecolor = {
    '#eb6565': '#f24949', # Bad Deletion
    '#a2cf72': '#86d95d', # Bad Elaboration
    '#f0ce5d': '#F7CE46', # Bad Split
    '#f5b078': '#ffa159', # Bad Structure
    '#99d1b7': '#7db39a', # Bad Reorder
    '#8e88f7': '#428cd6', # Complex Wording
    '#d9d9d9': '#d9d9d9', # Other
}

class TwoTriangleHandler(HandlerBase):
    def create_artists(self, legend, orig_handle,
                       xdescent, ydescent, width, height, fontsize, trans):
        facecolor = orig_handle[0].get_facecolor()
        linewidth = orig_handle[0].get_linewidth()
        edgecolor = 'black'

        # rect1 = plt.Rectangle((0, 0), width, height / 2, facecolor=facecolor, edgecolor=edgecolor, linewidth=linewidth)
        # rect2 = plt.Rectangle((0, height / 2), width, height / 2, facecolor=facecolor, edgecolor=edgecolor, linewidth=linewidth)

        import matplotlib
        # print(matplotlib.colors.rgb2hex(facecolor))
        pairwise_color = pairwise_facecolor[matplotlib.colors.rgb2hex(facecolor)]

        top_triangle_vertices = np.array([[0, 0], [0, width], [height, width]])
        bottom_triangle_vertices = np.array([[0, 0], [height, width], [width, 0]])
        top_triangle = plt.Polygon(top_triangle_vertices, facecolor=pairwise_color, closed=True, edgecolor='black', linewidth=0)
        bottom_triangle = plt.Polygon(bottom_triangle_vertices, facecolor=facecolor, closed=True, edgecolor='black', linewidth=0)

        # bottom_triangle.set_path_effects([plt.matplotlib.patheffects.Normal(), plt.matplotlib.patheffects.SimpleLineShadow(), plt.matplotlib.patheffects.Normal()])

        n = 0.9
        line = plt.Line2D([0 + n, width - n], [0+n, width-n], color='#595959', linewidth=0.7)
        rectangle = plt.Rectangle((0, 0), height, width, edgecolor='black', facecolor='none')

        # rect1.set_transform(trans)
        # rect2.set_transform(trans)

        legend.handle_length = 2.0

        # No two colors for "other" type
        if pairwise_color == '#d9d9d9':
            return [top_triangle, bottom_triangle, rectangle]

        return [top_triangle, bottom_triangle, rectangle, line]

handler_map = {type(h): TwoTriangleHandler() for h in handles}

# Put legend to right of graph
ax_quality[1].legend(handles[::-1], labels[::-1], loc='center left', bbox_to_anchor=legend_loc,
    fancybox=True, ncol=1, borderaxespad=1.,fontsize=font_size_legend,
    facecolor='white',edgecolor='black',framealpha=1,frameon=False,
    columnspacing=1,handlelength=1,handleheight=1,handletextpad=0.6,
    borderpad=0.2, alignment='left', handler_map=handler_map
    # title=f'{family.value} Error', title_fontproperties={'weight': 500, 'size': font_size_legend, 'style': 'italic'}
    )

# Put legend below graph
# legend_loc = (0.4, -0.4)
# ax_quality[1].legend(handles[::-1], labels[::-1], loc='upper center', bbox_to_anchor=legend_loc,
#     fancybox=True, ncol=4, borderaxespad=1.,fontsize=font_size_legend,
#     facecolor='white',edgecolor='black',framealpha=1,frameon=False,
#     columnspacing=1,handlelength=1,handleheight=1,handletextpad=0.6,
#     borderpad=0.2, alignment='left'
#     # title=f'{family.value} Error', title_fontproperties={'weight': 500, 'size': font_size_legend, 'style': 'italic'}
#     )

text_override = {
    'Bad Deletion': 'Good/Bad Generalization',
    'Bad Elaboration': 'Good/Bad Elaboration',
    'Split': 'Good/Bad Split',
    'Structure': 'Good/Bad Structure',
    'Component-level': 'Good/Bad Reorder',
    'Complex Wording': 'Good/Bad Paraphrase',
    'Other': 'Other Errors'
}

# Shared single legend below graph
# legend_loc = (0.5, 0.05)
# leg = fig.legend(handles[::-1], labels[::-1], loc='upper center', bbox_to_anchor=legend_loc,
#     fancybox=True, ncol=4, borderaxespad=1.,fontsize=font_size_legend,
#     facecolor='white',edgecolor='black',framealpha=1,frameon=False,
#     columnspacing=1,handlelength=1,handleheight=1,handletextpad=0.6,
#     borderpad=0.2, alignment='left', handler_map=handler_map
#     # title=f'{family.value} Error', title_fontproperties={'weight': 500, 'size': font_size_legend, 'style': 'italic'}
#     )
# for text in leg.texts:
#     text.set_text(fix_text(text_override[text.get_text()], syntax_error=(family == Family.SYNTAX)))

# for ha in ax_quality[1].legend_.legendHandles:
#     ha.set_edgecolor("black")

trans = mtrans.Affine2D().translate(10, 0)
for t in ax_quality[1].get_xticklabels():
    t.set_transform(t.get_transform()+trans)

for text in ax_quality[1].legend_.texts:
    text.set_text(fix_text(text_override[text.get_text()], syntax_error=(family == Family.SYNTAX)))

# Set the margins a little higher than the max value
step_dict = {
    Family.CONTENT: 0.2,
    Family.SYNTAX: 0.05,
    Family.LEXICAL: 0.1,
}
# max_error = max([sum(x.values()) for x in error_data.values()])
tick_range_error = np.arange(0, 2.5, step=0.5)
ax_quality[1].set_yticks(tick_range_error)

ax_quality[0].set_title('Quality Edits', fontsize=font_size_title)
ax_quality[1].set_title('Error Edits', fontsize=font_size_title)

plt.tight_layout()
if savefig:
    out_filename = f'../paper/plot/test-set-edits.pdf'
    plt.savefig(out_filename, format="pdf", bbox_inches='tight', pad_inches=0.0)
    plt.close(fig)
else:
    plt.show()