In [1]:
import os
from pathlib import Path

cwd = Path.cwd()
if cwd.name == "notebooks":
    os.chdir(cwd.parent)

#### Load Dataset

In [2]:
from EmoBox.EmoBox import EmoDataset, EmoEval

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_name = "salmonn-7b" # qwen2-audio-instruct, audio-flamingo-3, voxtral-mini, salmonn-7b
dataset = "meld"
fold = 1  # different datasets have different number of folds, which can be find in data/
user_data_dir = "./"
meta_data_dir = "EmoBox/data/"

if model_name in ["voxtral-mini"]:
    audio_format = "base64"
elif model_name in ["salmonn-7b"]:
    audio_format = "path"
else:
    audio_format = "array"

train = EmoDataset(dataset, user_data_dir, meta_data_dir, fold=fold, split="train", audio_format=audio_format)
test = EmoDataset(dataset, user_data_dir, meta_data_dir, fold=fold, split="test", audio_format=audio_format)

using official valid data in EmoBox/data/meld/fold_1/meld_valid_fold_1.jsonl
load in 9988 samples, only 9988 exists in data dir EmoBox/data/
load in 2610 samples, only 2156 exists in data dir EmoBox/data/
load in 1108 samples, only 952 exists in data dir EmoBox/data/
Num. training samples 9988
Num. valid samples 952
Num. test samples 2156
Using label_map {'neutral': 'Neutral', 'joy': 'Happy', 'sadness': 'Sad', 'surprise': 'Surprise', 'disgust': 'Disgust', 'anger': 'Angry', 'fear': 'Fear'}
using official valid data in EmoBox/data/meld/fold_1/meld_valid_fold_1.jsonl
load in 9988 samples, only 9988 exists in data dir EmoBox/data/
load in 2610 samples, only 2156 exists in data dir EmoBox/data/
load in 1108 samples, only 952 exists in data dir EmoBox/data/
Num. training samples 9988
Num. valid samples 952
Num. test samples 2156
Using label_map {'neutral': 'Neutral', 'joy': 'Happy', 'sadness': 'Sad', 'surprise': 'Surprise', 'disgust': 'Disgust', 'anger': 'Angry', 'fear': 'Fear'}


In [4]:
sample = test[0]
sample

{'key': 'meld-dia0_utt1-test',
 'audio': 'downloads/meld/output_repeated_splits_test/dia0_utt1.mp4',
 'label': 'Angry',
 'gender': 'Female',
 'age': 27,
 'language': 'English'}

In [5]:
test.label_map.values()

dict_values(['Neutral', 'Happy', 'Sad', 'Surprise', 'Disgust', 'Angry', 'Fear'])

In [6]:
from collections import Counter
labels =  [data['label'] for data in test]
Counter(labels)

Counter({'Neutral': 994,
         'Happy': 337,
         'Angry': 303,
         'Surprise': 238,
         'Sad': 173,
         'Disgust': 65,
         'Fear': 46})

#### Load Model

In [7]:
import torch
from mllm_emotion_classifier.models import ModelFactory

device = "cuda" if torch.cuda.is_available() else "cpu"
model = ModelFactory.create(
    name=model_name,
    do_sample=True,
    temperature=1.0,
    top_p=0.9,
    prompt_name="user_labels",
    class_labels=set(train.label_map.values()),
    device=device,

    # low_resource=True,
    # lora_alpha=28,
)

  beats_checkpoint = torch.load(self.beats_ckpt, map_location='cpu')
  WeightNorm.apply(module, name, dim)
  return torch.load(checkpoint_file, map_location="cpu")
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████| 2/2 [00:11<00:00,  5.74s/it]
  ckpt_dict = torch.load(ckpt)['model']


In [8]:
data_loader = torch.utils.data.DataLoader(
    dataset=test,
    batch_size=4,
    num_workers=4,
    pin_memory=True,
    drop_last=False,
    collate_fn=model.collate_fn
)

In [9]:
batch = next(iter(data_loader))
inputs, _ = batch
# inputs.keys()

In [10]:
# inputs['attention_mask'].shape

In [11]:
from tqdm import tqdm
predictions, labels = [], []
i = 0
for inputs, lbl in tqdm(data_loader, total=len(data_loader)):
    # inputs = {k: v.to(model.device) for k, v in inputs.items()}
    preds = model.predict(inputs)
    predictions.extend(preds)
    labels.extend(lbl)
    i += 1
    if i == 100: break

  0%|                                                                                                                  | 0/539 [00:00<?, ?it/s]

  8%|████████▌                                                                                                | 44/539 [03:35<40:19,  4.89s/it]


KeyboardInterrupt: 

#### Evaluation

In [12]:
from mllm_emotion_classifier.evaluate import Evaluator

evaluator = Evaluator()
evaluator.evaluate(model, data_loader)

  valid_indices = [i for i, p in enumerate(self.y_pred) if p is not "Unknown"]



Evaluating salmonn on iemocap


Inference: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 272/272 [13:07<00:00,  2.89s/it]


{'timestamp': '2026-01-25 14:00:07',
 'dataset': 'iemocap',
 'model_name': 'salmonn',
 'fold': None,
 'num_samples': 1085,
 'valid_rate': 1.0,
 'class_labels': ['Happy', 'Angry', 'Sad', 'Neutral'],
 'metrics': {'global': {'f1_macro': 0.151,
   'f1_weighted': 0.1431,
   'accuracy_unweighted': 0.2749,
   'accuracy_weighted': 0.2756,
   'precision_macro': 0.5031,
   'precision_weighted': 0.4097,
   'recall_macro': 0.2749,
   'recall_weighted': 0.2756},
  'classwise': {'accuracy': {'Angry': 0.7982,
    'Happy': 0.2793,
    'Neutral': 0.6452,
    'Sad': 0.8286},
   'false_positive_rate': {'Angry': 0.0012,
    'Happy': 0.969,
    'Neutral': 0.0014,
    'Sad': 0.0022},
   'false_negative_rate': {'Angry': 0.952,
    'Happy': 0.0,
    'Neutral': 1.0,
    'Sad': 0.9485},
   'true_positive_rate': {'Angry': 0.048,
    'Happy': 1.0,
    'Neutral': 0.0,
    'Sad': 0.0515},
   'true_negative_rate': {'Angry': 0.9988,
    'Happy': 0.031,
    'Neutral': 0.9986,
    'Sad': 0.9978},
   'positive_predictiv

In [13]:
evaluator.results['metrics']['global']

{'f1_macro': 0.151,
 'f1_weighted': 0.1431,
 'accuracy_unweighted': 0.2749,
 'accuracy_weighted': 0.2756,
 'precision_macro': 0.5031,
 'precision_weighted': 0.4097,
 'recall_macro': 0.2749,
 'recall_weighted': 0.2756}