In [None]:
import sys
import os

from __future__ import annotations
from typing import Dict, List, Tuple, Union, Optional
import os, sys, warnings, copy, random
from collections import 
from pathlib import Path
from tqdm.auto import tqdm
from functools import lru_cache
from itertools import combinations

import json
import yaml
from omegaconf import OmegaConf, DictConfig

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torchaudio
from scipy.signal import medfilt


# ignore torchaudio warnings
warnings.filterwarnings("ignore", message=".*torchaudio.load_with_torchcodec.*")
warnings.filterwarnings("ignore", message=".*StreamingMediaDecoder.*")
warnings.filterwarnings("ignore", message=".*deprecated.*")

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

sys.path.append(os.path.abspath(".."))

from personal_VAD.utils import load_model as load_pvad_model
from word_count_estimator.utils import load_model as load_wce_model

In [None]:
#@title utils
SEED = 42
DEFAULT_SR = 16000

def second2sample(second, sr=DEFAULT_SR): return int(second*sr)
def sample2second(sample, sr=DEFAULT_SR): return int(sample/sr)
s2sam = second2sample
sam2s = sample2second

def mono_resample(wav: torch.Tensor, sr: int=16000, target_sr: int = DEFAULT_SR) -> torch.Tensor:
    """
    [C,T] or [T] -> [1,T]
    """
    if wav.dim() == 1: wav = wav.unsqueeze(0)
    elif wav.dim() == 2 and wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True)
    if sr != target_sr: wav = torchaudio.functional.resample(wav, sr, target_sr)
    return wav

@lru_cache(maxsize=40)
def load_wav(paths, s=None, e=None, dur=None):
    if isinstance(paths, str): paths = [paths]
    wavs = [mono_resample(*torchaudio.load(p)) for p in paths]
    wav = torch.cat(wavs, dim=1)
    if s is not None: wav = wav[:,s2sam(s):s2sam(e)]
    if dur is not None:
        offset = random.randint(0, wav.size(1)-s2sam(dur))
        wav = wav[:,offset:offset+s2sam(dur)]
    return wav

class Extractor:
    def __init__(self, apply_cmvn=True, no_batch_time_first=False, **feature_args):
        if feature_args==None: feature_args={}
        basic_args = {'sample_rate':16000, 'n_fft':400, 'n_mels':24, 'win_length':400, 'hop_length':160}
        basic_args.update(feature_args)
        self.extractor = torchaudio.transforms.MelSpectrogram(**basic_args)
        self.apply_cmvn = apply_cmvn
        self.no_batch_time_first = no_batch_time_first

    def __call__(self, wav):
        return self.extract(wav)

    def extract(self, wav:list|str):
        if isinstance(wav,str): wav, sr = torchaudio.load(wav)
        spec = self.extractor(wav)  # (1,F,T)
        spec = torch.log10(spec + 1e-6)
        if self.apply_cmvn:
            mean = spec.mean(dim=2, keepdim=True)
            std = spec.std(dim=2, keepdim=True)
            spec = (spec - mean) / (std + 1e-9)
        if self.no_batch_time_first:
            spec = spec.squeeze(0).transpose(0,1)  # (T,F)
        return spec


# -----------------------------------------------------------------------------
# Evaluation Metrics & Helper Utils
# -----------------------------------------------------------------------------
def compute_metrics(y_true, y_pred):
    tp = int(np.sum((y_true == 1) & (y_pred == 1)))
    fp = int(np.sum((y_true == 0) & (y_pred == 1)))
    fn = int(np.sum((y_true == 1) & (y_pred == 0)))
    tn = int(np.sum((y_true == 0) & (y_pred == 0)))

    eps = 1e-12
    prec   = tp / (tp + fp + eps)
    rec    = tp / (tp + fn + eps)
    f1     = 2 * prec * rec / (prec + rec + eps)
    iou    = tp / (tp + fp + fn + eps)
    miss   = fn / (tp + fn + eps)
    fa     = fp / (tn + fp + eps)

    metrics = {
        "TPf": tp, "FPf": fp, "FNf": fn, "TNf": tn,
        "precision": prec, "recall": rec, "f1": f1, "iou": iou,
        "miss_rate": miss, "fa_rate": fa,
    }
    return metrics


def choose_enroll(df, min_dur):
    df = df.copy()
    valid_df = df[df["duration"] >= min_dur]

    if not valid_df.empty:
        sampled = valid_df.sample(1)
        sampled["_original_idxs"] = sampled.index.tolist()
        sampled["_st_et_list"] = sampled.apply(lambda row: [(row["start_time"], row["end_time"])], axis=1)
        return sampled

    df_sorted = df.sort_values("duration", ascending=False)
    total_dur = 0
    selected_idxs = []
    st_et_list = []
    for idx, row in df_sorted.iterrows():
        selected_idxs.append(idx)
        st_et_list.append((row["start_time"], row["end_time"]))
        total_dur += row["duration"]
        if total_dur >= min_dur: break
    combined_row = df.loc[selected_idxs].iloc[0].copy()
    combined_row["audio_path"] = tuple(df.loc[selected_idxs]["audio_path"].tolist())
    combined_row["duration"] = total_dur
    combined_row["_st_et_list"] = st_et_list
    combined_row["_original_idxs"] = selected_idxs
    return pd.DataFrame([combined_row])


def segs2label(segs, max_t):
    label = np.zeros(max_t, dtype=int)
    for s,e in segs: label[s:e]=1
    return label


In [None]:
# -----------------------------------------------------------------------------
# PVAD (Personal Voice Activity Detection)
# -----------------------------------------------------------------------------

def PVAD(key_wav, target_wav, pvad_model,
         threshold=0.5,
         median_filter_size=5,
         ns_min_length=300,
         s_min_length=200,
         apply_cmvn=False,
         debug_mode=False
         ):
    device = next(pvad_model.parameters()).device

    extractor = Extractor(apply_cmvn=apply_cmvn)
    key_feat = extractor(key_wav).to(device)
    target_feat = extractor(target_wav).to(device)

    with torch.no_grad():
        logit, _, _ = pvad_model(key_feat, target_feat)
        scores = torch.sigmoid(logit.squeeze(0)).cpu().detach().numpy()

    if debug_mode:
        target_feat = target_feat.squeeze(0).cpu().detach().numpy()
        #print(*scores.tolist())
        plt.figure(figsize=(20, 4))
        plt.imshow(target_feat, origin='lower', aspect='auto', cmap='magma')
        plt.colorbar(label='Amplitude (dB)')
        plt.xlabel('Time Frame')
        plt.ylabel('Mel Bin')
        plt.plot(np.arange(target_feat.shape[1]), scores*24, color='cyan', linewidth=2, label='Score')
        plt.legend(loc='upper right')
        plt.title('MelSpectrogram with Score Overlay')
        plt.show()

    binary_mask = scores>=threshold

    # median filtering
    binary_mask = binary_mask.astype(int)
    smoothed_mask = medfilt(binary_mask, kernel_size=median_filter_size)

    # get segments (in ms)
    smoothed_mask = np.concatenate(([0], smoothed_mask, [0]))
    speech_indices = np.where(np.diff(smoothed_mask) != 0)[0]
    segments = []
    for i in range(0, len(speech_indices), 2):
        start_frame = int(speech_indices[i]*10)
        end_frame = int(speech_indices[i+1]*10)
        segments.append([start_frame, end_frame])
    if debug_mode:
      print("raw segments:",segments)

    # delete short non-speech intervals.
    merged_segments = []
    current_segment = segments[0] if segments else None
    for next_segment in segments[1:]:
        nonspeech_duration = next_segment[0] - current_segment[1]
        if nonspeech_duration < ns_min_length:
            current_segment[1] = next_segment[1]
        else:
            merged_segments.append(current_segment)
            current_segment = next_segment
    if current_segment: merged_segments.append(current_segment)
    if debug_mode:
      print("merged segments", merged_segments)

    # delete short speech segments.
    valid_segments = []
    for segment in merged_segments:
        segment_duration = segment[1] - segment[0]
        if segment_duration > s_min_length:
            valid_segments.append(segment)

    # convert ms to sample idx
    ms2sample = lambda x: x*16
    segments_sampleidx = []
    for s,e in valid_segments: segments_sampleidx.append([ms2sample(s),ms2sample(e)])

    return segments_sampleidx

# -----------------------------------------------------------------------------
# WCE (Word Count Estimator) Wrapper
# -----------------------------------------------------------------------------

class WCE:
    def __init__(
        self,
        wce_model,
        apply_trim: bool = True,
        apply_cmvn: bool = True,
    ):

        self.apply_cmvn = bool(apply_cmvn)
        self.apply_trim = bool(apply_trim)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.model = wce_model.to(self.device).eval()
        self.extractor = Extractor(apply_cmvn=self.apply_cmvn, no_batch_time_first=True)

    @torch.no_grad()
    def _to_mono_1ch_16k(self, wav: torch.Tensor, sr: int=16000, target_sr: int = 16000) -> torch.Tensor:
        """
        입력: wav [C,T] 또는 [T], sr (샘플레이트)
        출력: [1,T] @ target_sr (16kHz 기본)
        """
        if wav.dim() == 1:
            wav = wav.unsqueeze(0)  # [1,T]
        elif wav.dim() == 2:
            if wav.size(0) > 1:
                wav = wav.mean(dim=0, keepdim=True)  # [1,T]
        else:
            raise ValueError(f"Unexpected wav shape: {tuple(wav.shape)}")
        if sr != target_sr:
            wav = torchaudio.functional.resample(wav, sr, target_sr)

        return wav

    @torch.no_grad()
    def compute_pred_count(self, output) -> int:
        """
        syllnet: logits -> sigmoid -> 0.5 threshold count
        wce    : dict 출력에서 'count' 사용 (모델 구현에 맞게 조정)
        """
        if self.model_type == "sylnet":
            # 기대 형태: logits [B, U]
            logits = output
            pred_cnt = (torch.sigmoid(logits) >= 0.5).float().sum(dim=1)  # [B]
            return int(pred_cnt.item())
        else:  # wce
            # 기대 형태: {"count": Tensor([B, 1]) 혹은 [B]}
            cnt = output["count"]
            if isinstance(cnt, (list, tuple)):
                cnt = torch.as_tensor(cnt)
            if cnt.dim() > 1:
                cnt = cnt.squeeze()
            return int(cnt[0].item() if cnt.numel() > 1 else cnt.item())

    @torch.no_grad()
    def word_extract(self, wav: Union[str, Path, torch.Tensor]):
        """
        단일 wav -> Log-Mel(+CMVN) -> 모델 추론 -> count(int)
        """
        if isinstance(wav, (str, Path)):
            wav, sr = torchaudio.load(str(wav))  # [C,T]
            wav = self._to_mono_1ch_16k(wav, sr)      # [T]
            if self.apply_trim:
              wav = self.trim_silence(wav, trim_db=20)
        elif isinstance(wav, torch.Tensor):
            sr = 16000  # Tensor 들어오면 sr을 외부에서 보장한다고 가정
            wav = self._to_mono_1ch_16k(wav, sr)      # [T]
            if self.apply_trim:
              wav = self.trim_silence(wav, trim_db=20)
        else:
            raise TypeError(f"Unsupported type: {type(wav)}")

        spec = self.extractor(wav)       # [T, n_mels]

        x_pad = spec.unsqueeze(0).to(self.device)        # [1, T, n_mels]
        lengths = torch.tensor([x_pad.size(1)], device=self.device, dtype=torch.long)

        output = self.model(x_pad, lengths)

        return self.compute_pred_count(output)


In [None]:
# Main Execution and Evaluation Function
def run_evaluation_pipeline(key_audio_path, target_audio_path, 
                            pvad_model_config, pvad_post_processing,
                            wce_model_config, gt_word_count=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # --- [Step 1] Load PVAD Model ---
    pvad_model = load_pvad_model(
        model_name=pvad_model_config["model"],
        model_args=pvad_model_config["model_args"],
        requires_grad=False,
        device=device
    )
    pvad_model.eval()

    # --- [Step 2] Load WCE Model ---
    wce_net = load_wce_model(
        model_name=wce_model_config["model"],
        model_args=wce_model_config["model_args"],
        requires_grad=False,
        device=device
    )
    wce_model = WCE(wce_net, apply_trim=False, apply_cmvn=True)

    # --- [Step 3] Load Audio and Enroll Speaker ---
    # Use load_wav if defined, otherwise fallback to torchaudio
    try:
        key_wav = load_wav(key_audio_path).to(device)     # Enrollment voice
        target_wav = load_wav(target_audio_path).to(device) # Analysis target audio
    except NameError:
        key_wav, _ = torchaudio.load(key_audio_path)
        key_wav = mono_resample(key_wav).to(device)
        target_wav, _ = torchaudio.load(target_audio_path)
        target_wav = mono_resample(target_wav).to(device)

    # --- [Step 4] Run PVAD (Extract Specific Speaker Segments) ---
    # Result: segments_sampleidx = [[start, end], ...]
    pred_segments = PVAD(
        key_wav, 
        target_wav, 
        pvad_model,
        **pvad_post_processing.get("PVAD_post_args", {}) 
    )

    # --- [Step 5] Run WCE and Evaluate ---
    total_pred_words = 0
    
    for s, e in pred_segments:
        seg_wav = target_wav[:, s:e].squeeze(0) # [T]
        if seg_wav.size(0) > 160: # Minimum length check
            # Estimate word count using WCE class method
            count = wce_model.word_extract(seg_wav)
            total_pred_words += count

    # Print Results and Evaluate
    print(f"--- Analysis Results ---")
    print(f"Detected Segments: {len(pred_segments)}")
    print(f"Estimated Total Word Count: {total_pred_words:.2f}")
    
    if gt_word_count is not None:
        error = abs(gt_word_count - total_pred_words) / max(1, gt_word_count)
        print(f"Ground Truth Word Count: {gt_word_count}")
        print(f"Word Count Estimation Error Rate: {error:.4%}")

    return total_pred_words

In [None]:
pvad_model_path = {
    'libri': '../personal_VAD/conf/libri.yaml',
    'chime': '../personal_VAD/conf/chime.yaml',
    'ami': '../personal_VAD/conf/ami.yaml'
}

pvad_post_processing_config = './conf/pvad_post_processing.yaml'
wce_model_paths = '../word_count_estimator/conf/wce_frame_onset.yaml'

key_audio_path = ...
target_audio_path = ...

pvad_model_config = OmegaConf.load(pvad_model_path['ami'])
pvad_post_processing = OmegaConf.load(pvad_post_processing_config)
wce_model_config = OmegaConf.load(wce_model_paths)

total_pred_words = run_evaluation_pipeline(
    key_audio_path, 
    target_audio_path, 
    pvad_model_config, 
    pvad_post_processing,
    wce_model_config
)