In [67]:
import torch
from torch.utils.data import Dataset, DataLoader
import random
from transformers import AdamW
import torch.nn as nn

from transformers import *
from tokenizers import *

from bertPhoneme import BertEmbeddingsV2, BertModelV2, BertForMaskedLMV2, BertConfigV2, MaskedLMWithProsodyOutput

In [68]:
import json
import random
from datasets import load_dataset

In [26]:
file_path = "/shared/3/projects/bangzhao/prosodic_embeddings/merge/training_data/output_part_1.jsonl"
phonemes = set()

with open(file_path, 'r') as f:
    for i, line in enumerate(f):
        if i >= 10000:
            break
        phonemes.update(set(json.loads(line)['phoneme']))
    print(len(phonemes))

70


In [27]:
phonemes.update(['UNK', 'SIL'])
print(list(phonemes))

['Z', 'AE1', 'W', 'G', 'AA0', 'AA2', 'R', 'AO1', 'D', 'IH2', 'AO2', 'Y', 'NG', 'EY2', 'DH', 'K', 'V', 'ER2', 'UW1', 'EY1', 'AA1', 'EH0', 'OY2', 'AY2', 'OY1', 'AE2', 'UH1', 'AH1', 'HH', 'JH', 'SH', 'SIL', 'UW2', 'ZH', 'AY1', 'AW0', 'N', 'UW0', 'IY1', 'EH2', 'OW1', 'M', 'ER0', 'OW0', 'TH', 'EY0', 'UH0', 'AW1', 'AH0', 'L', 'T', 'AW2', 'EH1', 'AY0', 'B', 'IY0', 'F', 'P', 'ER1', 'IH1', 'IH0', 'AO0', 'OW2', 'spn', 'OY0', 'UNK', 'UH2', 'S', 'IY2', 'CH', 'AH2', 'AE0']


In [28]:
phoneme_vocab = {p: i for i, p in enumerate(phonemes)}

In [29]:
print(phoneme_vocab)

{'Z': 0, 'AE1': 1, 'W': 2, 'G': 3, 'AA0': 4, 'AA2': 5, 'R': 6, 'AO1': 7, 'D': 8, 'IH2': 9, 'AO2': 10, 'Y': 11, 'NG': 12, 'EY2': 13, 'DH': 14, 'K': 15, 'V': 16, 'ER2': 17, 'UW1': 18, 'EY1': 19, 'AA1': 20, 'EH0': 21, 'OY2': 22, 'AY2': 23, 'OY1': 24, 'AE2': 25, 'UH1': 26, 'AH1': 27, 'HH': 28, 'JH': 29, 'SH': 30, 'SIL': 31, 'UW2': 32, 'ZH': 33, 'AY1': 34, 'AW0': 35, 'N': 36, 'UW0': 37, 'IY1': 38, 'EH2': 39, 'OW1': 40, 'M': 41, 'ER0': 42, 'OW0': 43, 'TH': 44, 'EY0': 45, 'UH0': 46, 'AW1': 47, 'AH0': 48, 'L': 49, 'T': 50, 'AW2': 51, 'EH1': 52, 'AY0': 53, 'B': 54, 'IY0': 55, 'F': 56, 'P': 57, 'ER1': 58, 'IH1': 59, 'IH0': 60, 'AO0': 61, 'OW2': 62, 'spn': 63, 'OY0': 64, 'UNK': 65, 'UH2': 66, 'S': 67, 'IY2': 68, 'CH': 69, 'AH2': 70, 'AE0': 71}


In [56]:
phoneme_vocab_size = len(phoneme_vocab) 
mask_token_id = phoneme_vocab["SIL"]
pad_token_id = 72
pad_cluster_id = 200

class HuggingFacePhonemeDataset(Dataset):
    def __init__(self, hf_dataset, vocab, mask_prob=0.2, max_length=128):
        self.dataset = hf_dataset
        self.vocab = vocab
        self.mask_prob = mask_prob
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]

        phonemes = sample["phoneme"]
        prosody_ids = sample["prosody_id_200"]

        input_ids = [self.vocab.get(p, self.vocab["UNK"]) for p in phonemes][:self.max_length]
        prosody_ids = prosody_ids[:self.max_length]

        labels = input_ids.copy()
        for i in range(len(input_ids)):
            if random.random() < self.mask_prob:
                labels[i] = input_ids[i]
                input_ids[i] = self.vocab["SIL"]

        pad_length = self.max_length - len(input_ids)
        input_ids.extend([pad_token_id] * pad_length)
        labels.extend([-100] * pad_length)
        prosody_ids.extend([pad_cluster_id] * pad_length)
        attention_mask = [1] * (self.max_length - pad_length) + [0] * pad_length
        
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
            "prosody_ids": torch.tensor(prosody_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
        }

In [36]:
hf_dataset = load_dataset("json", data_files="/shared/3/projects/bangzhao/prosodic_embeddings/merge/training_data/output_part_1.jsonl", split="train")

Generating train split: 0 examples [00:00, ? examples/s]

Loading dataset shards:   0%|          | 0/60 [00:00<?, ?it/s]

In [57]:
train_dataset = HuggingFacePhonemeDataset(hf_dataset, phoneme_vocab)

In [58]:
train_dataset[0]

{'input_ids': tensor([34, 41, 50, 59, 56,  1, 49, 48, 36,  2, 60, 14, 14, 48, 63, 63,  0, 31,
         60, 36, 49, 34, 50, 48, 31, 14, 27,  6, 31, 67, 48, 36, 31, 67, 48, 57,
          6, 38, 41, 15,  7,  6, 50,  8, 60, 67, 59, 33, 48, 36, 57,  6, 43, 28,
         59, 54, 48, 31, 60, 12, 60, 41, 57, 49, 31, 41, 48, 36, 31,  8, 60, 67,
         31,  6,  9, 41, 48, 31, 19, 31, 31, 36, 48,  3, 52, 36, 67, 50, 31, 62,
         41, 43, 67, 52, 15, 31, 31, 48, 49,  0, 48, 31,  8, 63, 67, 52, 31, 48,
         31, 42, 31, 20, 30, 31,  7, 31, 55, 48, 31, 31, 48, 31, 26,  6, 55, 31,
         36, 50]),
 'labels': tensor([34, 41, 50, 59, 56,  1, 49, 48, 36,  2, 60, 14, 14, 48, 63, 63,  0, 63,
         60, 36, 49, 34, 50, 48, 16, 14, 27,  6, 38, 67, 48, 36, 50, 67, 48, 57,
          6, 38, 41, 15,  7,  6, 50,  8, 60, 67, 59, 33, 48, 36, 57,  6, 43, 28,
         59, 54, 48, 50, 60, 12, 60, 41, 57, 49, 24, 41, 48, 36, 50,  8, 60, 67,
         15,  6,  9, 41, 48, 36, 19, 30, 48, 36, 48,  3, 52, 36, 67

In [59]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [60]:
for batch in train_loader:
    print(batch)
    break

{'input_ids': tensor([[60, 50, 67,  ...,  6, 68, 48],
        [ 3, 52, 31,  ..., 31, 36,  8],
        [28, 21, 31,  ..., 31, 56,  7],
        ...,
        [28, 48, 49,  ..., 50, 31, 31],
        [28, 31, 14,  ..., 31, 15, 30],
        [63,  6, 68,  ..., 36, 14, 48]]), 'labels': tensor([[60, 50, 67,  ...,  6, 68, 48],
        [ 3, 52, 50,  ...,  1, 36,  8],
        [28, 21, 49,  ..., 49, 56,  7],
        ...,
        [28, 48, 49,  ..., 50, 41, 42],
        [28, 34, 14,  ...,  1, 15, 30],
        [63,  6, 68,  ..., 36, 14, 48]]), 'prosody_ids': tensor([[  7,  60, 153,  ..., 138, 154,  68],
        [153, 153,  77,  ...,   7,  46, 186],
        [ 76,  41, 153,  ..., 174,   7,  49],
        ...,
        [179, 153, 153,  ...,  29, 131, 116],
        [195, 111,  99,  ...,   7, 174,  23],
        [  7, 179,   0,  ...,  46, 133,  94]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 

In [61]:
vocab_size = len(phoneme_vocab.keys())
max_length = 128


model_config = BertConfigV2(
    vocab_size=vocab_size,
    hidden_size=128,
    num_hidden_layers=2,
    num_attention_heads=2,
    intermediate_size=512,
    max_position_embeddings=max_length,
    prosody_cluster_size=3
)

model = BertForMaskedLMV2(config=model_config)

# BERT-Base	768	12	12	3072
# BERT-Small 512	4	8	2048
# BERT-Mini	256	4	4	1024
# BERT-Tiny	128	2	2	512

In [62]:
model

BertForMaskedLMV2(
  (bert): BertModelV2(
    (embeddings): BertEmbeddingsV2(
      (word_embeddings): Embedding(72, 128, padding_idx=0)
      (position_embeddings): Embedding(128, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (prosody_embeddings): Embedding(3, 128)
      (conv): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (den

In [72]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=1e-3)

# Define loss functions
mlm_loss_fn = nn.CrossEntropyLoss(ignore_index=-100)  # MLM Loss
prosody_loss_fn = nn.CrossEntropyLoss()  # Prosody Classification Loss

num_epochs = 5
model.train()

for epoch in range(num_epochs):
    total_loss = 0

    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        prosody_ids = batch["prosody_ids"].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, prosody_ids=prosody_ids)

        # Compute losses
        mlm_loss = mlm_loss_fn(outputs.logits.view(-1, phoneme_vocab_size), labels.view(-1))
        prosody_loss = prosody_loss_fn(outputs.prosody_logits.view(-1, model.config.prosody_cluster_size), prosody_ids.view(-1))

        # Combine losses
        total_batch_loss = mlm_loss + prosody_loss
        total_batch_loss.backward()
        optimizer.step()

        total_loss += total_batch_loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader)}")

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [36]:
test_dataset = [
    (["DH", "S", "IH", "Z", "AH", "T", "EH", "S", "T"], [0, 1, 2, 2, 1, 1, 0, 1, 1])]

In [37]:
test_dataset = PhonemeProsodyDataset(test_dataset, phoneme_vocab)
test_dataset[0]

{'input_ids': tensor([ 9, 28, 16, 40,  2, 30, 10, 40, 30,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]),
 'labels': tensor([   9,   28,   16,   37,    2,   30,   10,   28,   30, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100]),
 'prosody_ids': tensor([0, 1, 2, 2, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}

In [38]:
model.eval()

input_ids = test_dataset[0]["input_ids"].unsqueeze(0).to(device)     # shape: [1, seq_len]
labels = test_dataset[0]["labels"].unsqueeze(0).to(device)           # shape: [1, seq_len]
prosody_ids = test_dataset[0]["prosody_ids"].unsqueeze(0).to(device) # shape: [1, seq_len]

with torch.no_grad():
    outputs = model(input_ids=input_ids, prosody_ids=prosody_ids)

In [39]:
# Assume you have access to phoneme_vocab and its inverse mapping
id2phoneme = {v: k for k, v in phoneme_vocab.items()}
phoneme_logits = outputs.logits[0]  # shape: [seq_len, vocab_size]
prosody_logits = outputs.prosody_logits[0]  # shape: [seq_len, prosody_cluster_size]

print("Top 5 predictions for masked tokens (SIL):\n")

for i in range(input_ids.shape[1]):  # loop over tokens in sequence
    if input_ids[0, i].item() != phoneme_vocab["SIL"]:
        continue  # Only predict for masked tokens

    print(f"Token {i + 1} (Masked Position):")

    # ==== Phoneme Prediction ====
    phoneme_probs = torch.softmax(phoneme_logits[i], dim=-1)
    top5_phoneme = torch.topk(phoneme_probs, 5)
    print("  Top 5 Phoneme Predictions:")
    for j in range(5):
        pid = top5_phoneme.indices[j].item()
        prob = top5_phoneme.values[j].item()
        print(f"    {id2phoneme[pid]}: {prob:.4f}")

    # ==== Prosody Prediction ====
    prosody_probs = torch.softmax(prosody_logits[i], dim=-1)
    top5_prosody = torch.topk(prosody_probs, min(5, prosody_probs.size(-1)))
    print("  Top Prosody Predictions:")
    for j in range(top5_prosody.indices.size(0)):
        pid = top5_prosody.indices[j].item()
        prob = top5_prosody.values[j].item()
        print(f"    Cluster {pid}: {prob:.4f}")
    
    print("-" * 40)


Top 5 predictions for masked tokens (SIL):

Token 2 (Masked Position):
  Top 5 Phoneme Predictions:
    AH: 0.4596
    AE: 0.0711
    OW: 0.0620
    AO: 0.0358
    NG: 0.0338
  Top Prosody Predictions:
    Cluster 1: 0.9979
    Cluster 0: 0.0017
    Cluster 2: 0.0005
----------------------------------------
Token 7 (Masked Position):
  Top 5 Phoneme Predictions:
    EH: 0.2366
    W: 0.2217
    IH: 0.0439
    B: 0.0438
    EY: 0.0413
  Top Prosody Predictions:
    Cluster 0: 0.9802
    Cluster 2: 0.0134
    Cluster 1: 0.0065
----------------------------------------
