In [21]:
import torch
from torch.utils.data import Dataset, DataLoader
import random
from transformers import AdamW
import torch.nn as nn

from transformers import *
from tokenizers import *

from bertPhoneme import BertEmbeddingsV2, BertModelV2, BertForMaskedLMV2, BertConfigV2, MaskedLMWithProsodyOutput

In [None]:
# Define your phoneme-to-ID mapping
phoneme_vocab = { "AA": 0, "AE": 1, "AH": 2, "AO": 3, "AW": 4, "AY": 5, 
                  "B": 6, "CH": 7, "D": 8, "DH": 9, "EH": 10, "ER": 11, "EY": 12, 
                  "F": 13, "G": 14, "H": 15, "IH": 16, "IY": 17, "JH": 18, "K": 19, 
                  "L": 20, "M": 21, "N": 22, "NG": 23, "OW": 24, "OY": 25, "P": 26, 
                  "R": 27, "S": 28, "SH": 29, "T": 30, "TH": 31, "UH": 32, "UW": 33, 
                  "V": 34, "W": 35, "Y": 36, "Z": 37, "ZH": 38, "PAUSE": 39, "SIL": 40 }

phoneme_vocab_size = len(phoneme_vocab)  # e.g., 41 phonemes
mask_token_id = phoneme_vocab["SIL"]  # Use SIL as [MASK]

# Example dataset
dataset = [
    (["DH", "AH", "S", "IH", "Z", "AH", "T", "EH", "S", "T"], [0, 1, 1, 2, 2, 1, 1, 0, 1, 1]),
    (["B", "AH", "T", "ER", "IH", "S", "TH", "AA", "N"], [2, 2, 1, 1, 1, 2, 0, 1, 1]),
    (["DH", "AH", "ER", "IH", "Z", "AH", "T", "EH", "S", "T"], [0, 1, 0, 2, 2, 1, 1, 0, 1, 1]),
    (["S", "P", "IY", "CH", "IH", "Z", "K", "L", "EH", "R"], [1, 1, 2, 2, 1, 0, 0, 1, 1, 1]),
    (["TH", "AE", "V", "Y", "UW", "S", "T", "IH", "CH"], [0, 0, 1, 1, 2, 2, 1, 1, 1]),
    (["K", "AO", "L", "D", "S", "T", "AA", "R", "T", "IH", "NG"], [1, 1, 2, 2, 0, 1, 1, 2, 1, 1, 1]),
    (["W", "EH", "N", "D", "IH", "Z", "DH", "AH", "K", "EY", "S"], [2, 1, 1, 2, 1, 0, 0, 1, 1, 1, 1]),
    (["N", "OW", "Y", "UW", "K", "AE", "N", "S", "T", "AA", "P", "M", "IY"], [0, 1, 1, 2, 2, 1, 0, 0, 1, 1, 2, 2, 1]),
    (["IH", "T", "W", "AA", "Z", "AH", "K", "L", "EH", "R", "D", "EY"], [1, 1, 2, 2, 1, 0, 0, 1, 1, 2, 2, 1]),
    (["TH", "AW", "K", "AE", "N", "W", "IY", "G", "IH", "V", "DH", "AH", "CH", "AE", "N", "S"], [0, 0, 1, 1, 2, 2, 1, 0, 1, 1, 2, 2, 1, 0, 1, 1]),
    (["AY", "W", "AA", "N", "T", "T", "UW", "G", "OW"], [1, 1, 2, 2, 0, 0, 1, 1, 1]),
    (["SH", "IY", "S", "EH", "D", "DH", "AH", "T", "UW", "TH"], [1, 2, 2, 1, 1, 0, 0, 2, 2, 1]),
    (["Y", "UW", "K", "AE", "N", "N", "AA", "T", "B", "IY", "S", "IH", "R", "IY", "UH", "S"], [1, 1, 2, 2, 2, 1, 0, 0, 0, 1, 1, 2, 1, 1, 2, 2]),
    (["K", "AE", "N", "Y", "UW", "R", "IY", "P", "IY", "T", "DH", "AE", "T"], [1, 1, 2, 2, 2, 1, 1, 0, 0, 1, 0, 1, 1]),
    (["B", "IH", "G", "CH", "EY", "N", "JH", "IH", "Z", "K", "AH", "M", "IH", "NG"], [0, 1, 1, 2, 2, 2, 1, 1, 0, 0, 0, 1, 1, 1]),
    (["DH", "AH", "B", "EH", "S", "T", "W", "EY", "T", "T", "UW", "D", "UW", "IH", "T"], [0, 1, 1, 2, 2, 1, 0, 0, 1, 1, 2, 2, 1, 1, 1]),
    (["AY", "K", "AE", "N", "TH", "EH", "L", "P", "Y", "UW"], [1, 2, 2, 1, 1, 0, 0, 1, 1, 2]),
    (["IH", "IY", "D", "IH", "D", "N", "AA", "T", "K", "AA", "L"], [0, 1, 1, 2, 2, 2, 1, 1, 0, 0, 1]),
    (["IH", "F", "Y", "UW", "K", "AE", "N", "R", "IY", "D", "DH", "IH", "S"], [1, 1, 2, 2, 2, 1, 0, 0, 1, 1, 2, 2, 1]),
    (["AY", "IH", "OW", "P", "Y", "UW", "L", "AY", "K", "IH", "T"], [1, 1, 2, 2, 1, 0, 0, 1, 1, 2, 2])
]


class PhonemeProsodyDataset(Dataset):
    def __init__(self, data, vocab, mask_prob=0.25, max_length=20):
        self.data = data
        self.vocab = vocab
        self.mask_prob = mask_prob
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        phonemes, prosody_ids = self.data[idx]

        # Convert phonemes to IDs
        input_ids = [self.vocab[p] for p in phonemes]
        prosody_ids = prosody_ids[:self.max_length]

        # Apply MLM (random masking)
        labels = input_ids.copy()
        for i in range(len(input_ids)):
            if random.random() < self.mask_prob:
                labels[i] = input_ids[i]  # Keep the original label
                input_ids[i] = mask_token_id  # Replace with mask token

        # Padding
        pad_length = self.max_length - len(input_ids)
        input_ids.extend([0] * pad_length)
        labels.extend([-100] * pad_length)  # -100 for ignored loss computation
        prosody_ids.extend([0] * pad_length)

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
            "prosody_ids": torch.tensor(prosody_ids, dtype=torch.long),
        }

# Create DataLoader
train_dataset = PhonemeProsodyDataset(dataset, phoneme_vocab)

In [None]:
train_dataset[0]

In [None]:
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

In [None]:
for batch in train_loader:
    print(batch)
    break

In [None]:
vocab_size = len(phoneme_vocab.keys())
max_length = 20


model_config = BertConfigV2(
    vocab_size=vocab_size,
    hidden_size=128,
    num_hidden_layers=2,
    num_attention_heads=2,
    intermediate_size=512,
    max_position_embeddings=max_length,
    prosody_cluster_size=3
)

model = BertForMaskedLMV2(config=model_config)

# BERT-Base	768	12	12	3072
# BERT-Small 512	4	8	2048
# BERT-Mini	256	4	4	1024
# BERT-Tiny	128	2	2	512

In [None]:
model

In [89]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=3e-5)

# Define loss functions
mlm_loss_fn = nn.CrossEntropyLoss(ignore_index=-100)  # MLM Loss
prosody_loss_fn = nn.CrossEntropyLoss()  # Prosody Classification Loss

num_epochs = 10
model.train()

for epoch in range(num_epochs):
    total_loss = 0

    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        prosody_ids = batch["prosody_ids"].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, prosody_ids=prosody_ids)

        # Compute losses
        mlm_loss = mlm_loss_fn(outputs.logits.view(-1, phoneme_vocab_size), labels.view(-1))
        prosody_loss = prosody_loss_fn(outputs.prosody_logits.view(-1, model.config.prosody_cluster_size), prosody_ids.view(-1))

        # Combine losses
        total_batch_loss = mlm_loss + prosody_loss
        total_batch_loss.backward()
        optimizer.step()

        total_loss += total_batch_loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader)}")

Epoch 1: Loss = 4.800920820236206
Epoch 2: Loss = 4.656748056411743
Epoch 3: Loss = 4.541573143005371
Epoch 4: Loss = 4.443735122680664
Epoch 5: Loss = 4.368757009506226
Epoch 6: Loss = 4.298747396469116
Epoch 7: Loss = 4.222767400741577
Epoch 8: Loss = 4.152597379684448
Epoch 9: Loss = 4.097935938835144
Epoch 10: Loss = 4.034598350524902


In [90]:
test_dataset = [
    (["DH", "S", "IH", "SIL", "AH", "T", "EH", "S", "T"], [0, 1, 2, 2, 1, 1, 0, 1, 1])]

In [91]:
test_dataset = PhonemeProsodyDataset(test_dataset, phoneme_vocab)
test_dataset[0]

{'input_ids': tensor([ 9, 28, 16, 37,  2, 40, 10, 28, 30,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]),
 'labels': tensor([   9,   28,   16,   37,    2,   30,   10,   28,   30, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100]),
 'prosody_ids': tensor([0, 1, 2, 2, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}

In [92]:
model.eval()

input_ids = test_dataset[0]["input_ids"].unsqueeze(0).to(device)     # shape: [1, seq_len]
labels = test_dataset[0]["labels"].unsqueeze(0).to(device)           # shape: [1, seq_len]
prosody_ids = test_dataset[0]["prosody_ids"].unsqueeze(0).to(device) # shape: [1, seq_len]

with torch.no_grad():
    outputs = model(input_ids=input_ids, prosody_ids=prosody_ids)

In [93]:
# Assume you have access to phoneme_vocab and its inverse mapping
id2phoneme = {v: k for k, v in phoneme_vocab.items()}
phoneme_logits = outputs.logits[0]  # shape: [seq_len, vocab_size]
prosody_logits = outputs.prosody_logits[0]  # shape: [seq_len, prosody_cluster_size]

print("Top 5 predictions per token:\n")

for i in range(input_ids.shape[1]):  # loop over tokens in sequence
    print(f"Token {i + 1} (Input Phoneme: {id2phoneme[input_ids[0, i].item()]}):")

    # ==== Phoneme Prediction ====
    phoneme_probs = torch.softmax(phoneme_logits[i], dim=-1)
    top5_phoneme = torch.topk(phoneme_probs, 5)
    print("  Top 5 Phoneme Predictions:")
    for j in range(5):
        pid = top5_phoneme.indices[j].item()
        prob = top5_phoneme.values[j].item()
        print(f"    {id2phoneme[pid]}: {prob:.4f}")

    # ==== Prosody Prediction ====
    prosody_probs = torch.softmax(prosody_logits[i], dim=-1)
    top5_prosody = torch.topk(prosody_probs, min(5, prosody_probs.size(-1)))
    print("  Top Prosody Predictions:")
    for j in range(top5_prosody.indices.size(0)):
        pid = top5_prosody.indices[j].item()
        prob = top5_prosody.values[j].item()
        print(f"    Cluster {pid}: {prob:.4f}")
    
    print("-" * 40)


Top 5 predictions per token:

Token 1 (Input Phoneme: DH):
  Top 5 Phoneme Predictions:
    T: 0.0503
    S: 0.0421
    IH: 0.0413
    AE: 0.0385
    IY: 0.0364
  Top Prosody Predictions:
    Cluster 0: 0.4575
    Cluster 1: 0.3541
    Cluster 2: 0.1884
----------------------------------------
Token 2 (Input Phoneme: S):
  Top 5 Phoneme Predictions:
    S: 0.0527
    T: 0.0488
    N: 0.0478
    AH: 0.0431
    IY: 0.0377
  Top Prosody Predictions:
    Cluster 1: 0.4344
    Cluster 0: 0.3952
    Cluster 2: 0.1704
----------------------------------------
Token 3 (Input Phoneme: SIL):
  Top 5 Phoneme Predictions:
    T: 0.0445
    S: 0.0441
    K: 0.0408
    UW: 0.0389
    IH: 0.0382
  Top Prosody Predictions:
    Cluster 0: 0.5339
    Cluster 1: 0.2542
    Cluster 2: 0.2119
----------------------------------------
Token 4 (Input Phoneme: Z):
  Top 5 Phoneme Predictions:
    IH: 0.0521
    N: 0.0435
    UW: 0.0421
    K: 0.0375
    T: 0.0361
  Top Prosody Predictions:
    Cluster 1: 0.4500