In [None]:
# ────────────────────────────────────────────────────────────
# Cell 1: Imports & Config
# ────────────────────────────────────────────────────────────
import os, csv, random
import torch, torchaudio, whisper
import numpy as np
from tqdm.notebook import tqdm
from transformers import (
    AutoFeatureExtractor,
    AutoModelForAudioClassification,
    AutoTokenizer,
    AutoModelForSequenceClassification,
)

# ── Paths & hyperparams ────────────────────────────────────
LABELS_CSV  = "/projects/pdd/IS2025_Podcast_Challenge/Labels/labels_consensus.csv"
AUDIO_DIR   = "/projects/pdd/IS2025_Podcast_Challenge/Audios"
TRANS_DIR   = "/projects/pdd/IS2025_Podcast_Challenge/Transcripts"
WAVLM_DIR   = "wavlm_finetuned_bs16_lr2e-05_wd0.05_ep50_ga2_Lilit"
DEBERTA_DIR = "saved_deberta_model/DeBERTa_v3_Large_Lr1e-6_gradient6_batch16_stopEpoch9"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
ALPHA  = 0.6       # weight for audio
BETA   = 1 - ALPHA # weight for text

EMO2ID = {"A":0,"S":1,"H":2,"U":3,"F":4,"D":5,"C":6,"N":7,"O":8,"X":9}
ID2EMO = {v:k for k,v in EMO2ID.items()}


In [None]:
# ────────────────────────────────────────────────────────────
# Cell 2: Load Models (once)
# ────────────────────────────────────────────────────────────
print("Loading models...")
whisper_m = whisper.load_model("base").to(DEVICE).eval()
wavlm_fe  = AutoFeatureExtractor.from_pretrained(WAVLM_DIR)
wavlm_m   = AutoModelForAudioClassification.from_pretrained(WAVLM_DIR).to(DEVICE).eval()
tok       = AutoTokenizer.from_pretrained(DEBERTA_DIR, use_fast=False)
deberta_m = AutoModelForSequenceClassification.from_pretrained(DEBERTA_DIR).to(DEVICE).eval()

for m in (whisper_m, wavlm_m, deberta_m):
    m.requires_grad_(False)

softmax = torch.nn.Softmax(dim=-1)
print("✅ Models ready")


In [None]:
# ────────────────────────────────────────────────────────────
# Cell 3: Define helper functions
# ────────────────────────────────────────────────────────────
def audio_probs(wav_path):
    wav, sr = torchaudio.load(wav_path)
    if sr!=16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)
    feats = wavlm_fe(wav.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
    logits = wavlm_m(**{k:v.to(DEVICE) for k,v in feats.items()}).logits
    return softmax(logits).cpu().numpy().squeeze()

def text_probs(text):
    inputs = tok(text, return_tensors="pt", truncation=True, padding=True)
    logits = deberta_m(**{k:v.to(DEVICE) for k,v in inputs.items()}).logits
    return softmax(logits).cpu().numpy().squeeze()

def fuse(p_a, p_t):
    return ALPHA*p_a + BETA*p_t


In [None]:
# ────────────────────────────────────────────────────────────
# Cell 3: Define helper functions
# ────────────────────────────────────────────────────────────
def audio_probs(wav_path):
    wav, sr = torchaudio.load(wav_path)
    if sr!=16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)
    feats = wavlm_fe(wav.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
    logits = wavlm_m(**{k:v.to(DEVICE) for k,v in feats.items()}).logits
    return softmax(logits).cpu().numpy().squeeze()

@torch.inference_mode()
def text_probs(text):
    inputs = tok(text, return_tensors="pt", truncation=True, padding=True)
    logits = deberta_m(**{k:v.to(DEVICE) for k,v in inputs.items()}).logits
    return softmax(logits).cpu().numpy().squeeze()

def fuse(p_a, p_t):
    return ALPHA*p_a + BETA*p_t


In [None]:
# ────────────────────────────────────────────────────────────
# Cell 4: Load label rows (Train+Development)
# ────────────────────────────────────────────────────────────
rows = []
with open(LABELS_CSV) as f:
    for r in csv.DictReader(f):
        if r["Split_Set"].strip() in ("Train","Development"):
            rows.append(r)

print(f"Found {len(rows)} total samples")
random.shuffle(rows)
# (optional) you can subset for a quicker demo, e.g. rows = rows[:500]


In [None]:
# ────────────────────────────────────────────────────────────
# Cell 5: Demo on first 3 samples
# ────────────────────────────────────────────────────────────
for i, row in enumerate(rows[:3], 1):
    fn = row["FileName"]
    wav = os.path.join(AUDIO_DIR, fn)
    txt = os.path.join(TRANS_DIR, fn.replace(".wav",".txt"))
    # get transcript
    if os.path.exists(txt):
        transcription = open(txt).read().strip()
    else:
        transcription = whisper_m.transcribe(wav)["text"]
    # compute
    p_a = audio_probs(wav)
    p_t = text_probs(transcription)
    p_f = fuse(p_a, p_t)
    # report
    print(f"\n▶️ Sample {i}: {fn}")
    print("  • True label   :", row["EmoClass"])
    print("  • Audio probs  :", np.round(p_a,3))
    print("  • Text probs   :", np.round(p_t,3))
    print("  • Fused probs  :", np.round(p_f,3))
    print("  • Predicted    :", ID2EMO[int(p_f.argmax())])


In [None]:
# ────────────────────────────────────────────────────────────
# Cell 6: Full evaluation with periodic progress reports
# ────────────────────────────────────────────────────────────
correct = 0
total   = len(rows)

for idx, row in enumerate(rows, 1):
    fn   = row["FileName"]
    wav  = os.path.join(AUDIO_DIR, fn)
    txt  = os.path.join(TRANS_DIR, fn.replace(".wav", ".txt"))
    text = open(txt).read().strip() if os.path.exists(txt) \
           else whisper_model.transcribe(wav)["text"]
    
    p_a  = audio_probs(wav)
    p_t  = text_probs(text)
    p_f  = ALPHA * p_a + BETA * p_t

    pred = int(np.argmax(p_f))
    true = EMO2ID[row["EmoClass"].strip()]
    correct += (pred == true)

    # every 100 files (or at the end) print a mini‐report
    if idx % 100 == 0 or idx == total:
        acc_so_far = correct / idx
        print(f"→ Processed {idx}/{total} files — "
              f"current accuracy: {acc_so_far:.4f}")

# final accuracy
final_acc = correct / total
print(f"\n✅ Final Weighted Fusion Accuracy = {final_acc:.4f} "
      f"({correct}/{total})")



In [None]:
# Demonstration

In [None]:
# ════════════════════════════════════════════════════════════════
# Cell 1: Imports & config
# ════════════════════════════════════════════════════════════════
import os, random, torch, torchaudio, numpy as np
import whisper
from IPython.display import Audio, display
import matplotlib.pyplot as plt

from transformers import (
    AutoFeatureExtractor, AutoModelForAudioClassification,
    AutoTokenizer,       AutoModelForSequenceClassification
)

# ── Paths & device ───────────────────────────────────────────────
AUDIO_DIR   = "/projects/pdd/IS2025_Podcast_Challenge/Audios"
WAVLM_DIR   = "wavlm_finetuned_bs16_lr1e-06_wd0.05_ep50_ga2_Lilit"
DEBERTA_DIR = "saved_deberta_model/FT7Large_epoch_9"
FUSION_PT   = "fusion_head_10000.pt"    # your trained MLP weights
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"

# ── Class names ─────────────────────────────────────────────────
EMOTIONS = [
    "A-Anger","S-Sad","H-Happy","U-Surprise","F-Fear",
    "D-Disgust","C-Contempt","N-Neutral","O-Other","X-Unknown"
]


In [None]:
# ════════════════════════════════════════════════════════════════
# Cell 2: Load models + define helpers
# ════════════════════════════════════════════════════════════════
# 1) Whisper, WavLM & DeBERTa
print("Loading models…")
whisper_model = whisper.load_model("base").to(DEVICE).eval()

wavlm_fe  = AutoFeatureExtractor.from_pretrained(WAVLM_DIR)
wavlm_mod = AutoModelForAudioClassification.from_pretrained(WAVLM_DIR)\
               .to(DEVICE).eval()

tok       = AutoTokenizer.from_pretrained(DEBERTA_DIR, use_fast=False)
deberta   = AutoModelForSequenceClassification.from_pretrained(DEBERTA_DIR)\
               .to(DEVICE).eval()

for m in (whisper_model, wavlm_mod, deberta):
    m.requires_grad_(False)
print("  ✓ experts ready")

# 2) FusionHead definition + load weights
import torch.nn as nn
class FusionHead(nn.Module):
    def __init__(self, hidden: int = 32, out_dim: int = 10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(20, hidden), nn.ReLU(),
            nn.Linear(hidden, out_dim)
        )
    def forward(self, x): return self.net(x)

fusion = FusionHead().to(DEVICE)
fusion.load_state_dict(torch.load(FUSION_PT, map_location=DEVICE))
fusion.eval()
print("  ✓ fusion head ready\n")

softmax = torch.nn.Softmax(dim=-1)

# 3) Probability helpers
@torch.inference_mode()
def audio_probs(path):
    wav, sr = torchaudio.load(path)
    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)
    feats  = wavlm_fe(wav.squeeze().numpy(), sampling_rate=16000,
                      return_tensors="pt")
    logits = wavlm_mod(**{k:v.to(DEVICE) for k,v in feats.items()}).logits
    return softmax(logits).cpu().squeeze().numpy()

@torch.inference_mode()
def text_probs(text):
    toks   = tok(text, return_tensors="pt", truncation=True, padding=True)
    logits = deberta(**{k:v.to(DEVICE) for k,v in toks.items()}).logits
    return softmax(logits).cpu().squeeze().numpy()


In [None]:
# ════════════════════════════════════════════════════════════════
# Cell 3: Demonstration on a random file
# ════════════════════════════════════════════════════════════════
def analyse_random():
    # pick & load
    files = [f for f in os.listdir(AUDIO_DIR) if f.endswith(".wav")]
    if not files:
        print("No .wav files found!"); return
    fn   = random.choice(files)
    path = os.path.join(AUDIO_DIR, fn)

    # play + plot
    wav, sr = torchaudio.load(path)
    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000); sr = 16000

    print(f"\n🎧 Listening to: {fn}")
    display(Audio(wav.numpy(), rate=sr))
    plt.figure(figsize=(10,3))
    plt.plot(wav[0].numpy()); plt.title("Waveform"); plt.show()

    # get branch outputs
    p_a        = audio_probs(path)
    transcript = whisper_model.transcribe(path)["text"]
    p_t        = text_probs(transcript)

    # fuse via MLP
    with torch.inference_mode():
        x       = torch.tensor(
                    np.concatenate([p_a, p_t]),
                    dtype=torch.float32
                  ).unsqueeze(0).to(DEVICE)
        logits  = fusion(x)
        p_fused = softmax(logits).cpu().squeeze().numpy()

    # print results
    print(f"Transcript : {transcript[:200]}{'…' if len(transcript)>200 else ''}\n")
    print(f"WavLM   → {EMOTIONS[p_a.argmax()]}   |  probs: {np.round(p_a,4)}")
    print(f"DeBERTa → {EMOTIONS[p_t.argmax()]}   |  probs: {np.round(p_t,4)}")
    print(f"FUSED   → {EMOTIONS[p_fused.argmax()]}   |  probs: {np.round(p_fused,4)}")

# Run the demo
analyse_random()
