<a href="https://colab.research.google.com/github/hussainturii/TTS/blob/main/f5_tts_mini(with_DiTs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Mel Spectogram**

In [None]:
import torch
import torchaudio
import librosa
import numpy as np

# Mel config (choose these carefully and keep consistent)
SR = 22050
N_MELS = 80
N_FFT = 1024
HOP_LENGTH = 256
WIN_LENGTH = 1024
FMIN = 0
FMAX = SR // 2

# torchaudio transform (fast, batched)
mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=SR, n_fft=N_FFT, hop_length=HOP_LENGTH, win_length=WIN_LENGTH,
    n_mels=N_MELS, f_min=FMIN, f_max=FMAX, power=1.0  # power=1.0 -> amplitude; torchaudio returns power by default, but we set 1
)

def wav_to_mel(waveform, sr=SR):
    """
    waveform: numpy array or tensor (samples,) or (1, samples)
    returns: mel (n_mels, T) as numpy float32 (linear or amplitude) - we will convert to log when needed
    """
    if isinstance(waveform, np.ndarray):
        wav = torch.from_numpy(waveform).float()
    else:
        wav = waveform.float()
    if wav.dim() == 1:
        wav = wav.unsqueeze(0)
    # Resample if needed (omitted here; ensure input sr == SR)
    mel = mel_transform(wav)  # (1, n_mels, T)
    mel = mel.squeeze(0).cpu().numpy()
    # Convert to dB for many vocoders / normalization if needed:
    # mel_db = librosa.power_to_db(mel, ref=np.max)
    return mel  # linear-amplitude mel (not log) — record what you used!


In [None]:
# Minimal char tokenizer
CHARS = list("abcdefghijklmnopqrstuvwxyz!'?,. ")
VOCAB = {c: i+1 for i,c in enumerate(CHARS)}  # 0 reserved for padding
VOCAB["<unk>"] = len(VOCAB)+1
PAD = 0

def text_to_ids(text):
    text = text.lower()
    return [VOCAB.get(ch, VOCAB["<unk>"]) for ch in text]

# Example
print(text_to_ids("Hello, world!"))


[8, 5, 12, 12, 15, 30, 32, 23, 15, 18, 12, 4, 27]


**Transformer**

In [None]:
import torch.nn as nn
import math

class TextEncoder(nn.Module):
    def __init__(self, vocab_size, model_dim=256, nhead=4, num_layers=3, max_len=512, dropout=0.1):
        super().__init__()
        self.model_dim = model_dim
        self.embed = nn.Embedding(vocab_size, model_dim, padding_idx=PAD)
        self.pos = nn.Parameter(torch.zeros(max_len, model_dim))
        nn.init.normal_(self.pos, std=0.02)
        encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=nhead, dim_feedforward=model_dim*4, dropout=dropout, activation="gelu")
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    def forward(self, token_ids):
        # token_ids: (B, L)
        B, L = token_ids.shape
        x = self.embed(token_ids) + self.pos[:L].unsqueeze(0)  # (B, L, D)
        x = x.permute(1, 0, 2)  # (L, B, D)
        out = self.encoder(x)   # (L, B, D)
        return out.permute(1, 0, 2)  # (B, L, D)


In [None]:
import torch.nn.functional as F
import torch

def get_timestep_embedding(timesteps, dim):
    # sinusoidal embedding
    timesteps = timesteps.float()
    half = dim // 2
    freqs = torch.exp(-math.log(10000) * torch.arange(half, device=timesteps.device) / (half - 1))
    args = timesteps[:, None] * freqs[None, :]
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    if dim % 2 == 1:
        emb = F.pad(emb, (0,1))
    return emb  # (B, dim)


In [None]:
import torch.nn as nn
import torch

class DiffusionTransformer(nn.Module):
    def __init__(self, n_mels=80, model_dim=256, nhead=4, num_layers=6, max_frames=1024):
        super().__init__()
        self.n_mels = n_mels
        self.model_dim = model_dim
        # frame projection
        self.mel_proj = nn.Linear(n_mels, model_dim)
        # learned pos for mel frames
        self.pos = nn.Parameter(torch.zeros(max_frames, model_dim))
        nn.init.normal_(self.pos, std=0.02)
        # timestep MLP
        self.t_mlp = nn.Sequential(
            nn.Linear(model_dim, model_dim),
            nn.GELU(),
            nn.Linear(model_dim, model_dim),
        )
        # Decoder layers (mel frames = tgt, memory = text encoder)
        dec_layer = nn.TransformerDecoderLayer(d_model=model_dim, nhead=nhead, dim_feedforward=model_dim*4, activation="gelu")
        self.decoder = nn.TransformerDecoder(dec_layer, num_layers=num_layers)
        # output projection back to mel dims (predict eps)
        self.out = nn.Linear(model_dim, n_mels)

    def forward(self, x_t, t, text_memory):
        """
        x_t: (B, n_mels, T)
        t:   (B,) long tensor (timesteps)
        text_memory: (B, L, D)
        returns: pred_eps (B, n_mels, T)
        """
        B, n_mels, T = x_t.shape
        # permute to (T, B, n_mels)
        x = x_t.permute(2, 0, 1)
        # project to model dim -> (T, B, D)
        x = self.mel_proj(x)
        # add pos
        pos = self.pos[:T].unsqueeze(1)  # (T,1,D)
        x = x + pos
        # timestep embedding (B, D) -> expand to (T, B, D)
        t_emb = get_timestep_embedding(t, self.model_dim).to(x.device)
        t_emb = self.t_mlp(t_emb)  # (B, D)
        t_emb = t_emb.unsqueeze(0).expand(T, -1, -1)
        x = x + t_emb
        # prepare memory: (L, B, D)
        memory = text_memory.permute(1,0,2)
        # decode: cross-attention with text memory
        dec = self.decoder(tgt=x, memory=memory)  # (T, B, D)
        out = self.out(dec)  # (T, B, n_mels)
        out = out.permute(1, 2, 0)  # (B, n_mels, T)
        return out


In [None]:
class DiffusionSchedule:
    def __init__(self, T=200, beta_start=1e-4, beta_end=2e-2, device='cpu'):
        self.device = device
        self.T = T
        self.betas = torch.linspace(beta_start, beta_end, T, device=device)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)

    def q_sample(self, x0, t, noise=None):
        """
        x0: (B, n_mels, Tframes)
        t:  (B,) timesteps
        """
        if noise is None:
            noise = torch.randn_like(x0)
        a = self.sqrt_alphas_cumprod[t].view(-1,1,1)
        b = self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1)
        return a * x0 + b * noise, noise


In [None]:
from torch.utils.data import Dataset, DataLoader
import random
import os

# Example dataset using TTS audio files + text file (replace with real)
class SimpleTTSDataset(Dataset):
    def __init__(self, wav_paths, texts, sr=SR, n_mels=N_MELS):
        self.wav_paths = wav_paths
        self.texts = texts
        self.sr = sr

    def __len__(self):
        return len(self.wav_paths)

    def __getitem__(self, idx):
        wav, _ = torchaudio.load(self.wav_paths[idx])  # (1, samples)
        wav = wav.mean(0) if wav.shape[0] > 1 else wav.squeeze(0)
        mel = wav_to_mel(wav.numpy())  # (n_mels, T)
        text_ids = torch.tensor(text_to_ids(self.texts[idx]), dtype=torch.long)
        mel = torch.from_numpy(mel).float()
        return mel, text_ids

def collate_fn(batch):
    # batch: list of (mel (n_mels, T), text_ids (L,))
    # pad mels and texts
    mels = [item[0] for item in batch]
    texts = [item[1] for item in batch]
    max_T = max([m.shape[1] for m in mels])
    max_L = max([t.shape[0] for t in texts])
    B = len(batch)
    mels_p = torch.zeros(B, N_MELS, max_T)
    texts_p = torch.zeros(B, max_L, dtype=torch.long)
    for i,(m,t) in enumerate(zip(mels,texts)):
        mels_p[i, :, :m.shape[1]] = m
        texts_p[i, :t.shape[0]] = t
    return mels_p, texts_p

# Instantiate dataset (you must supply wav_paths + texts)
# dataset = SimpleTTSDataset(wav_paths, texts)
# loader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# Training loop skeleton (single step demonstration)
def train_step(model, text_encoder, schedule, optimizer, batch):
    model.train(); text_encoder.train()
    mels, texts = batch  # mels: (B, n_mels, T)
    B = mels.size(0)
    # sample timesteps
    t = torch.randint(0, schedule.T, (B,), device=mels.device)
    x_t, eps = schedule.q_sample(mels.to(mels.device), t)
    # encode text
    text_memory = text_encoder(texts.to(mels.device))
    pred = model(x_t.to(mels.device), t, text_memory)
    loss = torch.nn.functional.mse_loss(pred, eps.to(pred.device))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

In [None]:
@torch.no_grad()
def sample_ddpm(model, text_encoder, schedule, text_ids, mel_frames=120, device='cuda'):
    """
    text_ids: LongTensor (L,) or (1, L)
    returns: mel_gen (1, n_mels, mel_frames)
    """
    model.eval(); text_encoder.eval()
    if text_ids.dim() == 1:
        text_ids = text_ids.unsqueeze(0)
    text_ids = text_ids.to(device)
    text_memory = text_encoder(text_ids)  # (B, L, D)
    B = 1
    x = torch.randn(B, model.n_mels, mel_frames, device=device)
    for t_int in reversed(range(schedule.T)):
        t = torch.full((B,), t_int, device=device, dtype=torch.long)
        eps_pred = model(x, t, text_memory)
        beta_t = schedule.betas[t_int]
        alpha_t = schedule.alphas[t_int]
        alpha_cum_t = schedule.alphas_cumprod[t_int]
        coef1 = 1.0 / torch.sqrt(alpha_t)
        coef2 = beta_t / torch.sqrt(1.0 - alpha_cum_t)
        mean = coef1 * (x - coef2 * eps_pred)
        if t_int > 0:
            noise = torch.randn_like(x)
            sigma = torch.sqrt(beta_t)
            x = mean + sigma * noise
        else:
            x = mean
    mel_gen = x.clamp(min=0.0)  # clamp depending on your mel scaling
    return mel_gen  # (1, n_mels, T)


In [None]:
# try torchaudio prototype pipeline (may be deprecated but works on many installs)
try:
    from torchaudio.prototype.pipelines import HIFIGAN_VOCODER_V3_LJSPEECH as hifigan_bundle
    vocoder = hifigan_bundle.get_vocoder().to(device)
    hifi_sr = hifigan_bundle.sample_rate
    print("Loaded hifigan bundle, sr:", hifi_sr)
except Exception as e:
    print("torchaudio hifigan not available:", e)
    vocoder = None


Loaded hifigan bundle, sr: 22050


  VGGISH = VGGishBundle(_get_state_dict)
  HIFIGAN_VOCODER_V3_LJSPEECH = HiFiGANVocoderBundle(
  vocoder = hifigan_bundle.get_vocoder().to(device)
  model = hifigan_vocoder(**self._vocoder_params)
  return HiFiGANVocoder(


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
vocab_size = len(VOCAB) + 1
text_encoder = TextEncoder(vocab_size=vocab_size, model_dim=256).to(device)
model = DiffusionTransformer(n_mels=N_MELS, model_dim=256, num_layers=4).to(device)
schedule = DiffusionSchedule(T=100, device=device)  # small T for quick loops

model = DiffusionTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.MSELoss()  # simple L2 loss for predicted vs target mel

# fake batch
mel_dummy = torch.randn(2, N_MELS, 120).to(device)
text_dummy = torch.randint(1, vocab_size, (2, 20)).to(device)
optimizer = torch.optim.Adam(list(model.parameters()) + list(text_encoder.parameters()), lr=1e-4)

# one training step (demo)
loss = train_step(model, text_encoder, schedule, optimizer, (mel_dummy, text_dummy))
print("demo loss:", loss)

# sample from random text
ids = torch.tensor(text_to_ids("hello world"), dtype=torch.long)
mel_gen = sample_ddpm(model, text_encoder, schedule, ids, mel_frames=120, device=device)
print("generated mel shape:", mel_gen.shape)
# If vocoder available, convert to waveform as above




demo loss: 1.3710402250289917
generated mel shape: torch.Size([1, 80, 120])


In [None]:
import torch
import torchaudio
from torchaudio.prototype.pipelines import HIFIGAN_VOCODER_V3_LJSPEECH as hifigan_bundle
from IPython.display import Audio

device = "cuda" if torch.cuda.is_available() else "cpu"

# --- 1. Fake test input ---
B, n_mels, T = 1, 80, 120
x_t = torch.randn(B, n_mels, T).to(device)
timesteps = torch.randint(low=0, high=1000, size=(B,)).to(device)
text_memory = torch.randn(B, 30, 256).to(device)   # assuming model_dim=256

# --- 2. Run model forward ---
model.eval()
with torch.no_grad():
    mel = model(x_t, timesteps, text_memory)  # (B, n_mels, T)

# --- 3. Vocoder ---
vocoder = hifigan_bundle.get_vocoder().to(device)
with torch.no_grad():
    waveform = vocoder(mel)  # (B, 1, T_audio)

# --- 4. Play directly ---
Audio(waveform.squeeze().cpu().numpy(), rate=hifigan_bundle.sample_rate)

  vocoder = hifigan_bundle.get_vocoder().to(device)
  model = hifigan_vocoder(**self._vocoder_params)
  return HiFiGANVocoder(
