# Template for Accuracy Calculation for FLAN zero-shot and finetuned models

This script provides some functionalities to calculate accuracy scores for the results of fine-tuned and zero-shot models. Especially for zero-shot predictions, you might need to re-map the output of the model to match with the labels of the original data to calculate accurate metrics, as the output might add additional words or rephrase the output. If your task is not one of the tasks that we provide, you might need to add the mapping functions yourself by inspecting the output of your predictions (e.g. you can print all the unique combinations of output and original labels to see which categories have been created.) 

In [1]:
import os
import re
import glob

import pandas as pd 
from IPython.core.display import Markdown

from label_utils import task_num_to_task_name, dataset_num_to_dataset_name, plot_count_and_normalized_confusion_matrix, \
    task_to_display_labels, load_train_and_eval_sets, load_dataset_task_prompt_mappings, map_label_to_completion

In [2]:
import numpy as np

In [4]:
def map_outputs_task_1(output):
    if re.search(r'^(answer:){0,1}(\s)*a(\s)*$|(a(\.|:|\)))|(\s|^|\')relev(a|e)nt|aelevant', output.lower().strip()):
        return 'A'
    elif re.search(r'^(answer:){0,1}(\s)*b(\s)*$|b(\.|:|\))|not relevant|irrelevant|ielevant|\s+b$|brrelevant', output.lower().strip()):
        return 'B'
    elif output == np.nan or output == 'nan':
        return ""
    else:
        print(f'Weird value: {output.lower().strip()}')
        return ""


def map_outputs_task_2(output):
    if re.search(r'^(answer:){0,1}(\s)*a(\s)*$|a(\.|:|\))|challnge|problem|\bpro\b|blem', output.lower().strip()):
        return 'A'
    elif re.search(r'^(answer:){0,1}(\s)*b(\s)*$|b(\.|:|\))|solution|\blution\b', output.lower().strip()):
        return 'B'
    elif re.search(r'^(answer:){0,1}(\s)*c(\s)*$|c(\.|:|\))|neither|neutral|(\s)+c$', output.lower().strip()):
        return 'C'
    elif output == np.nan or output == 'nan':
        return ""
    else:
        print(f'Weird value: {output.lower().strip()}')
        return "np.nan"


def map_outputs_task_3(output):
    if re.search(r'^(answer:){0,1}(\s)*a(\s)*$|a(\.|:|\))|economic|economy|aconomy', output.lower().strip()):
        return 'A'

    elif re.search(r'^(answer:){0,1}(\s)*b(\s)*$|b(\.|:|\))|morality|rality', output.lower().strip()):
        return 'B'

    elif re.search(r'^(answer:){0,1}(\s)*c(\s)*$|c(\.|:|\))|fairness and equality|irness and equality', output.lower().strip()):
        return 'C'

    elif re.search(r'^(answer:){0,1}(\s)*d(\s)*$|d(\.|:|\))|policy prescription and evaluation|prescription and evaluation|licy prescription',
                   output.lower().strip()):
        return 'D'

    elif re.search(r'^(answer:){0,1}(\s)*e(\s)*$|e(\.|:|\))|law and order|crime and justice|law enforcement|w and order', output.lower().strip()):
        return 'E'

    elif re.search(r'^(answer:){0,1}(\s)*f(\s)*$|f(\.|:|\))|security and defense|curity and defense', output.lower().strip()):
        return 'F'

    elif re.search(r'^(answer:){0,1}(\s)*g(\s)*$|g(\.|:|\))|health and safety|alth and safety', output.lower().strip()):
        return 'G'

    elif re.search(r'^(answer:){0,1}(\s)*h(\s)*$|h(\.|:|\))|quality of life|ality of life', output.lower().strip()):
        return 'H'

    elif re.search(r'^(answer:){0,1}(\s)*i(\s)*$|i(\.|:|\))|political|litical', output.lower().strip()):
        return 'I'

    elif re.search(r'^(answer:){0,1}(\s)*j(\s)*$|j(\.|:|\))|external (regulation|region) and reputation|external regulation|regulation and reputation', output.lower().strip()):
        return 'J'

    elif re.search(
            r'^(answer:){0,1}(\s)*k(\s)*$|(k|n|w)(\.|:|\))|other|climate change|leadership and executive responsibility|'
            r'expansion of service opportunities|access to higher ed|potential',
            output.lower().strip()):
        return 'K'

    elif output == np.nan or output == 'nan':
        return ""

    else:
        print(f'Weird value: {output.lower().strip()}')
        return ""


def map_outputs_task_4(output):
    if re.search(r'^(answer:){0,1}(\s)*a(\s)*$|a(\.|:|\))|positive|postive stance|in favor|in advantage of|aast|a favor of a', output.lower().strip()):
        return 'A'

    elif re.search(r'^(answer:){0,1}(\s)*b(\s)*$|b(\.|:|\))|negative|negative stance|against|aggainst|bast', output.lower().strip()):
        return 'B'

    elif re.search(r'^(answer:){0,1}(\s)*c(\s)*$|c(\.|:|\))|neutral|neutral stance|cast', output.lower().strip()):
        return 'C'

    elif output == np.nan or output == 'nan':
        return ""

    else:
        print(f'Weird value: {output.lower().strip()}')
        return ""


def map_outputs_task_5(output):
    if re.search(r'^(answer:){0,1}(\s)*a(\s)*$|a(\.|:|\))|section 230|230', output.lower().strip()):
        return 'A'

    elif re.search(r'^(answer:){0,1}(\s)*b(\s)*$|b(\.|:|\))|trump ban|ban donald trump|ban(ning){0,1} trump|tr ban', output.lower().strip()):
        return 'B'

    elif re.search(r'^(answer:){0,1}(\s)*c(\s)*$|c(\.|:|\))|twitter support', output.lower().strip()):
        return 'C'

    elif re.search(r'^(answer:){0,1}(\s)*d(\s)*$|d(\.|:|\))|platform policies|policies', output.lower().strip()):
        return 'D'

    elif re.search(r'^(answer:){0,1}(\s)*e(\s)*$|e(\.|:|\))|complaint(s)+', output.lower().strip()):
        return 'E'

    elif re.search('^(answer:){0,1}(\s)*f(\s)*$|f(\.|:|\))|other',
                   output.lower().strip()):
        return 'F'

    elif output == np.nan or output == 'nan':
        return ""

    else:
        print(f'Weird value: {output.lower().strip()}')
        return  ""
    
def map_outputs_task_6(output):
     if re.search(r'^(answer:){0,1}(\s)*a(\s)*$|a(\.|:|\))|policy prescription|policy prescription and regulation|licy and regulation|alicy',
                   output.lower().strip()):
        return 'A'
     
     elif re.search(r'^(answer:){0,1}(\s)*b(\s)*$|b(\.|:|\))|morality|rality', output.lower().strip()):
        return 'B'
     
     elif re.search(r'^(answer:){0,1}(\s)*c(\s)*$|c(\.|:|\))|economics|econom|onomics', output.lower().strip()):
        return 'C'

     elif re.search(r'^(answer:){0,1}(\s)*d(\s)*$|d(\.|:|\))|other', output.lower().strip()):
        return 'D'

     elif output == np.nan or output == 'nan':
        return ""

     else:
        print(f'Weird value: {output.lower().strip()}')
        return ""

In [5]:
def process_output_completed(completion: str, task:int) -> str:
    completion = re.sub(r'(?i)Answer|folks|Plain|River|IN', '', completion)    
    answers = completion.strip().split(' ')
    if task == 1:
        return map_outputs_task_1(completion)
    if task == 2:
        return map_outputs_task_2(completion)
    if task == 3:
        return map_outputs_task_3(completion)
    if task == 4:
        return map_outputs_task_4(completion)
    if task == 5:
        return map_outputs_task_5(completion)
    if task == 6:
        return map_outputs_task_6(completion)
    
    #YOU MIGHT NEED TO ADD YOUR TASK HERE

In [7]:
dataset_task_mappings_fp = os.path.join('..', '..', 'dataset_task_mappings.csv')

# Load data

In [8]:
predictions_fp = glob.glob(os.path.join('..', '..', 'predictions',  'google_flan-t5-xl__w_generate', 'google_flan-t5-xl_*.csv'))
predictions_fp = sorted(predictions_fp)

In [None]:
predictions_fp

In [None]:
df = pd.read_csv(prediction_fp)

# Get the expected labelset
dataset_idx, dataset_task_mappings = load_dataset_task_prompt_mappings(
    dataset_num=ds_num, task_num=task_num, dataset_task_mappings_fp=dataset_task_mappings_fp)
label_column = dataset_task_mappings.loc[dataset_idx, "label_column"]
labelset = dataset_task_mappings.loc[dataset_idx, "labelset"].split(",")
labelset = [label.strip() for label in labelset]
labelset_full_description = dataset_task_mappings.loc[dataset_idx, "labelset_fullword"].split(",")

In [None]:
# Get predictions
y_pred = df.prediction_ds.map(lambda x: process_output_completed(x, task_num))
#assert df['pred_label'].map(lambda pred: pred not in labelset).sum() == 0, 'Prediction not in expected labelset'

# Get ground truth in same format
y_true = df[label_column].map(lambda label: map_label_to_completion(
    label=label, task_num=task_num, full_label=False))
assert y_true.map(lambda pred: pred not in labelset).sum() == 0, 'Ground truth not in expected labelset'
    
# Get accuracy
labels = labelset
display_labels = labelset_full_description
cm_plot, classification_report, metrics = plot_count_and_normalized_confusion_matrix(
    y_true, y_pred, display_labels, labels, xticks_rotation='horizontal')

# Get accuracy
accuracy_summary_list.append({
    'exp_name': os.path.basename(prediction_fp),
    'dataset': ds_num,
    'task': task_num,
    'sample_size': sample_size,
    'accuracy': metrics['accuracy'],
    'f1-macro': metrics['f1'],
    'precision': metrics['precision'],
    'recall': metrics['recall']
})