In [15]:
# semantic_codec=raw_audio -> s-tokens

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
import numpy as np
from transformers import AutoProcessor, AutoModel, WhisperProcessor, WhisperForConditionalGeneration
import os
import glob
import librosa
import gc
import logging
from tqdm import tqdm


  from .autonotebook import tqdm as notebook_tqdm


### Loading Data and Extracting Features

In [2]:
# Set up logging
logging.basicConfig(level=logging.INFO, filename='training.log', filemode='w')

# Set random seed for reproducibility
torch.manual_seed(42)

# Define the dataset class
class Arabic_Processed_audios(Dataset):
    def __init__(self, audio_path, max_length=160000):
        self.audio_path = audio_path
        self.audio_files = glob.glob(os.path.join(self.audio_path, "*.mp3"))
        self.max_length = max_length

    def __getitem__(self, idx):
        file = self.audio_files[idx]
        try:
            audio, _ = librosa.load(file, sr=16000)
            audio = librosa.util.normalize(audio)
            
            if len(audio) < self.max_length:
                audio = np.pad(audio, (0, self.max_length - len(audio)))
            else:
                audio = audio[:self.max_length]
            
            return audio, os.path.basename(file)  # Return filename for saving features
        except Exception as e:
            logging.error(f"Error loading {file}: {str(e)}")
            return None, None

    def __len__(self):
        return len(self.audio_files)

# Load audio files
Audio_files = Arabic_Processed_audios('phase2_data/subset_80k_audio')
print(f"Number of audio files: {len(Audio_files)}")

# Load Whisper processor and model
processor = WhisperProcessor.from_pretrained("openai/whisper-base", cache_dir='./processor_cache')
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base", cache_dir='./model_cache')
device = torch.device("cpu")  # Explicitly set to CPU since MPS is unsupported
model.to(device)

# Enable mixed precision if possible

def extract_whisper_features(model, audio, processor, layer=-1, max_length=1500):
    try:
        inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            encoder_outputs = model.get_encoder()(inputs["input_features"])
            features = encoder_outputs.last_hidden_state if layer == -1 else encoder_outputs.hidden_states[layer]
        # Pad or truncate to max_length
        if features.shape[1] > max_length:
            features = features[:, :max_length, :]
        elif features.shape[1] < max_length:
            padding = (0, 0, 0, max_length - features.shape[1])
            features = torch.nn.functional.pad(features, padding)
        return features.cpu()  # Move to CPU to save memory
    except Exception as e:
        logging.error(f"Error processing audio: {str(e)}")
        return None

# Create output directory for features
output_dir = "features_output"
os.makedirs(output_dir, exist_ok=True)

# Use DataLoader for batch processing
dataloader = DataLoader(Audio_files, batch_size=16, shuffle=False, num_workers=0)  # Adjust batch_size as needed

print("Processing Audio Files")
for batch in tqdm(dataloader, desc="Processing Audio Files"):
    audios, filenames = batch
    for audio, filename in zip(audios, filenames):
        if audio is None:
            continue  # Skip problematic files
        features = extract_whisper_features(model, audio.numpy(), processor, layer=-1)
        if features is not None:
            # Save features to disk
            feature_path = os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_features.npy")
            np.save(feature_path, features.numpy())
        
        # Clear memory
        del features, audio
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

print("Feature extraction complete. Features saved to:", output_dir)

Number of audio files: 45000
Processing Audio Files


Processing Audio Files: 100%|██████████| 2813/2813 [2:53:59<00:00,  3.71s/it]  

Feature extraction complete. Features saved to: features_output





In [2]:
class FeaturesDataset(Dataset):
    def __init__(self, features_dir):
        self.features_dir = features_dir
        self.feature_files = glob.glob(os.path.join(features_dir, "*_features.npy"))
        
    def __getitem__(self, idx):
        feature_file = self.feature_files[idx]
        features = np.load(feature_file)
        return torch.FloatTensor(features)
    
    def __len__(self):
        return len(self.feature_files)

In [3]:
# Load features dataset
features_dataset = FeaturesDataset('features_output')
print(f"Number of feature files: {len(features_dataset)}")

# Create train/test split
train_size = int(0.8 * len(features_dataset))
test_size = len(features_dataset) - train_size
train_dataset, test_dataset = random_split(features_dataset, [train_size, test_size])

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

Number of feature files: 45000


##### Conv Block

In [4]:
class ConvNextBlock(nn.Module):
    def __init__(self, dim, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, kernel_size, padding=kernel_size//2, groups=dim)  # Depthwise
        self.norm = nn.LayerNorm(dim)
        self.pwconv1 = nn.Linear(dim, 4 * dim)
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.act = nn.GELU()

    def forward(self, x):
        # x: (batch, dim, seq_len)
        residual = x
        x = self.conv(x)
        # Transpose for LayerNorm
        x = x.transpose(1, 2)  # (batch, seq_len, dim)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        # Transpose back
        x = x.transpose(1, 2)  # (batch, dim, seq_len)
        return x + residual

##### Vector Quantization Layer

In [5]:
# Vector Quantization Layer
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings=8192, embedding_dim=8, commitment_cost=0.25):
        
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.embeddings = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
        self.register_buffer('ema_count', torch.zeros(num_embeddings))
        self.register_buffer('ema_weight', self.embeddings.clone())

    def forward(self, x):

        flat_x = x.reshape(-1, self.embedding_dim)
        distances = torch.cdist(flat_x, self.embeddings)
        encoding_indices = torch.argmin(distances, dim=1)
        quantized = self.embeddings[encoding_indices].reshape(x.shape)
        codebook_loss = F.mse_loss(quantized.detach(), x)
        commitment_loss = self.commitment_cost * F.mse_loss(quantized, x.detach())
        loss = codebook_loss + commitment_loss
        quantized = x + (quantized - x).detach()
        
        if self.training:
            with torch.no_grad():
                one_hot = F.one_hot(encoding_indices, self.num_embeddings).float()
                self.ema_count = 0.999 * self.ema_count + 0.001 * torch.sum(one_hot, dim=0)
                n = torch.sum(self.ema_count)
                self.ema_count = (self.ema_count + 1e-8) / (n + self.num_embeddings * 1e-8) * n
                dw = torch.matmul(one_hot.transpose(0, 1), flat_x)
                self.ema_weight = 0.999 * self.ema_weight + 0.001 * dw
                self.embeddings.data = (self.ema_weight / (self.ema_count.unsqueeze(-1) + 1e-8))
        
        return quantized, loss, encoding_indices

#### Semantic Codec model

In [7]:
# VQ-VAE Model
# Custom Lambda module for applying arbitrary functions (e.g., transpose)
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)
class VQVAE(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=384, codebook_size=8192, codebook_dim=8):
        super().__init__()
        self.encoder = nn.Sequential(
                    nn.Conv1d(input_dim, hidden_dim, kernel_size=7, padding=3),
                    Lambda(lambda x: x.transpose(1, 2)),  # To (batch_size, sequence_length, hidden_dim)
                    nn.LayerNorm(hidden_dim),
                    Lambda(lambda x: x.transpose(1, 2)),  # Back to (batch_size, hidden_dim, sequence_length)
                    # Assume ConvNextBlock is defined and works with (batch_size, hidden_dim, sequence_length)
                    *[ConvNextBlock(hidden_dim) for _ in range(6)],
                    nn.Conv1d(hidden_dim, codebook_dim, kernel_size=1)
                )
        self.quantizer = VectorQuantizer(num_embeddings=codebook_size, embedding_dim=codebook_dim)
        self.decoder = nn.Sequential(
            nn.Conv1d(codebook_dim, hidden_dim, kernel_size=7, padding=3),
            *[ConvNextBlock(hidden_dim) for _ in range(6)],
            nn.Conv1d(hidden_dim, input_dim, kernel_size=1)
        )

    def forward(self, x):
            x = x.transpose(1, 2)  # From (batch_size, sequence_length, input_dim) to (batch_size, input_dim, sequence_length)
            z = self.encoder(x)  # Output: (batch_size, codebook_dim, sequence_length)
            z = z.transpose(1, 2)  # To (batch_size, sequence_length, codebook_dim) for quantizer
            quantized, vq_loss, indices = self.quantizer(z)  # quantized: (batch_size, sequence_length, codebook_dim)
            quantized = quantized.transpose(1, 2)  # To (batch_size, codebook_dim, sequence_length)
            recon = self.decoder(quantized)  # Input to decoder: (batch_size, codebook_dim, sequence_length)
            return recon, quantized, vq_loss, indices

#### Dataloader

In [8]:
# Load features dataset
features_dataset = FeaturesDataset('features_output')
print(f"Number of feature files: {len(features_dataset)}")

# Create train/test split
train_size = int(0.8 * len(features_dataset))
test_size = len(features_dataset) - train_size
train_dataset, test_dataset = random_split(features_dataset, [train_size, test_size])

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

Number of feature files: 45000


In [10]:
# Initialize VQVAE model
device='mps'
model = VQVAE(input_dim=512, hidden_dim=384, codebook_size=8192, codebook_dim=8)
model = model.to(device)

# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

#### Training and evaluation of semantic codec

In [24]:
def train_model(model, train_dataloader, test_dataloader, num_epochs=10):   
    model.train()
    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        epoch_loss = 0
        epoch_recon_loss = 0
        epoch_vq_loss = 0
        
        for batch in tqdm(train_dataloader, desc="Training", leave=False):
            features = batch[0].to(device)
            recon, quantized, vq_loss, indices = model(features)
            recon = recon.transpose(1, 2)  # Back to (batch_size, 1500, 512)
            # Normalize the loss by batch size
            recon_loss = torch.nn.functional.mse_loss(recon, batch[0].to(device), reduction='mean')
            loss = recon_loss + vq_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Track individual losses
            epoch_recon_loss += recon_loss.item()
            epoch_vq_loss += vq_loss.item()
            epoch_loss += loss.item()

        # Print epoch statistics once per epoch
        avg_recon_loss = epoch_recon_loss / len(train_dataloader)
        avg_vq_loss = epoch_vq_loss / len(train_dataloader)
        avg_total_loss = epoch_loss / len(train_dataloader)
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"  Training - Recon Loss: {avg_recon_loss:.4f}, VQ Loss: {avg_vq_loss:.4f}, Total Loss: {avg_total_loss:.4f}")

        # Testing loop
        model.eval()
        with torch.no_grad():
            test_recon_loss = 0
            test_vq_loss = 0
            test_total_loss = 0
            
            for batch in tqdm(test_dataloader, desc="Testing"):
                features = batch[0].to(device)
                recon, quantized, vq_loss, indices = model(features)
                recon = recon.transpose(1, 2)
                # Use the same loss calculation as training
                recon_loss = torch.nn.functional.mse_loss(recon, batch[0].to(device), reduction='mean')
                total_loss = recon_loss + vq_loss
                
                test_recon_loss += recon_loss.item()
                test_vq_loss += vq_loss.item()
                test_total_loss += total_loss.item()
            
            # Print test statistics
            avg_test_recon_loss = test_recon_loss / len(test_dataloader)
            avg_test_vq_loss = test_vq_loss / len(test_dataloader)
            avg_test_total_loss = test_total_loss / len(test_dataloader)
            print(f"  Testing - Recon Loss: {avg_test_recon_loss:.4f}, VQ Loss: {avg_test_vq_loss:.4f}, Total Loss: {avg_test_total_loss:.4f}\n")
    
    return model

In [25]:
# Start training
model = train_model(model, train_dataloader, test_dataloader, num_epochs=15)

# Save the trained model
torch.save(model.state_dict(), 'vqvae_model.pth')

Epochs:   0%|          | 0/15 [00:00<?, ?it/s]


Epoch 1/15
  Training - Recon Loss: 2.5877, VQ Loss: 0.4990, Total Loss: 3.0867


Testing: 100%|██████████| 563/563 [00:22<00:00, 24.80it/s]
Epochs:   7%|▋         | 1/15 [03:26<48:15, 206.82s/it]

  Testing - Recon Loss: 1.4096, VQ Loss: 0.2361, Total Loss: 1.6457






Epoch 2/15
  Training - Recon Loss: 2.6228, VQ Loss: 0.2166, Total Loss: 2.8394


Testing: 100%|██████████| 563/563 [00:22<00:00, 24.69it/s]
Epochs:  13%|█▎        | 2/15 [06:41<43:15, 199.67s/it]

  Testing - Recon Loss: 1.8586, VQ Loss: 0.1596, Total Loss: 2.0182






Epoch 3/15
  Training - Recon Loss: 2.1301, VQ Loss: 0.1619, Total Loss: 2.2920


Testing: 100%|██████████| 563/563 [00:23<00:00, 24.01it/s]
Epochs:  20%|██        | 3/15 [09:58<39:41, 198.50s/it]

  Testing - Recon Loss: 1.4287, VQ Loss: 0.1768, Total Loss: 1.6055






Epoch 4/15
  Training - Recon Loss: 1.7337, VQ Loss: 0.1154, Total Loss: 1.8491


Testing: 100%|██████████| 563/563 [00:23<00:00, 24.07it/s]
Epochs:  27%|██▋       | 4/15 [13:16<36:20, 198.21s/it]

  Testing - Recon Loss: 1.7483, VQ Loss: 0.1132, Total Loss: 1.8615






Epoch 5/15
  Training - Recon Loss: 1.4932, VQ Loss: 0.0950, Total Loss: 1.5883


Testing: 100%|██████████| 563/563 [00:23<00:00, 24.38it/s]
Epochs:  33%|███▎      | 5/15 [16:33<32:56, 197.66s/it]

  Testing - Recon Loss: 1.3937, VQ Loss: 0.0913, Total Loss: 1.4850






Epoch 6/15
  Training - Recon Loss: 1.3745, VQ Loss: 0.1139, Total Loss: 1.4885


Testing: 100%|██████████| 563/563 [00:22<00:00, 24.62it/s]
Epochs:  40%|████      | 6/15 [19:49<29:34, 197.15s/it]

  Testing - Recon Loss: 1.3656, VQ Loss: 0.0275, Total Loss: 1.3931






Epoch 7/15
  Training - Recon Loss: 99.6538, VQ Loss: 434.5145, Total Loss: 534.1683


Testing: 100%|██████████| 563/563 [00:23<00:00, 24.09it/s]
Epochs:  47%|████▋     | 7/15 [23:04<26:13, 196.69s/it]

  Testing - Recon Loss: 1.4037, VQ Loss: 0.9941, Total Loss: 2.3978






Epoch 8/15
  Training - Recon Loss: 1.3277, VQ Loss: 0.6148, Total Loss: 1.9425


Testing: 100%|██████████| 563/563 [00:23<00:00, 24.04it/s]
Epochs:  53%|█████▎    | 8/15 [26:22<22:58, 196.89s/it]

  Testing - Recon Loss: 1.2825, VQ Loss: 0.3829, Total Loss: 1.6653






Epoch 9/15
  Training - Recon Loss: 1.7769, VQ Loss: 0.3563, Total Loss: 2.1332


Testing: 100%|██████████| 563/563 [00:23<00:00, 23.81it/s]
Epochs:  60%|██████    | 9/15 [29:40<19:43, 197.31s/it]

  Testing - Recon Loss: 1.1900, VQ Loss: 0.2856, Total Loss: 1.4756






Epoch 10/15
  Training - Recon Loss: 2.0055, VQ Loss: 0.2589, Total Loss: 2.2645


Testing: 100%|██████████| 563/563 [00:23<00:00, 23.77it/s]
Epochs:  67%|██████▋   | 10/15 [32:59<16:28, 197.70s/it]

  Testing - Recon Loss: 1.1931, VQ Loss: 0.1805, Total Loss: 1.3736






Epoch 11/15
  Training - Recon Loss: 1.5074, VQ Loss: 0.2493, Total Loss: 1.7567


Testing: 100%|██████████| 563/563 [00:23<00:00, 23.80it/s]
Epochs:  73%|███████▎  | 11/15 [36:17<13:11, 197.90s/it]

  Testing - Recon Loss: 1.2073, VQ Loss: 0.1684, Total Loss: 1.3757






Epoch 12/15
  Training - Recon Loss: 418.8888, VQ Loss: 23665.5423, Total Loss: 24084.4307


Testing: 100%|██████████| 563/563 [00:23<00:00, 24.16it/s]
Epochs:  80%|████████  | 12/15 [39:33<09:52, 197.39s/it]

  Testing - Recon Loss: 1.4731, VQ Loss: 32.0831, Total Loss: 33.5562






Epoch 13/15
  Training - Recon Loss: 1.3003, VQ Loss: 5.3324, Total Loss: 6.6327


Testing: 100%|██████████| 563/563 [00:23<00:00, 23.69it/s]
Epochs:  87%|████████▋ | 13/15 [42:51<06:35, 197.61s/it]

  Testing - Recon Loss: 1.2372, VQ Loss: 2.3025, Total Loss: 3.5397






Epoch 14/15
  Training - Recon Loss: 1.8435, VQ Loss: 1.5969, Total Loss: 3.4404


Testing: 100%|██████████| 563/563 [00:23<00:00, 23.86it/s]
Epochs:  93%|█████████▎| 14/15 [46:10<03:17, 197.84s/it]

  Testing - Recon Loss: 1.2850, VQ Loss: 0.9127, Total Loss: 2.1977






Epoch 15/15
  Training - Recon Loss: 3.1348, VQ Loss: 0.7053, Total Loss: 3.8400


Testing: 100%|██████████| 563/563 [00:23<00:00, 23.82it/s]
Epochs: 100%|██████████| 15/15 [49:28<00:00, 197.92s/it]


  Testing - Recon Loss: 18.3410, VQ Loss: 0.7024, Total Loss: 19.0434



In [28]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

# Assuming you have a validation dataset (val_dataset) and corresponding dataloader (val_dataloader)
# If not, you'll need to create one.  This example assumes you have one.

def evaluate_model(model, dataloader, device):
    model.eval()  # Set the model to evaluation mode
    total_recon_loss = 0.0
    total_vq_loss = 0.0
    total_loss = 0.0
    num_batches = len(dataloader)

    with torch.no_grad():  # Disable gradient calculation for evaluation
        for batch in tqdm(dataloader, desc="Evaluating"):
            features = batch[0].to(device)  # Assuming your data loader returns a list, and the first element is your feature tensor
            
            recon, quantized, vq_loss, indices = model(features)
            recon = recon.transpose(1, 2)
            recon_loss = torch.nn.functional.mse_loss(recon, features, reduction='mean')
            loss = recon_loss + vq_loss

            total_recon_loss += recon_loss.item()
            total_vq_loss += vq_loss.item()
            total_loss += loss.item()

    avg_recon_loss = total_recon_loss / num_batches
    avg_vq_loss = total_vq_loss / num_batches
    avg_total_loss = total_loss / num_batches

    print(f"Evaluation Results:")
    print(f"  Reconstruction Loss: {avg_recon_loss:.4f}")
    print(f"  VQ Loss: {avg_vq_loss:.4f}")
    print(f"  Total Loss: {avg_total_loss:.4f}")

    return avg_recon_loss, avg_vq_loss, avg_total_loss

# Example usage (assuming you have a validation dataloader named 'val_dataloader')
# Make sure your model is on the correct device before evaluating
device = 'mps'
model = model.to(device)  # Ensure model is on the correct device

recon_loss, vq_loss, total_loss = evaluate_model(model, test_dataloader, device)

# You can also evaluate on the training data to check for overfitting:
train_recon_loss, train_vq_loss, train_total_loss = evaluate_model(model, train_dataloader, device)



Evaluating: 100%|██████████| 563/563 [00:23<00:00, 23.92it/s]


Evaluation Results:
  Reconstruction Loss: 18.3410
  VQ Loss: 0.7024
  Total Loss: 19.0434


Evaluating: 100%|██████████| 2250/2250 [01:33<00:00, 24.13it/s]

Evaluation Results:
  Reconstruction Loss: 18.3371
  VQ Loss: 0.7100
  Total Loss: 19.0471
Training Evaluation:
  Reconstruction Loss: 18.3371
  VQ Loss: 0.7100
  Total Loss: 19.0471





In [26]:
torch.save(model.state_dict(), 'semantic_codec_final.pth') #saved with state

### T2s

In [None]:
class MaskGCTT2S(nn.Module):
    def __init__(
        self,
        vocab_size,
        semantic_vocab_size=8192,
        n_layers=16,
        d_model=1024,
        d_ff=4096,
        n_heads=16,
        max_seq_len=2048,
        dropout=0.1,
    ):
        super().__init__()
        # Text embeddings
        self.token_embeddings = nn.Embedding(vocab_size, d_model)
        
        # Semantic token embeddings (for masked tokens)
        self.semantic_embeddings = nn.Embedding(semantic_vocab_size, d_model)
        
        # Position embeddings (RoPE)
        self.pos_encoding = RotaryPositionalEmbedding(d_model // n_heads, max_seq_len)
        
        # Transformer layers
        self.layers = nn.ModuleList([
            TransformerLayer(
                d_model=d_model, 
                d_ff=d_ff, 
                n_heads=n_heads,
                dropout=dropout,
                bidirectional=True,  # Important: use bidirectional attention
                with_rope=True
            ) for _ in range(n_layers)
        ])
        
        # Output projection
        self.output_proj = nn.Linear(d_model, semantic_vocab_size)
        
        # Adaptive RMSNorm
        self.norm = AdaptiveRMSNorm(d_model)
        
    def forward(self, text_tokens, semantic_tokens=None, masks=None, timestep=None):
        """
        Forward pass with masked semantic tokens
        Args:
            text_tokens: Input text tokens [B, T_text]
            semantic_tokens: Target semantic tokens [B, T_semantic]
            masks: Masking tensor for semantic tokens [B, T_semantic]
            timestep: Current diffusion timestep
        """
        # Get batch size
        batch_size = text_tokens.size(0)
        
        # Text embeddings
        text_emb = self.token_embeddings(text_tokens)
        
        if semantic_tokens is not None:
            # Semantic embeddings
            semantic_emb = self.semantic_embeddings(semantic_tokens)
            
            # Apply masking
            if masks is not None:
                # Replace masked tokens with learned mask embedding
                semantic_emb = semantic_emb * (1 - masks.unsqueeze(-1))
            
            # Concatenate text and semantic embeddings
            x = torch.cat([text_emb, semantic_emb], dim=1)
        else:
            x = text_emb
            
        # Process through transformer layers
        for layer in self.layers:
            x = layer(x, timestep=timestep)
            
        # Normalize output
        x = self.norm(x, timestep)
        
        # Get only the semantic part predictions
        if semantic_tokens is not None:
            semantic_len = semantic_tokens.size(1)
            semantic_preds = x[:, -semantic_len:]
        else:
            semantic_preds = x[:, text_tokens.size(1):]
            
        # Project to semantic vocabulary
        logits = self.output_proj(semantic_preds)
        
        return logits


In [None]:
class DurationPredictor(nn.Module):
    def __init__(
        self,
        vocab_size,
        n_layers=12,
        d_model=768,
        n_heads=12,
        max_seq_len=2048,
    ):
        super().__init__()
        # Similar architecture to main model but smaller
        self.token_embeddings = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = RotaryPositionalEmbedding(d_model // n_heads, max_seq_len)
        
        # Transformer layers with bidirectional attention
        self.layers = nn.ModuleList([
            TransformerLayer(
                d_model=d_model, 
                d_ff=d_model*4, 
                n_heads=n_heads,
                bidirectional=True,
                with_rope=True
            ) for _ in range(n_layers)
        ])
        
        # Output projection for duration prediction
        self.output_proj = nn.Linear(d_model, 1)
        
    def forward(self, phoneme_tokens, prompt_phonemes=None, prompt_durations=None, timestep=None):
        """Flow matching based duration prediction"""
        # Implementation of flow matching for duration prediction
        # See Section A.5 in the paper


In [None]:
def train_t2s_model(model, data_loader, optimizer, device, mask_ratio=0.75):
    model.train()
    total_loss = 0
    
    for batch in data_loader:
        # Unpack batch
        text_tokens = batch["text_tokens"].to(device)
        semantic_tokens = batch["semantic_tokens"].to(device)
        
        # Create random masks for semantic tokens
        masks = torch.bernoulli(
            torch.ones_like(semantic_tokens) * mask_ratio
        ).to(device)
        
        # Random timestep for adaptive layer norm
        timestep = torch.rand(text_tokens.size(0), device=device)
        
        # Forward pass
        logits = model(text_tokens, semantic_tokens, masks, timestep)
        
        # Compute loss only on masked tokens
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            semantic_tokens.view(-1),
            reduction="none"
        )
        
        # Apply masking to loss
        loss = (loss * masks.view(-1)).sum() / (masks.sum() + 1e-8)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(data_loader)


In [None]:
def generate_semantic_tokens(
    model, 
    text_tokens, 
    duration_predictor,
    num_steps=25,  # Paper shows 25 steps is optimal
    device="cuda",
    guidance_scale=2.5,
    rescale_weight=0.75,
):
    model.eval()
    batch_size = text_tokens.size(0)
    
    # Get text length
    text_len = text_tokens.size(1)
    
    # Predict duration using duration predictor
    durations = duration_predictor(text_tokens)
    total_duration = durations.sum().int().item()
    
    # Initialize semantic tokens randomly
    semantic_tokens = torch.randint(
        0, model.semantic_vocab_size, (batch_size, total_duration),
        device=device
    )
    
    # Initialize with all tokens masked
    masks = torch.ones_like(semantic_tokens, device=device)
    
    # Progressive generation with 25 inference steps
    for step in range(num_steps):
        # Forward pass with current state
        with torch.no_grad():
            # Get conditional output
            logits_cond = model(text_tokens, semantic_tokens, masks)
            
            # Classifier-free guidance: Get unconditional output
            if guidance_scale > 1.0:
                # Drop prompt with probability 0.15 during training
                logits_uncond = model(torch.zeros_like(text_tokens), semantic_tokens, masks)
                
                # Apply classifier-free guidance
                logits = logits_cond + guidance_scale * (logits_cond - logits_uncond)
                
                # Apply rescaling
                std_cond = torch.std(logits_cond, dim=-1, keepdim=True)
                std_guided = torch.std(logits, dim=-1, keepdim=True)
                logits_rescaled = logits * (std_cond / std_guided)
                
                # Final logits
                logits = rescale_weight * logits_rescaled + (1 - rescale_weight) * logits
            else:
                logits = logits_cond
        
        # Get probabilities
        probs = F.softmax(logits, dim=-1)
        
        # Sample new tokens
        new_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(batch_size, -1)
        
        # Update semantic tokens at masked positions
        semantic_tokens = semantic_tokens * (1 - masks) + new_tokens * masks
        
        # Update masks for next iteration (progressively unmask)
        if step < num_steps - 1:
            # Create new random masks for remaining steps
            unmasked_ratio = (step + 1) / num_steps
            masks = torch.bernoulli(
                torch.ones_like(semantic_tokens) * (1 - unmasked_ratio)
            ).to(device)
    
    return semantic_tokens


In [None]:
def calculate_metrics(generated_speech, reference_speech, asr_model):
    """Calculate SIM and WER metrics"""
    # SIM: Speaker similarity using embeddings
    sim_score = calculate_similarity(generated_speech, reference_speech)
    
    # WER: Word Error Rate using ASR
    generated_text = asr_model.transcribe(generated_speech)
    reference_text = asr_model.transcribe(reference_speech)
    wer_score = calculate_wer(generated_text, reference_text)
    
    return sim_score, wer_score

In [None]:
text_tokenizer = create_tokenizer()  # G2P for English, BPE+jieba+pypinyin for Chinese
semantic_codec = SemanticCodec().to(device)
t2s_model = MaskGCTT2S(...).to(device)
duration_predictor = DurationPredictor(...).to(device)

# Tokenize input text
text = "Hello, this is a test."
text_tokens = text_tokenizer.tokenize(text)

# Generate semantic tokens
semantic_tokens = generate_semantic_tokens(
    t2s_model, 
    text_tokens, 
    duration_predictor,
    num_steps=25  # As recommended in paper
)
