In [1]:
import joblib
import json
import librosa
import math
import numpy as np
import os
import pandas as pd

from collections import defaultdict, Counter
from praatio import textgrid
from sklearn.cluster import KMeans


In [2]:
# Audio and transcriptions.
audio_root   = "./data/LibriSpeech/train-clean-100"
tg_root      = "data/LibriSpeech/librispeech_alignments/train-clean-100"

# Syllabel-based manifests.
mfcc_syll_manifest    = "data/syllabert_clean100/clustering/labeled_manifest.jsonl"
hidden_syll_manifest  = "data/syllabert_clean100/hidden_clusters/labeled_manifest.jsonl"

# Frame-based manifests.
mfcc_frame_manifest = "data/hubert_clean100/clustering/frame_labeled_manifest.jsonl"
hidden_frame_manifest = "data/hubert_clean100/hidden_clusters/hidden_frame_labeled_manifest_k500.jsonl"

In [3]:
# Load manifests keyed by utt_id
def load_manifest(path):
    utt2e = defaultdict(list)
    with open(path) as f:
        for line in f:
            ent = json.loads(line)
            utt2e[ent["utterance_id"]].append(ent)
    return utt2e

mfcc_syll_m  = load_manifest(mfcc_syll_manifest)
hidden_syll_m= load_manifest(hidden_syll_manifest)
mfcc_frame_m = load_manifest(mfcc_frame_manifest)
hidden_frame_m = load_manifest(hidden_frame_manifest)


In [4]:
def load_manifest(manifest_path):
    utt2entries = defaultdict(list)
    with open(manifest_path, 'r') as f:
        for line in f:
            e = json.loads(line)
            utt2entries[e['utterance_id']].append(e)
    return utt2entries

def load_textgrid_labels(utt_id):
    for root, _, files in os.walk(tg_root):
        for fn in files:
            if fn.startswith(utt_id) and fn.endswith("_syllabified.TextGrid"):
                tg = textgrid.openTextgrid(os.path.join(root, fn), includeEmptyIntervals=True)
                tier = tg.tierNameList[2]
                return [(s, e, lab) for s,e,lab in tg.getTier(tier).entries if lab.strip()]
    return []

def assign_gt_label(start, end, gt_sylls):
    peak = (start + end) / 2
    for s,e,lab in gt_sylls:
        if s <= peak <= e:
            return lab
    return None

def compute_purity(utt2entries):
    pairs = []
    for utt, ents in utt2entries.items():
        gt = load_textgrid_labels(utt)
        if not gt: continue
        for ent in ents:
            pred = ent['cluster_id']
            gt_lab = assign_gt_label(ent['segment_start'], ent['segment_end'], gt)
            if gt_lab is not None:
                pairs.append((pred, gt_lab))
    tot = len(pairs)
    c2g = defaultdict(Counter)
    for p,g in pairs:
        c2g[p][g] += 1
    correct = sum(cnt.most_common(1)[0][1] for cnt in c2g.values())
    return correct / tot if tot else 0, tot



In [5]:
# cell: metric definitions
def compute_cluster_phone_pnmi(pred_clusters, gt_phones):
    """
    Compute:
      - cluster purity:    sum_k max_j n(k,j) / N
      - phone purity:      sum_j max_k n(k,j) / N
      - PNMI:              I(C;P) / H(P)
    
    Args:
      pred_clusters: 1-D iterable of cluster IDs (ints or strings), length N
      gt_phones:      1-D iterable of phone labels (ints or strings), length N
    
    Returns:
      (cluster_purity, phone_purity, pnmi)
    """
    N = len(pred_clusters)
    assert N == len(gt_phones), "Lengths must match"
    
    # build contingency matrix n[c][p]
    n_cp = defaultdict(lambda: defaultdict(int))
    for c, p in zip(pred_clusters, gt_phones):
        n_cp[c][p] += 1
    
    # cluster purity = sum_k max_j n(k,j) / N
    cluster_purity = sum(max(p_counts.values()) for p_counts in n_cp.values()) / N
    
    # phone purity = sum_j max_k n(k,j) / N
    # first invert counts by phone
    n_pc = defaultdict(lambda: defaultdict(int))
    for c, p in zip(pred_clusters, gt_phones):
        n_pc[p][c] += 1
    phone_purity = sum(max(c_counts.values()) for c_counts in n_pc.values()) / N
    
    # compute mutual information I(C;P)
    # P(c,p) = n(c,p)/N, P(c) = n_c/N, P(p) = n_p/N
    # I = sum_{c,p} P(c,p) * log2( P(c,p) / (P(c) P(p)) )
    # H(P) = - sum_p P(p) log2 P(p)
    # PNMI = I / H(P)
    n_c = {c: sum(p_counts.values()) for c, p_counts in n_cp.items()}
    n_p = {p: sum(c_counts.values()) for p, c_counts in n_pc.items()}
    
    I = 0.0
    for c, p_counts in n_cp.items():
        for p, n in p_counts.items():
            # joint probability
            P_cp = n / N
            P_c  = n_c[c] / N
            P_p  = n_p[p] / N
            I += P_cp * math.log2(P_cp / (P_c * P_p))
    
    H_p = - sum((count / N) * math.log2(count / N) for count in n_p.values() if count > 0)
    pnmi = I / H_p if H_p > 0 else 0.0
    
    return cluster_purity, phone_purity, pnmi

In [6]:
# --- Replace your existing tier loader & assign function with this:

from collections import defaultdict, Counter
import os
#import textgrid  # or from praatio import textgrid

def load_textgrid_tier(utt_id: str,
                       tg_root: str,
                       tier_index: int) -> list:
    """
    Load the i-th tier (0-based) from the TextGrid for this utt_id,
    filter out empty or '_unknown' labels.
    Returns list of (start, end, label).
    """
    for root, _, files in os.walk(tg_root):
        for fn in files:
            if fn.startswith(utt_id) and fn.endswith("_syllabified.TextGrid"):
                tg = textgrid.openTextgrid(
                    os.path.join(root, fn),
                    includeEmptyIntervals=True
                )
                tier_name = tg.tierNames[tier_index]
                raw = tg.getTier(tier_name).entries
                # filter out empty & unknown
                return [
                    (s, e, lab)
                    for s, e, lab in raw
                    if lab.strip() and not lab.endswith("_unknown")
                ]
    return []  # no grid found

def assign_gt_label(start: float,
                    end: float,
                    gt_intervals: list,
                    method: str = "peak",
                    overlap_thresh: float = 0.5) -> str:
    """
    Try to assign a GT label to the segment [start,end].
    * method='peak': use midpoint in region
    * method='overlap': use intersection/union >= overlap_thresh
    Returns the label string or None.
    """
    # 1) midpoint strategy
    mid = 0.5 * (start + end)
    for s, e, lab in gt_intervals:
        if s <= mid <= e:
            return lab

    if method == "overlap":
        # 2) overlap strategy
        seg_len = end - start
        for s, e, lab in gt_intervals:
            inter = max(0.0, min(end, e) - max(start, s))
            union = max(end, e) - min(start, s)
            if union>0 and (inter/union) >= overlap_thresh:
                return lab

    return None

In [7]:
def flatten_clusters(manifest_dict, tg_root, tier_index,
                     method="peak", overlap_thresh=0.5):
    """
    Given a manifest mapping utterance_id -> list of entries
    (each with 'segment_start', 'segment_end', 'cluster_id'),
    and a TextGrid root, flatten into two parallel lists:
      preds = [cluster_id_1, cluster_id_2, …]
      gts   = [gt_label_1,   gt_label_2,   …]
    by aligning each segment to the GT interval in tier_index (0-based).

    method="peak"   → match by midpoint in region
    method="overlap"→ match if IOU ≥ overlap_thresh
    """
    preds, gts = [], []

    for utt_id, entries in manifest_dict.items():
        # load GT intervals once
        gt_intervals = load_textgrid_tier(utt_id, tg_root, tier_index)
        if not gt_intervals:
            continue

        # sort manifest entries by start time
        entries_sorted = sorted(entries, key=lambda e: e["segment_start"])
        ptr = 0
        n_gt = len(gt_intervals)

        for ent in entries_sorted:
            st, ed = ent["segment_start"], ent["segment_end"]
            # advance pointer past any GT intervals ending before this segment
            while ptr < n_gt and gt_intervals[ptr][1] < st:
                ptr += 1

            # try to find a matching GT
            label = None
            mid = 0.5 * (st + ed)
            for j in range(ptr, n_gt):
                s_gt, e_gt, lab_gt = gt_intervals[j]
                if s_gt > ed:
                    break  # no further overlaps possible

                if method == "peak":
                    if s_gt <= mid <= e_gt:
                        label = lab_gt
                        break
                else:  # overlap
                    inter = max(0.0, min(ed, e_gt) - max(st, s_gt))
                    union = max(ed, e_gt) - min(st, s_gt)
                    if union > 0 and (inter / union) >= overlap_thresh:
                        label = lab_gt
                        break

            if label is None:
                continue  # skip unaligned

            preds.append(ent["cluster_id"])
            gts.append(label)

    return preds, gts

In [8]:
# Tier 3 (syllable) for syllable‐based clusters:
pred1, gt1 = flatten_clusters(mfcc_syll_m,   tg_root, tier_index=2)
pred2, gt2 = flatten_clusters(hidden_syll_m, tg_root, tier_index=2)

# Tier 2 (phoneme) for frame‐based clusters:
pred3, gt3 = flatten_clusters(mfcc_frame_m,   tg_root, tier_index=1)
pred4, gt4 = flatten_clusters(hidden_frame_m, tg_root, tier_index=1)

# compute metrics:
systems = [
    ("MFCC–syllable",        pred1, gt1),
    ("SyllaBERT–syllable",   pred2, gt2),
    ("MFCC–frame",           pred3, gt3),
    ("HuBERT hidden–frame",  pred4, gt4),
]

rows = []
for name, pred, gt in systems:
    cp, pp, pnmi = compute_cluster_phone_pnmi(pred, gt)
    rows.append({
        "System":          name,
        "Cluster Purity":    cp,
        "Segment Purity":    pp,
        "SNMI":             pnmi,
        "num_segments":    len(pred),
        "unique_segments": len(set(gt)),
        "k-clusters":    len(set(pred)),
    })

purity_results = pd.DataFrame(rows)
# Display in the notebook
purity_results = purity_results.style.format({
        "Cluster Purity":       "{:.3f}",
        "Segment Purity": "{:.3f}",
        "SNMI":         "{:.3f}"
    })
purity_results

Unnamed: 0,System,Cluster Purity,Segment Purity,SNMI,num_segments,unique_segments,k-clusters
0,MFCC–syllable,0.055,0.138,0.14,1334917,9604,100
1,SyllaBERT–syllable,0.049,0.035,0.181,1334917,9604,500
2,MFCC–frame,0.341,0.123,0.312,18086721,72,100
3,HuBERT hidden–frame,0.659,0.105,0.665,18086721,72,500


In [9]:
print(purity_results.to_latex(
    #"data/cluster_purity_results.tex",
    #index=False,
    column_format="lcccccc",
    #float_format="%.3f",
    #escape=True,
    hrules=True,
    label="tab:cluster_purity_results",
    caption="Purity and Segment-Normalized Mutual Information (SNMI) results for the four clustering systems."
))

\begin{table}
\caption{Purity and Segment-Normalized Mutual Information (SNMI) results for the four clustering systems.}
\label{tab:cluster_purity_results}
\begin{tabular}{lcccccc}
\toprule
 & System & Cluster Purity & Segment Purity & SNMI & num_segments & unique_segments & k-clusters \\
\midrule
0 & MFCC–syllable & 0.055 & 0.138 & 0.140 & 1334917 & 9604 & 100 \\
1 & SyllaBERT–syllable & 0.049 & 0.035 & 0.181 & 1334917 & 9604 & 500 \\
2 & MFCC–frame & 0.341 & 0.123 & 0.312 & 18086721 & 72 & 100 \\
3 & HuBERT hidden–frame & 0.659 & 0.105 & 0.665 & 18086721 & 72 & 500 \\
\bottomrule
\end{tabular}
\end{table}

