In [1]:
import pandas as pd
import numpy as np
from utils import *

config = load_config()
PROJECT_PATH = config.project_path
DATA_PATH = PROJECT_PATH.joinpath("data/processed")

In [11]:
gpt_4o_mini = pd.read_pickle(DATA_PATH.joinpath('gpt-4o-mini_c1.pkl'))
llama32 = pd.read_pickle(DATA_PATH.joinpath('llama-3.2-3B_c1.pkl'))
llama32_lora = pd.read_pickle(DATA_PATH.joinpath('llama-3.2-3B-lora_c1.pkl'))
llama32_sft = pd.read_pickle(DATA_PATH.joinpath('llama-3.2-3B-sft_c1.pkl'))
llama32_lora_ppo = pd.read_pickle(DATA_PATH.joinpath('llama-3.2-3B-lora-ppo_c1.pkl'))
llama32_lora_grpo = pd.read_pickle(DATA_PATH.joinpath('llama-3.2-3B-lora-grpo_c1.pkl'))

model_names = ['gpt-4o-mini','llama3.2-3B', 'llama3.2-3B-sft', 'llama3.2-3B-lora', 'llama3.2-3B-lora-ppo', 'llama3.2-3B-lora-grpo']
results = [gpt_4o_mini, llama32, llama32_sft, llama32_lora, llama32_lora_ppo, llama32_lora_grpo]

In [12]:
def collect_bleu(model_data) :
    return round(model_data['bleu']['bleu'],3)

def collect_rouge(model_data) :
    return round(model_data['rouge']['rougeL'],3)

def collect_bertscore(model_data) :
    return round(np.mean(model_data['bertscore']['f1']),3)

def collect_readability(model_data) :
    scores = []
    for k, v in model_data['readability'].items() :
        scores.append(v['flesch_kincaid_grade'])
    return round(np.mean(scores),3)

In [13]:
def format_c1_results(model_name, model_data) :

    bleu = collect_bleu(model_data)
    rouge = collect_rouge(model_data)
    bertscore = collect_bertscore(model_data)
    readability = collect_readability(model_data)

    return {"model" : model_name, "bleu" : bleu, "rouge" : rouge, "bertscore" : bertscore, "readability" : readability}

In [14]:
all_result = []
for model_name, result in zip(model_names, results) :
    out = format_c1_results(model_name, result)
    all_result.append(out)

df = pd.DataFrame(all_result)
df

Unnamed: 0,model,bleu,rouge,bertscore,readability
0,gpt-4o-mini,0.02,0.119,0.853,10.672
1,llama3.2-3B,0.023,0.112,0.851,10.777
2,llama3.2-3B-sft,0.025,0.104,0.83,7.905
3,llama3.2-3B-lora,0.031,0.125,0.851,7.636
4,llama3.2-3B-lora-ppo,0.157,0.322,0.893,7.237
5,llama3.2-3B-lora-grpo,0.14,0.292,0.889,7.202
