# Evaluation Method

In [14]:
from enum import Enum
import json
import os
from typing import Optional, Dict, Any, List
import yaml

import outlines
from outlines import models, generate
import pandas as pd
import pprint
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm


In [6]:
class RelationshipStatus(str, Enum):
    na = "na"
    low = "low"
    medium = "medium"
    high = "high"
    
class EmotionLabel(BaseModel):
    joy: RelationshipStatus
    trust: RelationshipStatus
    fear: RelationshipStatus
    surprise: RelationshipStatus
    sadness: RelationshipStatus
    disgust: RelationshipStatus
    anger: RelationshipStatus
    anticipation: RelationshipStatus
    
    # class Config:
    #     extra = Extra.forbid
    #     use_enum_values = True
        
class EntryResult(BaseModel):
    emotion: EmotionLabel
    reason: str

In [7]:
prompt_name = "try1"
with open(f"prompts/{prompt_name}.yaml", "r") as f:
    prompt_dict = yaml.load(f, Loader=yaml.FullLoader)
SYSTEM_MESSAGE = prompt_dict['system']
USER_TEMPLATE = prompt_dict['user']

# Load Data

In [8]:
llm_model = "gpt-4.1-mini-2025-04-14"
df = pd.read_csv(f"data/comet/test_{llm_model}.tsv", sep="\t")
df = df.sample(10)
print(df.shape, df.columns)

(10, 16) Index(['uid', 'original_idx', 'original_src', 'original_relation',
       'original_tgt', 'source', 'character', 'joy', 'trust', 'fear',
       'surprise', 'sadness', 'disgust', 'anger', 'anticipation', 'reason'],
      dtype='object')


In [10]:
def make_messages(row):
    user_message = USER_TEMPLATE.format(
        source=row['source'],
        character=row['character']
    )
    assistant_message = json.dumps(
        {
            "emotion": {
                "joy": row['joy'],
                "trust": row['trust'],
                "fear": row['fear'],
                "surprise": row['surprise'],
                "sadness": row['sadness'],
                "disgust": row['disgust'],
                "anger": row['anger'],
                "anticipation": row['anticipation']
            },
            "reason": row['reason']
        }
    )

    messages = [
        {"role": "system", "content": SYSTEM_MESSAGE},
        {"role": "user", "content": user_message},
        {"role": "assistant", "content": assistant_message}
    ]
    return messages

In [12]:
all_messages = [
    make_messages(df.iloc[i]) for i in range(df.shape[0])
]

# Load Model

In [16]:
## Load Pretrained Model
run_name = "250421-01-qwen2_5-3b-mini-try1"
model_dir = f"weights/{run_name}/best"

tokenizer = AutoTokenizer.from_pretrained(model_dir)
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(
    model_dir, torch_dtype=torch.bfloat16
)
model.eval()
print("BASE MODEL LOADED")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

BASE MODEL LOADED


In [17]:
outlines_model = models.Transformers(model, tokenizer,)
generator = outlines.generate.json(outlines_model, EntryResult)

# Predict

In [None]:
batch_size = 4
predictions = []
for i in tqdm(range(0, len(all_messages), batch_size)):
    batch = all_messages[i:i+batch_size]
    text = tokenizer.apply_chat_template(
        batch,
        tokenize=False,
        add_generation_prompt=True,
    )
    prediction = generator(text)
    predictions.extend(prediction)

100%|██████████| 3/3 [05:04<00:00, 101.62s/it]


In [21]:
predictions

[EntryResult(emotion=EmotionLabel(joy=<RelationshipStatus.na: 'na'>, trust=<RelationshipStatus.na: 'na'>, fear=<RelationshipStatus.na: 'na'>, surprise=<RelationshipStatus.medium: 'medium'>, sadness=<RelationshipStatus.low: 'low'>, disgust=<RelationshipStatus.na: 'na'>, anger=<RelationshipStatus.high: 'high'>, anticipation=<RelationshipStatus.na: 'na'>), reason='Rachel feels angry because confronting the boyfriend revealed the extent of his disrespect towards her friend.'),
 EntryResult(emotion=EmotionLabel(joy=<RelationshipStatus.na: 'na'>, trust=<RelationshipStatus.na: 'na'>, fear=<RelationshipStatus.na: 'na'>, surprise=<RelationshipStatus.na: 'na'>, sadness=<RelationshipStatus.na: 'na'>, disgust=<RelationshipStatus.na: 'na'>, anger=<RelationshipStatus.na: 'na'>, anticipation=<RelationshipStatus.na: 'na'>), reason='Other children show no strong emotional reaction to seeing the animals because it is part of their routine day school activity.'),
 EntryResult(emotion=EmotionLabel(joy=<Re

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report

In [None]:
labels = ["na", "low", "medium", "high"]
emotion_cols = ["joy", "trust", "fear", "surprise", "sadness", "disgust", "anger", "anticipation"]

reports = {}
for col in emotion_cols:
    y_true = df[col].tolist()
    y_pred = [getattr(pred.emotion, col).value for pred in predictions]
    reports[col] = classification_report(
        y_true, 
        y_pred,
        labels=labels,
        output_dict=True,
        zero_division=0
    )
reports

In [None]:
# labels = ["na", "low", "medium", "high"]
# for column in ["joy", "trust", "fear", "surprise", "sadness", "disgust", "anger", "anticipation"]:
#     y_true = df[column].tolist()
#     print(df[column].unique())
#     y_pred = [getattr(pred.emotion, column).value for pred in predictions]
#     print(f"Classification report for {column}:")
#     print(classification_report(y_true, y_pred, labels=labels))

['na' 'low' 'medium' 'high']
{'low', 'medium', 'na', 'high'}
Classification report for joy:
              precision    recall  f1-score   support

          na       0.80      1.00      0.89         4
         low       1.00      0.50      0.67         2
      medium       1.00      0.67      0.80         3
        high       0.50      1.00      0.67         1

    accuracy                           0.80        10
   macro avg       0.82      0.79      0.76        10
weighted avg       0.87      0.80      0.80        10

['na' 'medium' 'high']
{'medium', 'na', 'high'}
Classification report for trust:
              precision    recall  f1-score   support

          na       1.00      1.00      1.00         6
         low       0.00      0.00      0.00         0
      medium       1.00      1.00      1.00         2
        high       1.00      1.00      1.00         2

    accuracy                           1.00        10
   macro avg       0.75      0.75      0.75        10
weighted avg

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize