# Load Fine-Tuned Model

In [None]:
import nbimporter
from finetune_mllm import EmpatheticMLLM

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = EmpatheticMLLM()
model.load_state_dict(torch.load("model.pth"))
model.to(device)
model.eval()

# Generate Responses for Test Set

In [None]:
from finetune_mllm import prepare_split, MultimodalMELD
from torch.utils.data import DataLoader

test_data = prepare_split('dev')
test_dataset = MultimodalMELD(test_data)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)

In [None]:
from tqdm import tqdm
from transformers import GenerationConfig

generated_responses = []
target_responses = []
with torch.no_grad():
    for batch in tqdm(test_dataloader):
        x = {
            'text': batch['text'],
            'audio': batch['audio'],
            'video': batch['video']
        }

        generation_config = GenerationConfig(
            max_new_tokens=100,
            do_sample=True,
            top_k=50,
            top_p=0.95
        )

        responses = model.generate(x, generation_config)
        print(responses)
        generated_responses.append(responses)
        target_responses.append(batch['target_response'])

# Calculate BERTScore

In [None]:
from bert_score import score

P, R, F1 = score(generated_responses, target_responses, lang="en")

print(f"PBERT: {P.mean():.4f}")
print(f"RBERT: {R.mean():.4f}")
print(f"FBERT: {F1.mean():.4f}")

# Calculate Perplexity

In [None]:
# calculate PPL on ground-truth target responses
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

ppl_list = []
for response in target_responses:
    input_ids = tokenizer.encode(response, return_tensors='pt').to(device)
    with torch.no_grad():
        outputs = model.llm(input_ids, labels=input_ids)
    loss = outputs.loss
    ppl = math.exp(loss.item())
    ppl_list.append(ppl)
print(f"PPL: {sum(ppl_list) / len(ppl_list):.4f}")

# Calculate Diversity

In [None]:
from nltk import ngrams

def compute_dist_n(responses, n):
    all_ngrams = []
    for response in responses:
        tokens = response.split()
        all_ngrams.extend(ngrams(tokens, n))
    total = len(all_ngrams)
    unique = len(set(all_ngrams))
    return unique / total if total > 0 else 0

dist1 = compute_dist_n(generated_responses, 1)
dist2 = compute_dist_n(generated_responses, 2)

print(f"Dist-1: {dist1:.4f}")
print(f"Dist-2: {dist2:.4f}")