In [55]:
import os
from pathlib import Path

cwd = Path.cwd()
if cwd.name == "notebooks":
    os.chdir(cwd.parent)

In [56]:
import pandas as pd
from mllm_emotion_classifier.utils import add_fairness_metrics_to_df
from EmoBox.EmoBox import EmoDataset

In [57]:
sensitive_attr_dict = {
    'iemocap': ['gender'],
    'cremad': ['gender', 'age', 'ethnicity', 'race'],
    'emovdb': ['gender'],
    'tess': ['agegroup'],
    'ravdess': ['gender'],
    'esd': ['gender'],
    'meld': ['gender'],
}

In [58]:
hparam = 'temperature' # or 'top_p'
assert hparam in ['temperature', 'top_p'], "hparam must be either 'temperature' or 'top_p'"

dataset = 'emovdb' # iemocap, meld, cremad, ravdess, emovdb
fold = None # Set to an integer fold number if needed, else None to aggregate all folds
sensitive_attrs = sensitive_attr_dict[dataset] # gender, age, ethnicity, race
model = 'salmonn-7b' # qwen2-audio-instruct, audio-flamingo-3, voxtral-mini, salmonn-7b, 

metadata_dir = Path('EmoBox/data/')
dataset_path = metadata_dir / dataset
n_folds = len([d for d in dataset_path.iterdir() if d.is_dir() and d.name.startswith("fold_")])
out_dir = Path('outputs-2') / "temperature_runs" if hparam == 'temperature' else Path('outputs-2') / "topp_runs"

test = EmoDataset(dataset, './', metadata_dir, fold=1, split="test")
emotions = set(test.label_map.values())

if fold is None:
    dfs = []
    for f in range(1, n_folds + 1):
        results_csv = out_dir / model / dataset / f'fold_{f}.csv'
        df_fold = pd.read_csv(results_csv)
        dfs.append(df_fold)
    df = pd.concat(dfs, ignore_index=True)
else:
    results_csv = out_dir / model / dataset / f'fold_{fold}.csv'
    df = pd.read_csv(results_csv)

print(len(df), "rows")
df.head(5)

since there is no official valid data, use random split for train valid split, with a ratio of [80, 20]
load in 5168 samples, only 5168 exists in data dir EmoBox/data
load in 1719 samples, only 1719 exists in data dir EmoBox/data
Num. training samples 5168
Num. valid samples 0
Num. test samples 1719
Using label_map {'Amused': 'Amused', 'Sleepy': 'Sleepy', 'Angry': 'Angry', 'Disgust': 'Disgust', 'Neutral': 'Neutral'}
10 rows


Unnamed: 0,run,dataset,fold,model,prompt,temperature,valid_rate,global_f1_macro,global_f1_weighted,global_accuracy_unweighted,...,language_Neutral_statistical_parity,language_Sleepy_statistical_parity,language_statistical_parity,language_Amused_equal_opportunity,language_Angry_equal_opportunity,language_Disgust_equal_opportunity,language_Neutral_equal_opportunity,language_Sleepy_equal_opportunity,language_equal_opportunity,language_overall_accuracy_equality
0,0,emovdb,1,salmonn-7b,user_labels,1.0,1.0,0.2426,0.2598,0.2687,...,0.0,0.0,0.0,,,,,,,
1,1,emovdb,1,salmonn-7b,user_labels,1.0,1.0,0.1864,0.1895,0.2608,...,0.0,0.0,0.0,,,,,,,
2,2,emovdb,1,salmonn-7b,user_labels,1.0,1.0,0.0776,0.0643,0.206,...,0.0,0.0,0.0,,,,,,,
3,3,emovdb,1,salmonn-7b,user_labels,1.0,1.0,0.0951,0.0825,0.2104,...,0.0,0.0,0.0,,,,,,,
4,4,emovdb,1,salmonn-7b,user_labels,1.0,1.0,0.189,0.1934,0.2701,...,0.0,0.0,0.0,,,,,,,


In [59]:
test[0]

{'key': 'emovdb-sam-Amused-0384',
 'audio': array([-0.00119863,  0.00034247,  0.0015411 , ...,  0.00291096,
         0.00308219,  0.00308219], shape=(131361,), dtype=float32),
 'label': 'Amused',
 'gender': 'Male',
 'language': 'English'}

In [60]:
# Select columns
cols = [hparam, 'global_f1_macro', 'global_accuracy_unweighted'] + \
       [f"{attr}_{metric}" for attr in sensitive_attrs 
        for metric in ['statistical_parity', 'equal_opportunity', 'overall_accuracy_equality']]

# Group and compute mean and std
grouped_stats = df[cols].groupby([hparam]).agg(['mean', 'std']).reset_index()

# Create a formatted table with mean ± std
grouped = grouped_stats[[hparam]].copy()
for col in cols:
    if col == hparam:
        continue
    mean_vals = (grouped_stats[(col, 'mean')] * 100).round(2)
    std_vals = (grouped_stats[(col, 'std')] * 100).round(2)
    grouped[col] = mean_vals.astype(str) + ' ± ' + std_vals.astype(str)

# For finding best row, use the mean values
best_idx = (grouped_stats[('global_f1_macro', 'mean')]).idxmax()
best_row = grouped.loc[best_idx]

display(best_row)

temperature                                    1.0
global_f1_macro                       15.69 ± 6.24
global_accuracy_unweighted            24.22 ± 3.14
gender_statistical_parity              1.33 ± 1.07
gender_equal_opportunity               4.59 ± 4.28
gender_overall_accuracy_equality       7.89 ± 6.27
Name: 0, dtype: object