In [1]:
import shutil
shutil.copy("/kaggle/input/tts-dataset/Vocab.py", "/kaggle/working")
shutil.copy("/kaggle/input/tts-dataset/extract_semantics.py", "/kaggle/working")
shutil.copy("/kaggle/input/tts-dataset/load_audio_features.py", "/kaggle/working")
shutil.copy("/kaggle/input/tts-dataset/semantic_codec_final_20k_2.pth", "/kaggle/working")

import sys
sys.path.append('/kaggle/working')

## Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, List
from dataclasses import dataclass
from Vocab import Vocab
import pandas as pd
import librosa
from extract_semantics import load_semantic_extractor , extract_semantics
from load_audio_features import get_all_features

2025-05-25 15:45:03.627661: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748187904.106693      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748187904.241373      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Data Preparation

In [3]:
# 2. Input Processing
# The model processes three types of sequences concatenated together:
# [TEXT_TOKENS] + [SEMANTIC_PROMPT] + [SEMANTIC_TARGET]
# Sequence Structure

# Text Tokens: Arabic text converted to token IDs
# Semantic Prompt: A prefix of semantic tokens from reference audio
# Semantic Target: The full semantic token sequence (masked during training)

In [4]:
df=pd.read_csv("/kaggle/input/tts-dataset/45k_embeddings.csv")
df.head(3)

Unnamed: 0,audio_file,clean_text,text_embedding
0,processed_fJ2vuI_700.mp3,مش بس المجاميع والناس اللي ورا اللي,"[0.00993357878178358, 0.004596792161464691, -0..."
1,processed_htNK0t_9.mp3,والضغط والمذاكره ومشاريع التخرج وقرفها,"[0.06607607007026672, 0.04131579399108887, -0...."
2,processed_0rMASI_336.mp3,بتاعه دواء يعني السكيزوفرينيا او الفصام,"[0.021152175962924957, -0.040665335953235626, ..."


In [5]:
tokenizer=Vocab()
df['tokenized_prompts']=df['clean_text'].apply(tokenizer.tokenize)

In [6]:
features = get_all_features() 

Validating audio files: 100%|██████████| 8000/8000 [02:16<00:00, 58.53it/s]


✓ valid audio files found: 8000


preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.98k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/290M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.81k [00:00<?, ?B/s]

Extracting Whisper features: 100%|██████████| 4000/4000 [36:52<00:00,  1.81it/s]


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Transpose(nn.Module):
    def __init__(self, dim0=1, dim1=2):
        super().__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x):
        return x.transpose(self.dim0, self.dim1)


class SemanticExtractor(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=384, codebook_size=8192, codebook_dim=8):
        super().__init__()
        # Only keep the encoder and quantizer parts
        self.encoder = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, kernel_size=7, padding=3),
            Transpose(),                             # (B, 384, L) → (B, L, 384)
            nn.LayerNorm(hidden_dim),               # now normalizes the last dim
            Transpose(),                             # (B, L, 384) → (B, 384, L)
            *[ConvNextBlock(hidden_dim) for _ in range(6)],
            nn.Conv1d(hidden_dim, codebook_dim, kernel_size=1)
        )

        self.quantizer = VectorQuantizer(num_embeddings=codebook_size, embedding_dim=codebook_dim)

    def forward(self, x):
        # x shape: (batch_size, sequence_length, input_dim)
        x = x.transpose(1, 2)  # To (batch_size, input_dim, sequence_length)
        z = self.encoder(x)  # Get encoded features
        z = z.transpose(1, 2)  # To (batch_size, sequence_length, codebook_dim)
        _, _, indices = self.quantizer(z)  # Get semantic tokens
        return indices

class ConvNextBlock(nn.Module):
    def __init__(self, dim, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, kernel_size, padding=kernel_size//2, groups=dim)  # Depthwise
        self.norm = nn.LayerNorm(dim)
        self.pwconv1 = nn.Linear(dim, 4 * dim)
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.act = nn.GELU()

    def forward(self, x):
        # x: (batch, dim, seq_len)
        residual = x
        x = self.conv(x)
        # Transpose for LayerNorm
        x = x.transpose(1, 2)  # (batch, seq_len, dim)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        # Transpose back
        x = x.transpose(1, 2)  # (batch, dim, seq_len)
        return x + residual

# Vector Quantization Layer
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings=8192, embedding_dim=8, commitment_cost=0.25):
        
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.embeddings = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
        self.register_buffer('ema_count', torch.zeros(num_embeddings))
        self.register_buffer('ema_weight', self.embeddings.clone())

    def forward(self, x):

        flat_x = x.reshape(-1, self.embedding_dim)
        distances = torch.cdist(flat_x, self.embeddings)
        encoding_indices = torch.argmin(distances, dim=1)
        quantized = self.embeddings[encoding_indices].reshape(x.shape)
        codebook_loss = F.mse_loss(quantized.detach(), x)
        commitment_loss = self.commitment_cost * F.mse_loss(quantized, x.detach())
        loss = codebook_loss + commitment_loss
        quantized = x + (quantized - x).detach()
        
        if self.training:
            with torch.no_grad():
                one_hot = F.one_hot(encoding_indices, self.num_embeddings).float()
                self.ema_count = 0.999 * self.ema_count + 0.001 * torch.sum(one_hot, dim=0)
                n = torch.sum(self.ema_count)
                self.ema_count = (self.ema_count + 1e-8) / (n + self.num_embeddings * 1e-8) * n
                dw = torch.matmul(one_hot.transpose(0, 1), flat_x)
                self.ema_weight = 0.999 * self.ema_weight + 0.001 * dw
                self.embeddings.data = (self.ema_weight / (self.ema_count.unsqueeze(-1) + 1e-8))
        
        return quantized, loss, encoding_indices

def load_semantic_extractor(model_path='/kaggle/working/semantic_codec_final_20k_2.pth', device='cuda'):
    print(device)
    model = SemanticExtractor(input_dim=512, hidden_dim=384, codebook_size=8192, codebook_dim=8)
    state_dict = torch.load(model_path, map_location=device)
 

    # Load encoder weights
    encoder_state = {k.replace("encoder.", ""): v for k, v in state_dict.items() if k.startswith("encoder.")}
    model.encoder.load_state_dict(encoder_state, strict=False)  # <=== set strict=False

    # Load quantizer weights
    quantizer_state = {k.replace("quantizer.", ""): v for k, v in state_dict.items() if k.startswith("quantizer.")}
    model.quantizer.load_state_dict(quantizer_state, strict=False)  

    model.to(device)
    model.eval()
    return model


def extract_semantics(audio_features, model):
    """
    Extract semantic tokens from audio features
    Args:
        audio_features: Tensor of shape (batch_size, sequence_length, 512)
        model: Loaded SemanticExtractor model
    Returns:
        semantic_tokens: Tensor of shape (batch_size, sequence_length)
    """
    with torch.no_grad():
        semantic_tokens = model(audio_features)
    return semantic_tokens





In [None]:
# from extract_semantics import load_semantic_extractor , extract_semantics
model = load_semantic_extractor(device='cuda')
semantic_tokens = []     
audio_files=[]       # collect results
for name, feat in features:
    if len(feat.shape) == 2:
        feat = feat.unsqueeze(0)  # Add batch dimension if missing
    audio_files.append(name)
    toks = extract_semantics(feat.to('cuda'), model)   # <-- adds batch dim already (1,…)
    semantic_tokens.append(toks.squeeze()) #|Remove batch dimension before appending



In [None]:

df=df[df['audio_file'].isin(audio_files)]
print(df.shape)

In [None]:
semantic_tokens

## T2S Arch

In [None]:
@dataclass
class MaskGCTConfig:
    vocab_size_text: int = 40000  # Text vocabulary size for Arabic
    vocab_size_semantic: int = 1024  # Semantic token vocabulary size
    max_seq_len: int = 2048
    n_layers: int = 12
    n_heads: int = 4
    d_model: int = 512
    d_ff: int = 1408  # 2.75 * d_model for GLU
    dropout: float = 0.1
    eps: float = 1e-5
    theta: float = 10000.0  # RoPE theta
    max_time_steps: int = 1000  # For diffusion scheduling
    max_position_embeddings=1152



In [None]:
class AdaptiveRMSNorm(nn.Module):
    """Adaptive RMSNorm that accepts time step as condition"""
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))
        # Time conditioning MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.SiLU(),
            nn.Linear(d_model, d_model * 2)  # scale and shift
        )
        
    def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, d_model)
        # time_emb: (batch, d_model)
        
        # Get time-dependent scale and shift
        time_out = self.time_mlp(time_emb)  # (batch, d_model * 2)
        scale, shift = time_out.chunk(2, dim=-1)  # Each: (batch, d_model)
        
        # Apply RMSNorm
        norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        
        # Apply time-dependent transformation
        # Expand scale and shift to match x dimensions
        scale = scale.unsqueeze(1)  # (batch, 1, d_model)
        shift = shift.unsqueeze(1)  # (batch, 1, d_model)
        
        return norm * self.weight * (1 + scale) + shift


In [None]:
class RotaryPositionalEmbedding(nn.Module):
    """Rotary Position Embedding (RoPE)"""
    def __init__(self, d_model: int,max_seq_len=2048, theta: float = 10000.0):
        super().__init__()
        self.d_model = d_model
        self.theta = theta
        
        # Precompute frequencies
        inv_freq = 1.0 / (theta ** (torch.arange(0, d_model, 2).float() / d_model))
        self.register_buffer('inv_freq', inv_freq)
        
        # Precompute position encodings
        t = torch.arange(max_seq_len).type_as(inv_freq)
        freqs = torch.einsum('i,j->ij', t, inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        self.register_buffer('cos_cached', emb.cos())
        self.register_buffer('sin_cached', emb.sin())
    
    def rotate_half(self, x):
        x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
        return torch.cat([-x2, x1], dim=-1)
    
    def forward(self, q, k, seq_len):
        cos = self.cos_cached[:seq_len, :]
        sin = self.sin_cached[:seq_len, :]
        
        # Apply rotary embedding to queries and keys
        q_rot = q * cos + self.rotate_half(q) * sin
        k_rot = k * cos + self.rotate_half(k) * sin
        
        return q_rot, k_rot


In [None]:
class GatedLinearUnit(nn.Module):
    """Gated Linear Unit with GELU activation"""
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
        self.up_proj = nn.Linear(d_model, d_ff, bias=False)
        self.down_proj = nn.Linear(d_ff, d_model, bias=False)
    
    def forward(self, x):
        gate = F.gelu(self.gate_proj(x))
        up = self.up_proj(x)
        return self.down_proj(gate * up)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.head_dim = self.d_model // self.n_heads
        
        assert self.head_dim * self.n_heads == self.d_model, "d_model must be divisible by n_heads"

        # Linear layers for q, k, v
        self.q_proj = nn.Linear(self.d_model, self.d_model)
        self.k_proj = nn.Linear(self.d_model, self.d_model)
        self.v_proj = nn.Linear(self.d_model, self.d_model)
        
        self.out_proj = nn.Linear(self.d_model, self.d_model)
        
        # Initialize Rotary Positional Embedding with longer max_seq_len
        self.rotary_emb = RotaryPositionalEmbedding(d_model=self.head_dim, max_seq_len=2048)

    def forward(self, x, attention_mask=None):
        batch_size, seq_len, _ = x.size()

        # Project to q, k, v
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        # Reshape for multi-head attention: (batch, seq_len, n_heads, head_dim)
        q = q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)  # (batch, heads, seq_len, head_dim)
        k = k.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        
        # Apply Rotary Positional Embedding to q and k
        # Note: RoPE expects (batch, heads, seq_len, head_dim)
        q, k = self.rotary_emb(q, k, seq_len)
        
        # Scaled dot-product attention
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # (batch, heads, seq_len, seq_len)
        
        if attention_mask is not None:
            attn_scores = attn_scores.masked_fill(attention_mask == 0, float('-inf'))
        
        attn_probs = torch.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_probs, v)  # (batch, heads, seq_len, head_dim)
        
        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # Final output projection
        out = self.out_proj(attn_output)
        
        return out


In [None]:
class TransformerBlock(nn.Module):
    """Transformer block with bidirectional attention and GLU"""
    def __init__(self, config: MaskGCTConfig):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.feed_forward = GatedLinearUnit(config.d_model, config.d_ff)
        self.norm1 = AdaptiveRMSNorm(config.d_model, config.eps)
        self.norm2 = AdaptiveRMSNorm(config.d_model, config.eps)
        
    def forward(self, x, time_emb, attention_mask=None):
        # Pre-norm attention
        normed_x = self.norm1(x, time_emb)
        attn_out = self.attention(normed_x, attention_mask)
        x = x + attn_out
        
        # Pre-norm feed-forward
        normed_x = self.norm2(x, time_emb)
        ff_out = self.feed_forward(normed_x)
        x = x + ff_out
        
        return x

In [None]:
class TimeEmbedding(nn.Module):
    """Sinusoidal time embedding for diffusion steps"""
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        
    def forward(self, time_steps):
        half_dim = self.d_model // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=time_steps.device) * -emb)
        emb = time_steps[:, None] * emb[None, :]
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb

In [32]:
class MaskGCT_T2S(nn.Module):
    """Text-to-Semantic MaskGCT Model"""
    def __init__(self, config: MaskGCTConfig):
        super().__init__()
        self.config = config
        
        # Embeddings
        self.text_embedding = nn.Embedding(config.vocab_size_text, config.d_model)
        self.semantic_embedding = nn.Embedding(config.vocab_size_semantic, config.d_model)
        
        # Time embedding for diffusion
        self.time_embedding = TimeEmbedding(config.d_model)
        self.time_mlp = nn.Sequential(
            nn.Linear(config.d_model, config.d_model),
            nn.SiLU(),
            nn.Linear(config.d_model, config.d_model)
        )
        
        # Transformer layers
        self.layers = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layers)
        ])
        
        # Output head
        self.output_norm = AdaptiveRMSNorm(config.d_model, config.eps)
        self.output_proj = nn.Linear(config.d_model, config.vocab_size_semantic)
        
        # Special tokens
        self.mask_token_id = config.vocab_size_semantic - 1
        self.pad_token_id = 0
        
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def create_attention_mask(self, text_len, prompt_len, target_len):
        """Create attention mask for [text, prompt, target] sequence"""
        total_len = text_len + prompt_len + target_len
        mask = torch.ones(total_len, total_len)
        
        # Text can attend to itself
        mask[:text_len, :text_len] = 1
        
        # Prompt can attend to text and itself
        mask[text_len:text_len+prompt_len, :text_len+prompt_len] = 1
        
        # Target can attend to text, prompt, and itself (bidirectional)
        mask[text_len+prompt_len:, :] = 1
        
        return mask.unsqueeze(0)  # Add batch dimension
    
    def forward(self, 
                text_tokens: torch.Tensor,
                semantic_prompt: torch.Tensor,
                semantic_target: torch.Tensor,
                time_steps: torch.Tensor,
                mask_ratio: float = 0.15):
        """
        Forward pass for training
        
        Args:
            text_tokens: Text token sequence (batch, text_len)
            semantic_prompt: Prompt semantic tokens (batch, prompt_len)
            semantic_target: Target semantic tokens (batch, target_len)
            time_steps: Diffusion time steps (batch,)
            mask_ratio: Ratio of tokens to mask
        """
        batch_size = text_tokens.shape[0]
        text_len = text_tokens.shape[1]
        prompt_len = semantic_prompt.shape[1]
        target_len = semantic_target.shape[1]
        
        # Create masked target
        masked_target = semantic_target.clone()
        mask = torch.rand(batch_size, target_len) < mask_ratio
        masked_target[mask] = self.mask_token_id
        
        # Embed tokens
        text_emb = self.text_embedding(text_tokens)
        prompt_emb = self.semantic_embedding(semantic_prompt)
        target_emb = self.semantic_embedding(masked_target)
        
        # Concatenate sequences: [text, prompt, target]
        x = torch.cat([text_emb, prompt_emb, target_emb], dim=1)
        
        # Time embedding
        time_emb = self.time_embedding(time_steps)
        time_emb = self.time_mlp(time_emb)
        
        # Create attention mask
        attention_mask = self.create_attention_mask(text_len, prompt_len, target_len)
        attention_mask = attention_mask.to(x.device)
        
        # Apply transformer layers
        for layer in self.layers:
            x = layer(x, time_emb, attention_mask)
        
        # Apply output normalization and projection
        x = self.output_norm(x, time_emb)
        logits = self.output_proj(x)
        
        # Return only target logits
        target_logits = logits[:, text_len + prompt_len:, :]
        
        return target_logits
    
    def generate(self,
                 text_tokens: torch.Tensor,
                 semantic_prompt: torch.Tensor,
                 target_length: int,
                 num_steps: int = 20,
                 temperature: float = 1.0,
                 top_k: int = None,
                 top_p: float = None):
        """
        Generate semantic tokens given text and prompt
        
        Args:
            text_tokens: Text token sequence (batch, text_len)
            semantic_prompt: Prompt semantic tokens (batch, prompt_len)
            target_length: Length of target sequence to generate
            num_steps: Number of denoising steps
            temperature: Sampling temperature
            top_k: Top-k sampling
            top_p: Nucleus sampling
        """
        batch_size = text_tokens.shape[0]
        device = text_tokens.device
        
        # Initialize with all mask tokens
        semantic_target = torch.full(
            (batch_size, target_length), 
            self.mask_token_id, 
            device=device, 
            dtype=torch.long
        )
        
        # Iterative denoising
        for step in range(num_steps):
            # Current time step
            t = torch.full((batch_size,), step / num_steps * self.config.max_time_steps, device=device)
            
            # Forward pass
            with torch.no_grad():
                logits = self.forward_inference(text_tokens, semantic_prompt, semantic_target, t)
            
            # Apply temperature
            logits = logits / temperature
            
            # Apply top-k and top-p filtering
            if top_k is not None:
                top_k_logits, top_k_indices = torch.topk(logits, top_k, dim=-1)
                logits = torch.full_like(logits, float('-inf'))
                logits.scatter_(-1, top_k_indices, top_k_logits)
            
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                
                indices_to_remove = sorted_indices_to_remove.scatter(
                    -1, sorted_indices, sorted_indices_to_remove
                )
                logits[indices_to_remove] = float('-inf')
            
            # Sample tokens
            probs = F.softmax(logits, dim=-1)
            new_tokens = torch.multinomial(probs.view(-1, probs.shape[-1]), 1).view(batch_size, target_length)
            
            # Update masked positions
            mask_positions = (semantic_target == self.mask_token_id)
            
            # Schedule: unmask some tokens each step
            num_to_unmask = max(1, int(mask_positions.sum().item() * (1 - (step + 1) / num_steps)))
            
            # Choose positions to unmask based on confidence
            confidence = probs.max(dim=-1)[0]
            confidence[~mask_positions] = -1  # Don't consider already unmasked positions
            
            # Get top confident positions to unmask
            _, top_indices = torch.topk(confidence.view(batch_size, -1), num_to_unmask, dim=-1)
            
            # Update tokens
            for b in range(batch_size):
                for idx in top_indices[b]:
                    if mask_positions[b, idx]:
                        semantic_target[b, idx] = new_tokens[b, idx]
        
        return semantic_target
    
    def forward_inference(self, text_tokens, semantic_prompt, semantic_target, time_steps):
        """Forward pass for inference (no masking)"""
        batch_size = text_tokens.shape[0]
        text_len = text_tokens.shape[1]
        prompt_len = semantic_prompt.shape[1]
        target_len = semantic_target.shape[1]
        
        # Embed tokens
        text_emb = self.text_embedding(text_tokens)
        prompt_emb = self.semantic_embedding(semantic_prompt)
        target_emb = self.semantic_embedding(semantic_target)
        
        # Concatenate sequences
        x = torch.cat([text_emb, prompt_emb, target_emb], dim=1)
        
        # Time embedding
        time_emb = self.time_embedding(time_steps)
        time_emb = self.time_mlp(time_emb)
        
        # Create attention mask
        attention_mask = self.create_attention_mask(text_len, prompt_len, target_len)
        attention_mask = attention_mask.to(x.device)
        
        # Apply transformer layers
        for layer in self.layers:
            x = layer(x, time_emb, attention_mask)
        
        # Apply output normalization and projection
        x = self.output_norm(x, time_emb)
        logits = self.output_proj(x)
        
        # Return only target logits
        return logits[:, text_len + prompt_len:, :]



In [45]:

# Example usage and training setup
def create_model_config():
    """Create configuration for Arabic TTS"""
    return MaskGCTConfig(
        vocab_size_text=100,  # Adjust based on your Arabic tokenizer
        vocab_size_semantic=7500,  # Adjust based on your semantic codec
        max_seq_len=1024,
        n_layers=6,
        n_heads=2,
        d_model=512,
        d_ff=1408,
        dropout=0.1,
        eps=1e-5,
        theta=10000.0,
        max_time_steps=1000
)
    





## Dataset and Dataloaders

In [46]:
from torch.utils.data import Dataset

class TextSemanticDataset(Dataset):
    def __init__(self, df, semantic_tokens_list, max_semantic_len=1000, min_prompt_len=5, max_prompt_len=50):
        """
        df: pandas DataFrame with 'tokenized_prompts' column (list of tokens)
        semantic_tokens_list: list of semantic tokens tensors (squeezed 1D LongTensor)
        """
        assert len(df) == len(semantic_tokens_list), "Mismatch between text and semantic tokens length"

        self.text_tokens = df['tokenized_prompts'].tolist()
        self.semantic_tokens = semantic_tokens_list
        self.max_semantic_len = max_semantic_len
        self.min_prompt_len = min_prompt_len
        self.max_prompt_len = max_prompt_len

    def random_prefix(self, semantic_tokens):
        if semantic_tokens.shape[0] > self.max_semantic_len:
            semantic_tokens = semantic_tokens[:self.max_semantic_len]

        prompt_len = min(
            max(self.min_prompt_len, semantic_tokens.shape[0] // 4),
            min(self.max_prompt_len, semantic_tokens.shape[0] - 1)
        )

        semantic_prompt = semantic_tokens[:prompt_len]
        semantic_target = semantic_tokens

        max_target_len = 1024 - prompt_len
        semantic_prompt = semantic_tokens[:prompt_len]
        semantic_target = semantic_tokens[:max_target_len]

        
        return semantic_prompt, semantic_target

    def __len__(self):
        return len(self.text_tokens)

    def __getitem__(self, idx):
        text = torch.LongTensor(self.text_tokens[idx])
        
        name, semantic_tensor = self.semantic_tokens[idx]  # <-- Unpack tuple
        if semantic_tensor.dim() == 2:
            semantic_tensor = semantic_tensor.squeeze(0)

        semantic_prompt, semantic_target = self.random_prefix(semantic_tensor)

        return {
            'text_tokens': text,
            'semantic_prompt': semantic_prompt,
            'semantic_target': semantic_target,
        }


In [47]:
class TextSemanticDataset(Dataset):
    def __init__(self, df, semantic_tokens_list, max_semantic_len=1000, min_prompt_len=5, max_prompt_len=50):
        """
        df: pandas DataFrame with 'tokenized_prompts' column (list of tokens)
        semantic_tokens_list: list of semantic tokens tensors (squeezed 1D LongTensor)
        """
        assert len(df) == len(semantic_tokens_list), "Mismatch between text and semantic tokens length"
        
        self.text_tokens = df['tokenized_prompts'].tolist()
        self.semantic_tokens = semantic_tokens_list  # This should be a list of tensors, not tuples
        self.max_semantic_len = max_semantic_len
        self.min_prompt_len = min_prompt_len
        self.max_prompt_len = max_prompt_len

    def __len__(self):
        return len(self.text_tokens)

    def __getitem__(self, idx):
        text = torch.LongTensor(self.text_tokens[idx])
        
        # Get semantic tensor directly, no need to unpack
        semantic_tensor = self.semantic_tokens[idx]
        
        # Ensure semantic tensor is 1D
        if semantic_tensor.dim() == 2:
            semantic_tensor = semantic_tensor.squeeze(0)
        
        # Truncate if too long
        if len(semantic_tensor) > self.max_semantic_len:
            semantic_tensor = semantic_tensor[:self.max_semantic_len]
        
        # Truncate text if too long
        if len(text) > self.max_prompt_len:
            text = text[:self.max_prompt_len]
        
        return {
            'text': text,
            'semantic': semantic_tensor
        }

In [48]:
class TextSemanticDataset(Dataset):
    def __init__(self, df, semantic_tokens_list, max_semantic_len=1000, min_prompt_len=5, max_prompt_len=50):
        """
        df: pandas DataFrame with 'tokenized_prompts' column (list of tokens)
        semantic_tokens_list: list of semantic tokens tensors (squeezed 1D LongTensor)
        """
        assert len(df) == len(semantic_tokens_list), "Mismatch between text and semantic tokens length"
        
        self.text_tokens = df['tokenized_prompts'].tolist()
        self.semantic_tokens = semantic_tokens_list  # List of tensors
        self.max_semantic_len = max_semantic_len
        self.min_prompt_len = min_prompt_len
        self.max_prompt_len = max_prompt_len

    def random_prefix(self, semantic_tokens):
        if semantic_tokens.shape[0] > self.max_semantic_len:
            semantic_tokens = semantic_tokens[:self.max_semantic_len]

        prompt_len = min(
            max(self.min_prompt_len, semantic_tokens.shape[0] // 4),
            min(self.max_prompt_len, semantic_tokens.shape[0] - 1)
        )

        semantic_prompt = semantic_tokens[:prompt_len]
        semantic_target = semantic_tokens

        max_target_len = 1024 - prompt_len
        semantic_prompt = semantic_tokens[:prompt_len]
        semantic_target = semantic_tokens[:max_target_len]

        return semantic_prompt, semantic_target

    def __len__(self):
        return len(self.text_tokens)

    def __getitem__(self, idx):
        # Get text tokens
        text = torch.LongTensor(self.text_tokens[idx])
        
        # Get semantic tokens (already a tensor)
        semantic_tensor = self.semantic_tokens[idx]
        
        # Ensure semantic tensor is 1D
        if semantic_tensor.dim() == 2:
            semantic_tensor = semantic_tensor.squeeze(0)
        
        # Get prompt and target
        semantic_prompt, semantic_target = self.random_prefix(semantic_tensor)

        return {
            'text_tokens': text,
            'semantic_prompt': semantic_prompt,
            'semantic_target': semantic_target,
        }

In [49]:
def collate_fn(batch):
    # batch is a list of dicts with 'text_tokens', 'semantic_prompt', 'semantic_target'
    max_text_len = max(item['text_tokens'].size(0) for item in batch)
    max_prompt_len = max(item['semantic_prompt'].size(0) for item in batch)
    max_target_len = max(item['semantic_target'].size(0) for item in batch)

    max_text_len = min(max_text_len, 2048)
    max_prompt_len = min(max_prompt_len, 2048)
    max_target_len = min(max_target_len, 2048)

    # pad each tensor to max length in batch (or max 2048)
    text_tokens = torch.stack([torch.nn.functional.pad(item['text_tokens'], (0, max_text_len - item['text_tokens'].size(0))) for item in batch])
    semantic_prompt = torch.stack([torch.nn.functional.pad(item['semantic_prompt'], (0, max_prompt_len - item['semantic_prompt'].size(0))) for item in batch])
    semantic_target = torch.stack([torch.nn.functional.pad(item['semantic_target'], (0, max_target_len - item['semantic_target'].size(0))) for item in batch])

    return {
        'text_tokens': text_tokens,
        'semantic_prompt': semantic_prompt,
        'semantic_target': semantic_target,
    }


In [50]:
filtered_df=df[df['audio_file'].isin(audio_files)]

In [51]:
# 1. Create a quick lookup map from audio file names to their semantic tokens
audio_to_semantic_map = {name: token for name, token in zip(audio_files, semantic_tokens)}

# 2. Use the audio file names from your *filtered* DataFrame to get the relevant tokens
filtered_semantic_tokens = [audio_to_semantic_map[name] for name in filtered_df['audio_file'].tolist()]

In [52]:
df.shape
df.columns
print(len(filtered_df))
print(len(filtered_semantic_tokens))

7666
7666


In [53]:
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import torch.nn as nn
import torch
# Now create the dataset
dataset = TextSemanticDataset(filtered_df,filtered_semantic_tokens)


dataset

<__main__.TextSemanticDataset at 0x7fa1e0798050>

In [54]:
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import torch.nn as nn
import torch

# Assuming you have your dataset ready
# Split dataset into train and test (80-20 split)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create separate dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")


Train dataset size: 6132
Test dataset size: 1534


## Train

In [55]:
config=create_model_config()
model = MaskGCT_T2S(config).to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
device='cuda'

In [56]:
def training_step(model, batch, optimizer, criterion):
    """Single training step"""
    text_tokens = batch['text_tokens']  # (batch, text_len)
    semantic_prompt = batch['semantic_prompt']  # (batch, prompt_len)
    semantic_target = batch['semantic_target']  # (batch, target_len)
    
    # Random time steps
    batch_size = text_tokens.shape[0]
    time_steps = torch.randint(0, model.config.max_time_steps, (batch_size,), device=text_tokens.device)
    
    # Forward pass
    logits = model(text_tokens, semantic_prompt, semantic_target, time_steps)
    
    # Compute loss (cross-entropy with masked positions)
    loss = criterion(logits.reshape(-1, logits.shape[-1]), semantic_target.reshape(-1))

    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()


In [None]:
from tqdm import tqdm
import torch.nn as nn

# Setup
model = MaskGCT_T2S(config)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda")
model.to(device)

# Training loop
num_epochs = 5  # Define the number of epochs
model.train()

for epoch in range(num_epochs):
    total_loss = 0.0
    
    with tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}") as pbar:
        for batch_idx, batch in enumerate(pbar):
            # Move batch to device
            for key in batch:
                batch[key] = batch[key].to(device)
            
            # Training step
            loss = training_step(model, batch, optimizer, criterion)
            total_loss += loss  # Accumulate the loss
            
            # Optionally: print every N batches
            if (batch_idx + 1) % 10 == 0:
                avg_loss = total_loss / (batch_idx + 1)
                pbar.set_postfix({"Avg Loss": avg_loss}) # Update tqdm progress bar
                
    # Print epoch summary
    epoch_avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}/{num_epochs} - Avg Loss: {epoch_avg_loss:.4f}")


Epoch 1/5: 100%|██████████| 1533/1533 [08:23<00:00,  3.04it/s, Avg Loss=0.0469]


Epoch 1/5 - Avg Loss: 0.0469


Epoch 2/5: 100%|██████████| 1533/1533 [08:23<00:00,  3.04it/s, Avg Loss=0.01]  


Epoch 2/5 - Avg Loss: 0.0100


Epoch 3/5: 100%|██████████| 1533/1533 [08:23<00:00,  3.04it/s, Avg Loss=0.00949]


Epoch 3/5 - Avg Loss: 0.0095


Epoch 4/5:   7%|▋         | 106/1533 [00:34<07:47,  3.05it/s, Avg Loss=0.0093]

In [None]:
torch.save(model.state_dict(), 'T2S.pth') #saved with state