In [2]:
from string import Template
import pandas as pd 
from openai import OpenAI
import re
import os

In [26]:


client = OpenAI(api_key="")

def extract_answer_and_explanation(response_text):
    pattern = re.compile(r"answer:\s*(.*?)\s*\nexplanation:\s*(.*)", re.DOTALL | re.IGNORECASE)
    matches = pattern.search(response_text)
    if matches:
        answer = matches.group(1)
        explanation = matches.group(2)
        return answer, explanation
    else:
        return None, None

def query_openai(prompt: str) -> str:
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}],
        temperature=1,
        max_tokens=150,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
        response_format={"type": "text"}
    )
    return response.choices[0].message.content


prompt_template = """
For the provided situation choose which of five emotions is most likely to result from that situation. 
First provide the correct letter and the a brief explanation.

Example: 
Situation: 
Bob saw a cat and yawned.
A. Happy B. Angry C. Frightened D. Bored E. Hungry

Answer: D
Explanation: In the given situation, Bob yawning upon seeing a cat typically suggests a lack 
of interest or excitement, indicating boredom. 

Situation: 
$situation
"""

questions = pd.read_csv("../data/steu-abilities-test.csv")

In [28]:
import tqdm.notebook as tqdm
results = []

for index, row in tqdm.tqdm(questions.iterrows(), total = len(questions)):

    situation = f"{row['situation']} \nA. {row['A']} B. {row['B']} C. {row['C']} D. {row['D']} E. {row['E']}"
    prompt = prompt_template.replace("$situation", situation)

    response = query_openai(prompt)
    answer, explanation = extract_answer_and_explanation(response)
    
    row_dict = row.to_dict()
    row_dict["pred"] = answer
    row_dict["explanation"] = explanation
    row_dict["response"] = response

    row_dict["is_correct"] = int(row_dict['answer label'].lower() == row_dict['pred'].lower())
    results.append(row_dict)

results = pd.DataFrame(results)
results.head()

  0%|          | 0/43 [00:00<?, ?it/s]

Unnamed: 0,Instructions,situation,A,B,C,D,E,answer label,answer,pred,explanation,response,is_correct
0,The following questions each describe a situat...,Clara receives a gift. Clara is most likely to...,Happy,Angry,Frightened,Bored,Hungry,A,Happy,A,Receiving a gift usually evokes feelings of ha...,Answer: A \nExplanation: Receiving a gift usu...,1
1,The following questions each describe a situat...,A pleasant experience ceases unexpectedly and ...,Ashamed,Distressed,Angry,Sad,Frustrated,D,Sad,D,"When a pleasant experience ends unexpectedly, ...",Answer: D \nExplanation: When a pleasant expe...,1
2,The following questions each describe a situat...,Xavier completes a difficult task on time and ...,Surprise,Pride,Relief,Hope,Joy,B,Pride,B,Completing a difficult task on time and under ...,Answer: B \nExplanation: Completing a difficu...,1
3,The following questions each describe a situat...,An irritating neighbour of Eve's moves to anot...,Regret,Hope,Relief,Sadness,Joy,C,Relief,C,"In this situation, Eve is most likely to feel ...","Answer: C \nExplanation: In this situation, E...",1
4,The following questions each describe a situat...,There is great weather on the day Jill is goin...,Pride,Joy,Relief,Guilt,Hope,B,Joy,B,Great weather on the day of an outdoor picnic ...,Answer: B \nExplanation: Great weather on the...,1


In [32]:
results.to_csv("../results/gpt4o-mini/steu-results.csv", index=False)

In [30]:
results["is_correct"].mean()

0.7906976744186046

### Extending to compile all results STEU

Some cases did not generate any predictions for Gemma, if nothing was predicted `is_correct`=0 

In [11]:
FILE_NAME = 'steu-results.csv'
RESULTS_MODS = '../results/'
print("STEU ACCURACY RESULTS")
for model_dir in os.listdir(RESULTS_MODS):
    df = pd.read_csv(os.path.join(RESULTS_MODS, model_dir, FILE_NAME))
    print(f"{model_dir:<30}: {df['is_correct'].mean():<15}")

STEU ACCURACY RESULTS
gemma-2-27b-it                : 0.5813953488372093
Mixtral-8x7B-Instruct-v0.1    : 0.5581395348837209
Meta-Llama-3.1-70B-Instruct   : 0.627906976744186
gpt4o-mini                    : 0.7906976744186046
