In [2]:
import json
from tqdm import tqdm
import numpy as np
import evaluate
from rouge_score import rouge_scorer

In [3]:
rouge = rouge_scorer.RougeScorer(['rouge1'], use_stemmer=True)

In [4]:
def eval_attributes_matching_score( flag, data_path ):
    corpus = [ json.loads(line) for line in open(data_path) ]
    R1_precision_list = []
    R1_recall_list = []
    R1_F1_list = []
    intent_match_list = []
    
    for example in corpus:
        gen_info = example["generated_citations"][0]["generation"]
        gen_keywords = gen_info["keywords"]
        gen_intent = gen_info["citation_intent"]
    
        label_keywords = "; ".join( example["keywords"] )
        label_intent = example["citation_intent"]

        intent_match_list.append( int( label_intent == gen_intent ) )
        r_score = rouge.score( label_keywords, gen_keywords )["rouge1"]

        R1_precision_list.append( r_score.precision )
        R1_recall_list.append( r_score.recall )
        R1_F1_list.append( r_score.fmeasure )
    
    return flag + " & " + " & ".join([ "%.2f"%(np.round(np.mean(R1_precision_list)* 100 , 2) ), 
                       "%.2f"%(np.round(np.mean(R1_recall_list)* 100, 2) ), 
                       "%.2f"%(np.round(np.mean(R1_F1_list)* 100, 2) ), 
                       "%.4f"%(np.round(np.mean( intent_match_list ), 4) )               
                    ]) + " \\\\" 

In [5]:
eval_results = []
for flag, data_path in [
    ( "BART-base-140M", "../results/sft_model/bart-base/test_with_citations.jsonl"  ),
    ( "BART-large-400M", "../results/sft_model/bart-large/test_with_citations.jsonl"  ),
    ( "GPT-Neo-125M", "../results/sft_model/gpt-neo-125m-hf/test_with_citations.jsonl"  ),
    ( "GPT-Neo-1.3B", "../results/sft_model/gpt-neo-1.3b-hf/test_with_citations.jsonl"  ),
    ( "Galactica-125M", "../results/sft_model/galactica-125m-ct2/test_with_citations.jsonl"  ),
    ( "Galactica-125M-PPO", "../results/ppo_model/galactica-125m-ct2/test_with_citations.jsonl"  ),
    ( "Galactica-1.3B", "../results/sft_model/galactica-1.3b-ct2/test_with_citations.jsonl"  ),
    ( "Galactica-6.7B", "../results/sft_model/galactica-6.7b-ct2/test_with_citations.jsonl"  ),
    ( "Galactica-6.7B-PPO", "../results/ppo_model/galactica-6.7b-ct2/test_with_citations.jsonl"  ),
    ( "LLaMa-7B", "../results/sft_model/llama-7b-ct2/test_with_citations.jsonl"  ),
    ( "LLaMa-7B-PPO", "../results/ppo_model/llama-7b-ct2/test_with_citations.jsonl"  ),
]:
    eval_results.append( eval_attributes_matching_score( flag, data_path ) )
    
print("\n".join(eval_results))

BART-base-140M & 22.05 & 16.70 & 17.62 & 0.6083 \\
BART-large-400M & 24.92 & 18.47 & 19.68 & 0.6454 \\
GPT-Neo-125M & 21.10 & 16.36 & 17.13 & 0.5861 \\
GPT-Neo-1.3B & 28.00 & 23.18 & 23.58 & 0.6352 \\
Galactica-125M & 26.15 & 21.86 & 22.11 & 0.6204 \\
Galactica-125M-PPO & 24.80 & 20.84 & 21.06 & 0.6296 \\
Galactica-1.3B & 29.89 & 25.53 & 25.86 & 0.6602 \\
Galactica-6.7B & 29.49 & 24.78 & 25.10 & 0.6380 \\
Galactica-6.7B-PPO & 30.03 & 25.92 & 25.93 & 0.6407 \\
LLaMa-7B & 28.13 & 22.78 & 23.40 & 0.6352 \\
LLaMa-7B-PPO & 28.57 & 23.39 & 23.97 & 0.6315 \\
