In [None]:
# Author: Malik Altakrori, PhD
# IBM Research
# malik.altakrori@ibm.com

In [None]:
# AFTER COMPLETING ALL THE EXPERIMENTS from notebook 2, you can start with this. 

In [None]:
import os
import json

In [None]:
from primeqa.mitqa.metrics.evaluate import compute_f1


In [None]:
root_folder = "<Provide the abs path to the repo>"

# you need to decide which experiment to run
results_folder = "Results" # or ["Results", "Results_Translated"] for the translated dialectal questions

In [None]:
# This function removes the unmatched quote punctuation as sometimes a generative model would generate one and stop (it is selecting a span)
# We remove this extra quote to prevent it from messing up csv files that we will use
def remove_extra_quote(pred, verbose=False):
    if verbose:
        print(f"before: {pred}")
    processed_prediction = pred
    last_loc = 0
    even = 1 # odd and even (* -1)
    for i, c in enumerate(processed_prediction):
        if c == "\"":
            even *= -1
            last_loc = i

    if even == -1 : # one does not match
        processed_prediction = processed_prediction[0:last_loc] + processed_prediction[last_loc+1:]
        
    if verbose:
        print(f"After: {processed_prediction}")
    return processed_prediction

pred = "\"Troy\"\"."
if pred.count("\"") % 2 == 1:
    print("Caught a case")
    remove_extra_quote(pred, True)

## Generate the pairs for each setting/Lang

In [None]:
# we extract the refernce answer and the generated answer for each experiment, and calculate the unnormalized f1scores 
for setting in ["All", "NoBB"]:
    os.makedirs(os.path.join(root_folder, results_folder, "normalizedScore", setting, "RefPredPairs"), exist_ok=True)

    for lang in ["EN", "MSA"]:
        full_path = os.path.join(root_folder, results_folder, setting, f"Results_{lang}")
        for p in os.listdir(full_path):
            if p.startswith(".DS"):
                continue
            print(f"{full_path}/{p}")    
            with open(f"{full_path}/{p}/eval_predictions_processed.json", "r") as predsFile:
                predsList = json.load(predsFile)            

            with open(f"{full_path}/{p}/eval_references.json", "r") as refsFile:
                refsList = json.load(refsFile)

            with open(os.path.join(root_folder, results_folder, "normalizedScore",setting, "RefPredPairs",f"{p}.tsv"), "w") as outFile:
                outFile.write("Reference\tPredection\tF1_score\n")
                for r, p in zip(refsList, predsList):            
                    f1 = compute_f1(a_gold=r['answers']['text'][0], a_pred=p['prediction_text'])

                    processed_prediction = p['prediction_text']
                    if processed_prediction.count("\"") % 2 == 1:
                        # print("Caught a case")
                        processed_prediction = remove_extra_quote(processed_prediction)

                    outFile.write(f"{r['answers']['text'][0]}\t{processed_prediction}\t{f1}\n")
        #             break 

        #     break
        # break


## Generate the script to run the code to calculate the normalized f1-score 

In [None]:
eval_file_loc = os.path.join(root_folder, "Notebooks", "f1_em_eval.py")
eval_file_loc

In [None]:
if not results_folder.endswith("Translated"):
    out_ = "original"
else:
    out_ = "translated"

with open(os.path.join(root_folder, "Scripts", out_, "run_new_metric_all.sh"), 'w') as outFile:
    outFile.writelines("#!/usr/bin/env bash \n")
    outFile.writelines("echo \"dont forget to activate your env\" \n\n")

    for setting in ["All", "NoBB"]:
        input_folder = os.path.join(root_folder, results_folder, "normalizedScore", setting, "RefPredPairs")
        output_folder = os.path.join(root_folder, results_folder, "normalizedScore", setting, "RefPredPairs_Processed" )

        os.makedirs(output_folder, exist_ok=True)
        for p in os.listdir(input_folder):
            input_file_path = os.path.join(input_folder, p)
            output_file_path = os.path.join(output_folder, p)
            lang_switch = p.split('-')[0]
            lang = "arabic" if lang_switch == "MSA" else "english"
            cmd = f"python {eval_file_loc} --input_file {input_file_path} --output_file {output_file_path} --language {lang}\n"

            print(cmd, end="")
            outFile.writelines(cmd)
        #     break
        # break