In [1]:
import os
import time
import math
import torch
import torchaudio
import torch.nn as nn
import torchaudio.transforms as T
from torch.utils.data import DataLoader
from datasets import load_dataset

############################################################
# Character set and indexing
############################################################
character_set = list(" aăâbcdđeêfghijklmnoôơpqrstuưvwxyzAĂÂBCDĐEFGHIJKLMNOPQRSTUVWXYZ0123456789.,!?")
if '<unk>' not in character_set:
    character_set.insert(0, '<unk>')
char2idx = {c: i for i, c in enumerate(character_set)}

############################################################
# Helper to compute mean and std of mel-spectrograms
############################################################
def compute_mel_stats(dataset_name, split, cache_dir="./dataset_cache", n_mels=80, sample_rate=16000):
    """
    Compute mean and std of mel spectrograms across the dataset.
    For large datasets, this might be slow. Consider caching the results.
    """
    dataset = load_dataset(dataset_name, split=split, cache_dir=cache_dir)
    mel_transform = T.MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels)
    sums = 0.0
    squared_sums = 0.0
    count = 0

    print("Computing mel stats (mean/std) over the dataset...")
    for item in dataset:
        audio_array = item["audio"]["array"]
        audio_tensor = torch.tensor(audio_array, dtype=torch.float32)
        mel_spec = mel_transform(audio_tensor)  # [n_mels, time]
        # Flatten mel_spec for mean/std computation
        sums += mel_spec.sum().item()
        squared_sums += (mel_spec ** 2).sum().item()
        count += mel_spec.numel()

    mean = sums / count
    var = (squared_sums / count) - (mean ** 2)
    std = math.sqrt(var)
    print(f"Computed mel stats: mean={mean:.4f}, std={std:.4f}")
    return mean, std

############################################################
# Dataset & DataLoader
############################################################
class TTSDataLoader:
    def __init__(self, dataset_name="doof-ferb/vlsp2020_vinai_100h", split="train",
                 cache_dir="./dataset_cache", max_retries=5, retry_delay=5,
                 mel_mean=None, mel_std=None, n_mels=80, sample_rate=16000):
        """
        Class to load and preprocess the TTS dataset.
        Downloads from Hugging Face if not cached.
        Normalizes mel-spectrograms if mean and std are provided.
        """
        self.dataset_name = dataset_name
        self.split = split
        self.cache_dir = cache_dir
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.n_mels = n_mels
        self.sample_rate = sample_rate
        self.dataset = self.load_dataset_with_retry()

        # Mel transform
        self.mel_transform = T.MelSpectrogram(sample_rate=self.sample_rate, n_mels=self.n_mels)

        # Normalization params
        self.mel_mean = mel_mean
        self.mel_std = mel_std

    def load_dataset_with_retry(self):
        retries = 0
        while retries < self.max_retries:
            try:
                dataset = load_dataset(self.dataset_name, split=self.split, cache_dir=self.cache_dir)
                return dataset
            except Exception as e:
                if "503" in str(e) or "ConnectionError" in str(e):
                    print(f"Server error ({type(e).__name__}), retrying {retries + 1}/{self.max_retries} in {self.retry_delay}s...")
                    retries += 1
                    time.sleep(self.retry_delay)
                else:
                    raise e
        raise Exception(f"Failed to load dataset after {self.max_retries} retries.")

    def __getitem__(self, index):
        """
        Return text_tensor and mel_spec for the given index.
        text_tensor: LongTensor [text_length]
        mel_spec: FloatTensor [n_mels, time_steps]
        """
        item = self.dataset[index]
        text = item["transcription"]
        audio_array = item["audio"]["array"]

        # Map characters to indices
        text_indices = [char2idx.get(c, char2idx['<unk>']) for c in text]
        text_tensor = torch.tensor(text_indices, dtype=torch.long)

        # Convert audio to mel spectrogram
        audio_tensor = torch.tensor(audio_array, dtype=torch.float32)
        mel_spec = self.mel_transform(audio_tensor)
        if self.mel_mean is not None and self.mel_std is not None:
            mel_spec = (mel_spec - self.mel_mean) / (self.mel_std + 1e-5)

        return text_tensor, mel_spec

    def __len__(self):
        return len(self.dataset)


def collate_fn(batch):
    """
    Pad sequences in the batch for both text and mel specs.
    Return padded_texts, padded_mels, text_lengths, mel_lengths.
    """
    texts, mels = zip(*batch)
    text_lengths = torch.tensor([len(text) for text in texts], dtype=torch.long)
    mel_lengths = torch.tensor([mel.size(1) for mel in mels], dtype=torch.long)

    max_text_length = text_lengths.max().item()
    max_mel_length = mel_lengths.max().item()

    # Pad text
    padded_texts = torch.zeros(len(texts), max_text_length, dtype=torch.long)
    for i, text in enumerate(texts):
        padded_texts[i, :len(text)] = text

    # Pad mel
    padded_mels = torch.zeros(len(mels), mels[0].size(0), max_mel_length, dtype=torch.float32)
    for i, mel in enumerate(mels):
        padded_mels[i, :, :mel.size(1)] = mel

    return padded_texts, padded_mels, text_lengths, mel_lengths

############################################################
# Model Components
############################################################
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # [max_len, 1, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: [seq_len, batch, d_model]
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


class TransformerEncoder(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward, dropout):
        super(TransformerEncoder, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=False
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        return self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)


class TransformerDecoder(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward, dropout):
        super(TransformerDecoder, self).__init__()
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=False
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        return self.transformer_decoder(
            tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
            tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask
        )


class TransformerTTS(nn.Module):
    def __init__(self, input_dim, output_dim, d_model=256, nhead=4, num_layers=4,
                 dim_feedforward=1024, dropout=0.1):
        super(TransformerTTS, self).__init__()

        self.d_model = d_model
        self.text_embedding = nn.Embedding(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)

        self.mel_embed = nn.Linear(output_dim, d_model)
        self.pos_decoder = PositionalEncoding(d_model, dropout=dropout)

        self.encoder = TransformerEncoder(d_model, nhead, num_layers, dim_feedforward, dropout)
        self.decoder = TransformerDecoder(d_model, nhead, num_layers, dim_feedforward, dropout)

        self.fc_out = nn.Linear(d_model, output_dim)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None,
                src_key_padding_mask=None, tgt_key_padding_mask=None):
        # Embed text
        src_embedded = self.text_embedding(src) * math.sqrt(self.d_model)
        src_embedded = src_embedded.permute(1, 0, 2)  # [seq_len, batch, d_model]
        src_embedded = self.pos_encoder(src_embedded)

        # Embed mel
        tgt_embedded = self.mel_embed(tgt.permute(0, 2, 1)) * math.sqrt(self.d_model)
        tgt_embedded = tgt_embedded.permute(1, 0, 2)
        tgt_embedded = self.pos_decoder(tgt_embedded)

        memory = self.encoder(src_embedded, src_mask, src_key_padding_mask)
        output = self.decoder(tgt_embedded, memory, tgt_mask, None, tgt_key_padding_mask, src_key_padding_mask)

        output = self.fc_out(output)  # [tgt_seq_len, batch, output_dim]
        output = output.permute(1, 2, 0)  # [batch, output_dim, seq_len]
        return output

############################################################
# Training Function
############################################################
def train():
    # Hyperparameters
    input_dim = len(character_set)
    output_dim = 80  # Mel bands
    d_model = 256
    nhead = 4
    num_layers = 4           # Reduced from 6 to 4 for stability
    dim_feedforward = 1024
    dropout = 0.05
    batch_size = 8
    learning_rate = 5e-5     # Reduced LR from 1e-4 to 5e-5
    num_epochs = 10
    grad_clip = 0.5

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Compute mel stats (only do once and save if needed)
    mel_mean, mel_std = compute_mel_stats("doof-ferb/vlsp2020_vinai_100h", "train", n_mels=output_dim, sample_rate=16000)

    # Initialize model, criterion, optimizer
    model = TransformerTTS(
        input_dim=input_dim,
        output_dim=output_dim,
        d_model=d_model,
        nhead=nhead,
        num_layers=num_layers,
        dim_feedforward=dim_feedforward,
        dropout=dropout,
    ).to(device)

    # Use L1Loss for more stable training
    criterion = nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Initialize dataset & dataloader
    dataset = TTSDataLoader(
        split="train",
        mel_mean=mel_mean,
        mel_std=mel_std,
        n_mels=output_dim,
        sample_rate=16000
    )
    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=4)

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0

        for i, (src, tgt, src_lengths, tgt_lengths) in enumerate(dataloader):
            src = src.to(device)
            tgt = tgt.to(device)

            optimizer.zero_grad()
            output = model(src, tgt)
            loss = criterion(output, tgt)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            optimizer.step()
            epoch_loss += loss.item()

            if i % 10 == 0:
                print(f"Epoch {epoch}, Step {i}, Loss: {loss.item():.4f}")

        avg_epoch_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch} completed. Average Loss: {avg_epoch_loss:.4f}")

        # Save checkpoint
        os.makedirs("checkpoints", exist_ok=True)
        torch.save(model.state_dict(), f"checkpoints/transformer_tts_epoch{epoch}.pth")


if __name__ == "__main__":
    train()

Resolving data files:   0%|          | 0/35 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/35 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/35 [00:00<?, ?files/s]

Generating train split:   0%|          | 0/56427 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [3]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckpt = torch.load("checkpoints/hifigan_gen_universal.pth", map_location=device, weights_only=True)
print(ckpt.keys())

dict_keys(['generator'])


In [6]:
import torch

# Load the checkpoint on the GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckpt = torch.load("checkpoints/checkpoints_23_12_2024_TTS_Transformer/transformer_tts_epoch24.pth", map_location=device, weights_only=True)
print(ckpt.keys())

odict_keys(['module.text_embedding.weight', 'module.pos_encoder.pe', 'module.mel_embed.weight', 'module.mel_embed.bias', 'module.pos_decoder.pe', 'module.encoder.transformer_encoder.layers.0.self_attn.in_proj_weight', 'module.encoder.transformer_encoder.layers.0.self_attn.in_proj_bias', 'module.encoder.transformer_encoder.layers.0.self_attn.out_proj.weight', 'module.encoder.transformer_encoder.layers.0.self_attn.out_proj.bias', 'module.encoder.transformer_encoder.layers.0.linear1.weight', 'module.encoder.transformer_encoder.layers.0.linear1.bias', 'module.encoder.transformer_encoder.layers.0.linear2.weight', 'module.encoder.transformer_encoder.layers.0.linear2.bias', 'module.encoder.transformer_encoder.layers.0.norm1.weight', 'module.encoder.transformer_encoder.layers.0.norm1.bias', 'module.encoder.transformer_encoder.layers.0.norm2.weight', 'module.encoder.transformer_encoder.layers.0.norm2.bias', 'module.encoder.transformer_encoder.layers.1.self_attn.in_proj_weight', 'module.encoder.