##### GPT-4o "as an expert"

In [None]:
import pandas as pd, numpy as np, pickle, os, json, random
from tqdm import tqdm
from figurative_flute.utils import postprocess_output_few
from openai import OpenAI

API_KEY = 'MY_OPENAI_API_KEY'

client = OpenAI(api_key=API_KEY)

src_path = './outputs/finetuning/gemma2_5e-5/outputs/test/pred_ref_test.csv'

df_flute = pd.read_csv('./figurative_flute/data/outputs/flute/outputs.csv')
df_gemma2 = pd.read_csv(src_path)
gemma2_labels, gemma2_explanations = df_gemma2['pred_label'], df_gemma2['pred_explanation'].fillna('').to_list()
premises, hypothesis, ref_labels, ref_explanations = df_flute['premise'].to_list(), df_flute['hypothesis'].to_list(), df_flute['ref_label'].to_list(), df_flute['ref_explanation'].to_list()
assert len(gemma2_labels) == len(gemma2_explanations) == len(ref_labels) == len(ref_explanations) == len(premises) == len(hypothesis)

def extract_expl_examples(k):
    raw_data = json.load(open('./figurative_flute/data/train.json', 'r'))
    expl = [d['explanation'] for d in raw_data if len(d['explanation']) > 5]
    random.seed(42)
    indices = random.sample(range(len(expl)), k)
    return [expl[i] for i in indices]

def process_output_expert(output):
    if 'Label:' in output:
        label = output.split('Label:')[1].split('\n')[0].strip()
    else:
        label = output.split('\n')[0].strip()
    label = 'Entailment' if 'entail' in label.lower() else 'Contradiction' if 'contradict' in label.lower() else 'Correct'
    if 'Explanation:' in output:
        explanation = output.split('Explanation:')[1].split('\n')[0].strip()
    else:
        explanation = output[output.rfind(':')+1:].strip()
    return label, explanation

template1 = """I will provide you with a pair of sentences consisting of a premise and a hypothesis containing figurative language. The task is to determine whether there is a contradiction or entailment between the premise and hypothesis, and provide an explanation for it.
I will also provide you with a model's prediction ("Entailment" or "Contradiction") and explanation.
Your task is to verify the correctness of the prediction and, if needed, improve the explanation. When modifying the explanation, do not explicitly mention "premise" or "hypothesis", and keep the same length and style of the model's generated one.

Premise: [PREMISE]
Hypothesis: [HYPOTHESIS]

Model's prediction
Label: [PRED_LABEL]
Explanation: [PRED_EXPL]

Answer with
Label:
Explanation:
If the label is the same as the model's prediction, write "Correct". If the explanation does not need improvement, write "Correct"."""

templates = {1: template1}
expl_few_shot = 10

template_id = 1

out_folder = f'./outputs/gpt_expert/template_{template_id}' + (f'_k{expl_few_shot}/' if expl_few_shot > 0 else '/')
os.makedirs(out_folder, exist_ok=True)
with open(f'{out_folder}args.txt', 'w') as f:
    f.write(f'src_path={src_path}\ntemplate_id={template_id}')
if os.path.exists(f'{out_folder}errors.txt'): os.remove(f'{out_folder}errors.txt')

full_outputs, expert_labels, expert_explanations = [], [], []

for p, h, pred_label, pred_expl in tqdm(zip(premises, hypothesis, gemma2_labels, gemma2_explanations), total=len(gemma2_explanations)):
    prompt = templates[template_id].replace('[PREMISE]', p).replace('[HYPOTHESIS]', h).replace('[PRED_LABEL]', pred_label).replace('[PRED_EXPL]', pred_expl)
    if expl_few_shot > 0:
        expl_examples = [f'- {e}' for e in extract_expl_examples(expl_few_shot)]
        few_shot = f'In the following you can find some examples of explanations. Use the same style.\n' + '\n'.join(expl_examples)
        ix = prompt.find('Premise')
        prompt = prompt[:ix] + few_shot + '\n\n' + prompt[ix:]
    messages = [{"role": "user", "content": prompt}]
    try:
        completion = client.chat.completions.create(model='gpt-4o-2024-05-13', messages=messages, temperature=0.0, max_tokens=64, seed=42)
        output_text = completion.choices[0].message.content
        expert_label, expert_explanation = process_output_expert(output_text)
        if expert_label.lower() in pred_label.lower(): expert_label = 'Correct'
    except:
        expert_label, expert_explanation = '', ''
        if not 'output_text' in locals(): output_text = 'API Error'
        with open(f'{out_folder}errors.txt', 'a') as f:
            f.write(f'### Prompt ###\n{prompt}\n### Output ###\n{output_text}\n### Pred label ###\n{expert_label}\n### Pred explanation ###\n{expert_explanation}\n@@@@@@@@@@@@@@@@@@@@@@\n')
    full_outputs.append(output_text)
    expert_labels.append(expert_label)
    expert_explanations.append(expert_explanation)
    
for i, (g_lab, g_exp, e_lab, e_exp) in enumerate(zip(gemma2_labels, gemma2_explanations, expert_labels, expert_explanations)):
    if e_lab == '': expert_labels[i] = g_lab
    if e_exp == '': expert_explanations[i] = g_exp
    
df_pred_ref = pd.DataFrame({'original_label': gemma2_labels, 'pred_label': expert_labels, 'ref_label': ref_labels, 'original_explanation': gemma2_explanations, 'pred_explanation': expert_explanations, 'ref_explanation': ref_explanations, 'output': full_outputs})
df_pred_ref.to_csv(f'{out_folder}expert_pred_ref.csv', sep=',', header=True, index=False)

###### process and analyze output

In [None]:
import pandas as pd, numpy as np, pickle, os

template_id = 1
df_expert = pd.read_csv(f'./outputs/gpt_expert/template_{template_id}_k10/expert_pred_ref.csv')
df_gemma2 = pd.read_csv('./outputs/finetuning/gemma2_5e-5/outputs/test/pred_ref_test.csv')

expert_labels, expert_explanations = df_expert['pred_label'].to_list(), df_expert['pred_explanation'].fillna('').to_list()
gemma2_labels, gemma2_explanations = df_gemma2['pred_label'].to_list(), df_gemma2['pred_explanation'].fillna('').to_list()
ref_labels, ref_explanations = df_gemma2['ref_label'].to_list(), df_gemma2['ref_explanation'].to_list()
assert len(expert_labels) == len(gemma2_labels) == len(expert_explanations) == len(gemma2_explanations) == len(ref_labels) == len(ref_explanations)

gemma2_expert_labels, gemma2_expert_explanations = [], []

for g_lab, g_exp, e_lab, e_exp in zip(gemma2_labels, gemma2_explanations, expert_labels, expert_explanations):
    # label
    if 'correct' in e_lab.lower() or g_lab == e_lab:
        gemma2_expert_labels.append(g_lab)
        gemma2_expert_explanations.append(g_exp)
    else:
        gemma2_expert_labels.append(e_lab)
        # explanation
        if e_exp.lower().startswith('correct'):
            gemma2_expert_explanations.append(g_exp)
        else:
            gemma2_expert_explanations.append(e_exp)
assert len(gemma2_expert_labels) == len(gemma2_expert_explanations) == len(ref_explanations)

out_folder = f'./outputs/gpt_expert/template_{template_id}_k10/'
df_pred_ref = pd.DataFrame({'pred_label': gemma2_expert_labels, 'ref_label': ref_labels, 'pred_explanation': gemma2_expert_explanations, 'ref_explanation': ref_explanations})
# df_pred_ref.to_csv(f'{out_folder}gemma2_expert_pred_ref_v2.csv', sep=',', header=True, index=False)