In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["HF_DATASETS_CACHE"] = "/shared/3/projects/bangzhao/.hf_cache"

In [2]:
import json
import random
from tqdm import tqdm
from datasets import load_dataset

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import torch.nn as nn

from bertPhoneme import BertEmbeddingsV2, BertModelV2, BertForMaskedLMV2, BertConfigV2, MaskedLMWithProsodyOutput

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm
2025-07-22 18:38:02.731687: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753209482.745303 3183293 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753209482.749490 3183293 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753209482.761473 3183293 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753209482.761483 3183293 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753209482.761485 3183293

In [3]:
with open("/shared/3/projects/bangzhao/prosodic_embeddings/bert_train/phoneme_vocab.json", "r") as f:
    phoneme_vocab = json.load(f)

In [4]:
cluster_size = 100

phoneme_vocab_size = len(phoneme_vocab)
mask_token_id = phoneme_vocab["SIL"]
pad_token_id = 72
pad_cluster_id = cluster_size + 1
mask_prosody_id = cluster_size


class HuggingFacePhonemeDataset(Dataset):
    def __init__(self, hf_dataset, vocab, mask_prob=0.15, max_length=512):
        self.dataset = hf_dataset
        self.vocab = vocab
        self.mask_prob = mask_prob
        self.max_length = max_length

        # NEW: Build (row_idx, chunk_start) mapping
        self.index_map = []
        print("Indexing chunks from dataset...")
        for row_idx, sample in tqdm(enumerate(self.dataset), total=len(self.dataset), desc="Chunking"):
            length = len(sample["phoneme"])
            for start in range(0, length, max_length):
                self.index_map.append((row_idx, start))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        row_idx, start = self.index_map[idx]
        sample = self.dataset[row_idx]

        phonemes = sample["phoneme"][start:start + self.max_length]
        prosody_ids = sample["prosody_id_100"][start:start + self.max_length]  # change cluster size here

        # Tokenize
        input_ids = [self.vocab.get(p, self.vocab["UNK"]) for p in phonemes]
        
        labels = input_ids.copy() # new
        prosody_labels = prosody_ids.copy()

        # Mask prosody
        for i in range(len(prosody_ids)):
            if random.random() < self.mask_prob:
                labels[i] = input_ids[i] # new
                input_ids[i] = mask_token_id # changed 
                prosody_labels[i] = prosody_ids[i]
                prosody_ids[i] = mask_prosody_id
            else:
                labels[i] = -100 # new 
                prosody_labels[i] = -100

        # Padding
        pad_length = self.max_length - len(input_ids)
        input_ids.extend([pad_token_id] * pad_length)
        labels.extend([-100] * pad_length) # new
        prosody_ids.extend([pad_cluster_id] * pad_length)
        prosody_labels.extend([-100] * pad_length)
        attention_mask = [1] * (len(input_ids) - pad_length) + [0] * pad_length

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long), # new
            "prosody_ids": torch.tensor(prosody_ids, dtype=torch.long),
            "prosody_labels": torch.tensor(prosody_labels, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
        }

In [5]:
hf_dataset = load_dataset("json", data_files="/shared/3/projects/bangzhao/prosodic_embeddings/merge/training_data_6features/output_part_1_20kSample.jsonl", split="train")
hf_dataset_test = hf_dataset.select(range(19800, 20000))
test_dataset = HuggingFacePhonemeDataset(hf_dataset_test, phoneme_vocab)

Generating train split: 20000 examples [05:41, 58.59 examples/s] 


Indexing chunks from dataset...


Chunking: 100%|█████████████████████████████████████████████████████| 200/200 [00:07<00:00, 26.04it/s]


In [7]:
hf_dataset_test

Dataset({
    features: ['name', 'phoneme', 'prosody_id_10', 'prosody_id_20', 'prosody_id_50', 'prosody_id_100', 'prosody_id_200', 'prosody_id_500', 'prosody_id_1000'],
    num_rows: 200
})

In [8]:
save_dir = "/shared/3/projects/bangzhao/prosodic_embeddings/bert_train/mlm_prosody&phoneme_random_position_20kSample_6features_100clu/"
checkpoint_path = os.path.join(save_dir, "mlm_prosody_step100.pt")

checkpoint = torch.load(checkpoint_path, map_location=device)

# Reconstruct config
config_dict = checkpoint["config"]
model_config = BertConfigV2(**config_dict)

# Load model
model = BertForMaskedLMV2(config=model_config)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()

BertForMaskedLMV2(
  (bert): BertModelV2(
    (embeddings): BertEmbeddingsV2(
      (word_embeddings): Embedding(73, 512, padding_idx=72)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (prosody_embeddings): Embedding(102, 512, padding_idx=101)
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput

In [9]:
sample = test_dataset[6]

input_ids = sample["input_ids"].unsqueeze(0).to(model.device)
labels = sample["labels"].unsqueeze(0).to(model.device)
prosody_ids = sample["prosody_ids"].unsqueeze(0).to(model.device)
prosody_labels = sample["prosody_labels"].unsqueeze(0).to(model.device)
attention_mask = sample["attention_mask"].unsqueeze(0).to(model.device)

In [10]:
model.eval()
with torch.no_grad():
    outputs = model(input_ids=input_ids, prosody_ids=prosody_ids, attention_mask=attention_mask)

pred_phonemes = torch.argmax(outputs.logits, dim=-1)[0]         # [seq_len]
pred_prosody = torch.argmax(outputs.prosody_logits, dim=-1)[0]  # [seq_len]

# Evaluation
id2phoneme = {v: k for k, v in phoneme_vocab.items()}
mask_positions_phoneme = (labels != -100).nonzero(as_tuple=True)[1].tolist()
mask_positions_prosody = (prosody_labels != -100).nonzero(as_tuple=True)[1].tolist()

results = []
correct_phoneme = 0
correct_prosody = 0

# ----- Evaluate phoneme prediction -----
for i in mask_positions_phoneme:
    gt_ph_id = labels[0, i].item()
    pred_ph_id = pred_phonemes[i].item()
    phoneme_correct = gt_ph_id == pred_ph_id
    if phoneme_correct:
        correct_phoneme += 1

    results.append({
        "type": "phoneme",
        "position": i,
        "phoneme_correct": phoneme_correct,
        "phoneme_gt": id2phoneme.get(gt_ph_id, "UNK"),
        "phoneme_pred": id2phoneme.get(pred_ph_id, "UNK"),
        "phoneme_gt_id": gt_ph_id,
        "phoneme_pred_id": pred_ph_id,
    })

# ----- Evaluate prosody prediction -----
for i in mask_positions_prosody:
    gt_pr_id = prosody_labels[0, i].item()
    pred_pr_id = pred_prosody[i].item()
    prosody_correct = gt_pr_id == pred_pr_id
    if prosody_correct:
        correct_prosody += 1

    results.append({
        "type": "prosody",
        "position": i,
        "prosody_correct": prosody_correct,
        "prosody_gt_id": gt_pr_id,
        "prosody_pred_id": pred_pr_id,
    })

# ----- Compute Accuracy -----
phoneme_total = len(mask_positions_phoneme)
prosody_total = len(mask_positions_prosody)

phoneme_accuracy = correct_phoneme / phoneme_total if phoneme_total > 0 else None
prosody_accuracy = correct_prosody / prosody_total if prosody_total > 0 else None

# ----- Print Accuracy -----
print(f"\nPhoneme Accuracy: {phoneme_accuracy:.2%}" if phoneme_accuracy is not None else "No masked phoneme positions.")
print(f"Prosody Accuracy: {prosody_accuracy:.2%}" if prosody_accuracy is not None else "No masked prosody positions.")

# ----- Print Results -----
for r in results:
    if r["type"] == "phoneme":
        print(f"Pos {r['position']:>3} | Phoneme: Pred={r['phoneme_pred']} GT={r['phoneme_gt']} → {'True' if r['phoneme_correct'] else 'False'}")
    elif r["type"] == "prosody":
        print(f"Pos {r['position']:>3} | Prosody: Pred={r['prosody_pred_id']} GT={r['prosody_gt_id']} → {'True' if r['prosody_correct'] else 'False'}")




Phoneme Accuracy: 42.25%
Prosody Accuracy: 19.72%
Pos   2 | Phoneme: Pred=AE1 GT=UH1 → False
Pos   6 | Phoneme: Pred=L GT=L → True
Pos   7 | Phoneme: Pred=AH0 GT=K → False
Pos   9 | Phoneme: Pred=AY1 GT=AH1 → False
Pos  10 | Phoneme: Pred=T GT=B → False
Pos  16 | Phoneme: Pred=IH0 GT=IH0 → True
Pos  21 | Phoneme: Pred=S GT=M → False
Pos  33 | Phoneme: Pred=K GT=K → True
Pos  38 | Phoneme: Pred=T GT=IY0 → False
Pos  43 | Phoneme: Pred=L GT=N → False
Pos  44 | Phoneme: Pred=AH0 GT=D → False
Pos  52 | Phoneme: Pred=IY0 GT=IY0 → True
Pos  63 | Phoneme: Pred=S GT=K → False
Pos  74 | Phoneme: Pred=B GT=W → False
Pos  78 | Phoneme: Pred=AH1 GT=EH1 → False
Pos  90 | Phoneme: Pred=Z GT=D → False
Pos  94 | Phoneme: Pred=AH0 GT=AH0 → True
Pos 100 | Phoneme: Pred=T GT=IH1 → False
Pos 101 | Phoneme: Pred=S GT=S → True
Pos 105 | Phoneme: Pred=AY1 GT=AY1 → True
Pos 108 | Phoneme: Pred=Z GT=N → False
Pos 122 | Phoneme: Pred=S GT=M → False
Pos 127 | Phoneme: Pred=Y GT=Y → True
Pos 131 | Phoneme: Pred=

In [11]:
from tqdm import tqdm

total_phoneme_correct = 0
total_prosody_correct = 0
total_phoneme_masked = 0
total_prosody_masked = 0

model.eval()

for idx in tqdm(range(len(test_dataset)), desc="Evaluating samples"):
    sample = test_dataset[idx]

    input_ids = sample["input_ids"].unsqueeze(0).to(model.device)
    phoneme_labels = sample["labels"].unsqueeze(0).to(model.device)
    prosody_ids = sample["prosody_ids"].unsqueeze(0).to(model.device)
    prosody_labels = sample["prosody_labels"].unsqueeze(0).to(model.device)
    attention_mask = sample["attention_mask"].unsqueeze(0).to(model.device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, prosody_ids=prosody_ids, attention_mask=attention_mask)

    pred_phonemes = torch.argmax(outputs.logits, dim=-1)[0]
    pred_prosody = torch.argmax(outputs.prosody_logits, dim=-1)[0]

    mask_positions_phoneme = (phoneme_labels != -100).nonzero(as_tuple=True)[1].tolist()
    mask_positions_prosody = (prosody_labels != -100).nonzero(as_tuple=True)[1].tolist()

    # Evaluate phoneme predictions
    for i in mask_positions_phoneme:
        gt_ph_id = phoneme_labels[0, i].item()
        pred_ph_id = pred_phonemes[i].item()
        total_phoneme_masked += 1
        if gt_ph_id == pred_ph_id:
            total_phoneme_correct += 1

    # Evaluate prosody predictions
    for i in mask_positions_prosody:
        gt_pr_id = prosody_labels[0, i].item()
        pred_pr_id = pred_prosody[i].item()
        total_prosody_masked += 1
        if gt_pr_id == pred_pr_id:
            total_prosody_correct += 1

# === Final accuracy calculation ===
avg_phoneme_acc = (total_phoneme_correct / total_phoneme_masked) if total_phoneme_masked > 0 else None
avg_prosody_acc = (total_prosody_correct / total_prosody_masked) if total_prosody_masked > 0 else None

print(f"\n✅ Evaluated {total_phoneme_masked} phoneme-masked and {total_prosody_masked} prosody-masked positions over {len(test_dataset)} samples")
print(f"Phoneme Accuracy: {avg_phoneme_acc:.2%}" if avg_phoneme_acc is not None else "No masked phoneme positions.")
print(f"Prosody Accuracy: {avg_prosody_acc:.2%}" if avg_prosody_acc is not None else "No masked prosody positions.")

Evaluating samples: 100%|█████████████████████████████████████████| 4768/4768 [06:09<00:00, 12.92it/s]


✅ Evaluated 358362 phoneme-masked and 358362 prosody-masked positions over 4768 samples
Phoneme Accuracy: 30.22%
Prosody Accuracy: 16.20%





In [19]:
# from sklearn.metrics import f1_score
# import torch
# from tqdm import tqdm
# import numpy as np

# y_true = []
# y_pred = []

# model.eval()

# for idx in tqdm(range(200), desc="Collecting predictions"):
#     sample = test_dataset[idx]

#     input_ids = sample["input_ids"].unsqueeze(0).to(model.device)
#     prosody_labels = sample["prosody_labels"].unsqueeze(0).to(model.device)
#     prosody_ids = sample["prosody_ids"].unsqueeze(0).to(model.device)
#     attention_mask = sample["attention_mask"].unsqueeze(0).to(model.device)

#     with torch.no_grad():
#         outputs = model(input_ids=input_ids, prosody_ids=prosody_ids, attention_mask=attention_mask)

#     pred_prosody = torch.argmax(outputs.prosody_logits, dim=-1)[0]
#     mask_positions = (prosody_labels != -100).nonzero(as_tuple=True)[1].tolist()

#     for i in mask_positions:
#         gt_pr_id = prosody_labels[0, i].item()
#         pred_pr_id = pred_prosody[i].item()
#         y_true.append(gt_pr_id)
#         y_pred.append(pred_pr_id)

Collecting predictions: 100%|█████████████████████████████████████| 200/200 [00:13<00:00, 14.59it/s]


## Further qualitative investigation

In [21]:
y_true = np.array(y_true)
y_pred = np.array(y_pred)

class_f1_scores = {}

for cls in range(100):
    binary_true = (y_true == cls).astype(int)
    binary_pred = (y_pred == cls).astype(int)
    score = f1_score(binary_true, binary_pred, zero_division=0)
    class_f1_scores[cls] = score

# Print top classes sorted by F1
for cls, score in sorted(class_f1_scores.items(), key=lambda x: -x[1])[:10]:
    print(f"Class {cls}: F1 = {score:.4f}")

Class 33: F1 = 0.5256
Class 71: F1 = 0.5116
Class 90: F1 = 0.4615
Class 25: F1 = 0.4327
Class 26: F1 = 0.4222
Class 68: F1 = 0.4218
Class 15: F1 = 0.3837
Class 51: F1 = 0.3733
Class 45: F1 = 0.3700
Class 56: F1 = 0.3537


In [22]:
for cls, score in sorted(class_f1_scores.items(), key=lambda x: -x[1])[-10:]:
    print(f"Class {cls}: F1 = {score:.4f}")

Class 64: F1 = 0.1099
Class 50: F1 = 0.1071
Class 73: F1 = 0.1065
Class 58: F1 = 0.1053
Class 12: F1 = 0.0947
Class 0: F1 = 0.0935
Class 27: F1 = 0.0877
Class 8: F1 = 0.0600
Class 16: F1 = 0.0462
Class 30: F1 = 0.0000


In [49]:
from collections import defaultdict, Counter
from tqdm import tqdm

# For each prosody class, store a Counter of input_ids
prosody_class_to_input_ids = defaultdict(Counter)
prosody_class_to_input_ids_prediction = defaultdict(Counter)

model.eval()

for idx in tqdm(range(len(test_dataset)//2), desc="Counting input_ids per prosody class"):
    sample = test_dataset[idx]

    input_ids = sample["input_ids"].unsqueeze(0).to(model.device)
    prosody_labels = sample["prosody_labels"].unsqueeze(0).to(model.device)
    prosody_ids = sample["prosody_ids"].unsqueeze(0).to(model.device)
    attention_mask = sample["attention_mask"].unsqueeze(0).to(model.device)

    mask_positions = (prosody_labels[0] != -100).nonzero(as_tuple=True)[0].tolist()
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids, prosody_ids=prosody_ids, attention_mask=attention_mask)
    
    pred_prosody = torch.argmax(outputs.prosody_logits, dim=-1)[0]
    
    for i in mask_positions:
        prosody_class = prosody_labels[0, i].item()              # Ground truth label
        prosody_prediction = pred_prosody[i].item()              # Model prediction
        token_id = input_ids[0, i].item()                        # Corresponding input token at position
    
        prosody_class_to_input_ids[prosody_class][token_id] += 1
        prosody_class_to_input_ids_prediction[prosody_prediction][token_id] += 1

Counting input_ids per prosody class: 100%|█████████████████████| 2384/2384 [03:16<00:00, 12.12it/s]


In [52]:
id_to_phoneme = {v: k for k, v in phoneme_vocab.items()}

In [53]:
# Result: maps prosody class → Counter of phonemes (instead of input_ids)
prosody_class_to_phonemes = defaultdict(Counter)

for prosody_class, input_id_counter in prosody_class_to_input_ids.items():
    for input_id, count in input_id_counter.items():
        phoneme = id_to_phoneme.get(input_id, f"<UNK_{input_id}>")
        prosody_class_to_phonemes[prosody_class][phoneme] += count

In [54]:
prosody_class_to_phonemes_prediction = defaultdict(Counter)

for prosody_class, input_id_counter in prosody_class_to_input_ids_prediction.items():
    for input_id, count in input_id_counter.items():
        phoneme = id_to_phoneme.get(input_id, f"<UNK_{input_id}>")
        prosody_class_to_phonemes_prediction[prosody_class][phoneme] += count

In [55]:
prosody_class_to_phonemes

defaultdict(collections.Counter,
            {27: Counter({'D': 168,
                      'AH0': 148,
                      'G': 109,
                      'T': 103,
                      'N': 100,
                      'Y': 93,
                      'JH': 88,
                      'IY1': 83,
                      'UW1': 74,
                      'IH0': 71,
                      'K': 66,
                      'DH': 59,
                      'IY0': 57,
                      'V': 49,
                      'Z': 47,
                      'HH': 47,
                      'B': 45,
                      'IH1': 45,
                      'M': 44,
                      'P': 33,
                      'R': 30,
                      'spn': 26,
                      'F': 26,
                      'S': 23,
                      'NG': 23,
                      'UH1': 21,
                      'L': 20,
                      'CH': 19,
                      'EY1': 17,
                      'W': 15,
     

In [56]:
prosody_class_to_phonemes_prediction

defaultdict(collections.Counter,
            {50: Counter({'IY1': 204,
                      'Y': 154,
                      'IY0': 115,
                      'D': 109,
                      'IH0': 102,
                      'AH0': 86,
                      'G': 76,
                      'UW1': 72,
                      'N': 54,
                      'T': 51,
                      'IH1': 43,
                      'HH': 38,
                      'DH': 36,
                      'Z': 28,
                      'EY1': 26,
                      'NG': 25,
                      'V': 21,
                      'L': 19,
                      'JH': 18,
                      'B': 16,
                      'EH1': 13,
                      'R': 12,
                      'K': 12,
                      'S': 7,
                      'ER0': 5,
                      'AH1': 5,
                      'UW0': 4,
                      'M': 4,
                      'F': 4,
                      'W': 4,
         

In [77]:
prosody_class_to_phonemes[30]

Counter({'S': 83,
         'T': 45,
         'Z': 20,
         'SH': 15,
         'AH0': 11,
         'N': 7,
         'IH1': 6,
         'K': 6,
         'JH': 6,
         'P': 5,
         'DH': 5,
         'IH0': 4,
         'B': 4,
         'D': 4,
         'AY1': 3,
         'EH1': 3,
         'V': 3,
         'CH': 3,
         'EY1': 3,
         'L': 3,
         'TH': 3,
         'R': 3,
         'F': 2,
         'AE1': 2,
         'W': 2,
         'HH': 2,
         'UW1': 2,
         'AA1': 2,
         'M': 1,
         'AH1': 1,
         'OW1': 1,
         'AO1': 1,
         'NG': 1,
         'UH1': 1,
         'IH2': 1,
         'ER0': 1})

In [78]:
prosody_class_to_phonemes_prediction[30]

Counter({'S': 48,
         'T': 10,
         'Z': 6,
         'JH': 5,
         'K': 2,
         'NG': 2,
         'D': 1,
         'IY0': 1,
         'HH': 1})

In [67]:
from collections import defaultdict, Counter

input_id_to_prosody_class = defaultdict(Counter)
for cls, input_id_counter in prosody_class_to_input_ids.items():
    for input_id, count in input_id_counter.items():
        input_id_to_prosody_class[input_id][cls] += count

input_id_to_predicted_class = defaultdict(Counter)
for cls, input_id_counter in prosody_class_to_input_ids_prediction.items():
    for input_id, count in input_id_counter.items():
        input_id_to_predicted_class[input_id][cls] += count

In [69]:
phoneme_to_prosody_class = defaultdict(Counter)

for input_id, class_counter in input_id_to_prosody_class.items():
    phoneme = id_to_phoneme.get(input_id, f"<UNK_{input_id}>")
    phoneme_to_prosody_class[phoneme] = class_counter

phoneme_to_predicted_class = defaultdict(Counter)

for input_id, class_counter in input_id_to_predicted_class.items():
    phoneme = id_to_phoneme.get(input_id, f"<UNK_{input_id}>")
    phoneme_to_predicted_class[phoneme] = class_counter

In [72]:
phoneme_to_predicted_class['IY0']

Counter({71: 457,
         6: 305,
         93: 147,
         67: 147,
         20: 138,
         29: 127,
         50: 115,
         59: 102,
         33: 100,
         80: 96,
         96: 91,
         14: 73,
         3: 69,
         69: 67,
         94: 51,
         88: 50,
         65: 47,
         66: 44,
         1: 40,
         86: 37,
         26: 33,
         95: 33,
         46: 32,
         43: 31,
         38: 29,
         92: 29,
         11: 26,
         79: 25,
         90: 24,
         85: 24,
         91: 24,
         39: 23,
         17: 23,
         32: 23,
         83: 22,
         49: 21,
         64: 21,
         52: 20,
         42: 20,
         72: 19,
         23: 18,
         5: 18,
         60: 17,
         31: 16,
         81: 16,
         18: 15,
         55: 15,
         37: 15,
         51: 15,
         57: 13,
         36: 13,
         45: 13,
         4: 13,
         62: 13,
         25: 13,
         13: 12,
         2: 12,
         70: 11,
         22

In [73]:
phoneme_to_prosody_class['IY0']

Counter({71: 322,
         6: 235,
         93: 161,
         50: 125,
         67: 120,
         80: 97,
         29: 96,
         69: 93,
         59: 81,
         96: 80,
         20: 76,
         65: 68,
         88: 68,
         14: 62,
         27: 57,
         94: 56,
         33: 53,
         46: 51,
         95: 42,
         52: 38,
         3: 37,
         86: 34,
         38: 33,
         18: 32,
         17: 29,
         43: 29,
         1: 29,
         92: 28,
         39: 28,
         91: 26,
         83: 25,
         31: 25,
         4: 25,
         75: 24,
         81: 24,
         49: 24,
         66: 23,
         62: 23,
         79: 23,
         26: 23,
         84: 23,
         99: 22,
         11: 21,
         36: 21,
         72: 21,
         23: 20,
         63: 20,
         51: 20,
         60: 19,
         16: 19,
         48: 18,
         55: 17,
         37: 17,
         85: 17,
         87: 17,
         90: 17,
         9: 17,
         25: 17,
         73: 1

## Baseline calculation

In [13]:
import torch
from collections import Counter
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader

# Prepare dataloader
test_loader = DataLoader(test_dataset, batch_size=1)

# Collect all true masked prosody labels
all_true_labels = []
for batch in tqdm(test_loader, desc="Collecting True Prosody Labels"):
    labels = batch["prosody_labels"][0]  # shape: (512,)
    mask = labels != -100
    all_true_labels.extend(labels[mask].tolist())

# Get prosody cluster IDs excluding pad and mask tokens
prosody_ids = list(set(all_true_labels))
prosody_ids = [p for p in prosody_ids if p not in {pad_cluster_id, mask_prosody_id}]

########################################
# 1. Majority Class Baseline
########################################
majority_class = Counter(all_true_labels).most_common(1)[0][0]
majority_preds = [majority_class] * len(all_true_labels)
majority_acc = np.mean(np.array(majority_preds) == np.array(all_true_labels))
print(f"Majority Class Accuracy: {majority_acc:.4f}")

########################################
# 2. Sampling from Label Distribution
########################################
label_counter = Counter(all_true_labels)
label_total = sum(label_counter.values())
label_probs = [label_counter[i] / label_total for i in prosody_ids]
label_dist_preds = np.random.choice(prosody_ids, size=len(all_true_labels), p=label_probs)
label_dist_acc = np.mean(label_dist_preds == np.array(all_true_labels))
print(f"Sampled from Label Distribution Accuracy: {label_dist_acc:.4f}")

########################################
# 3. Uniform Sampling over All Classes
########################################
uniform_preds = np.random.choice(prosody_ids, size=len(all_true_labels))
uniform_acc = np.mean(uniform_preds == np.array(all_true_labels))
print(f"Uniform Sampling Accuracy: {uniform_acc:.4f}")

Collecting True Prosody Labels: 100%|███████████| 4768/4768 [05:36<00:00, 14.17it/s]

Majority Class Accuracy: 0.0222
Sampled from Label Distribution Accuracy: 0.0112
Uniform Sampling Accuracy: 0.0099



