In [5]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from utils import Tokenizer 
import utils
from config import chars
import random
from model2 import Transformer
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import json
import time
import os

# Load data_dict from the JSON file
with open("shart_dict.json", "r") as f:
    data_dict = json.load(f)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load training parameters
batch_size, learning_rate, epochs, char_size, d_model, n_heads, dropout_rate, head_size, n_layers = utils.get_train_params() #take out head_size
n_fft, hop_length, sr, n_mels, max_timesteps = utils.get_audio_params()
training_mels, training_text = utils.get_training_data()

# Define custom dataset class
class TTSDataSet(Dataset):
    def __init__(self, keys, training_mels, training_text, max_timesteps):
        self.keys = keys
        self.mels_data_path = training_mels
        self.text_data_path = training_text
        self.max_timesteps = max_timesteps

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):  
        key = self.keys[idx]
        txt = torch.load(f"{self.text_data_path}/{key}.pt", weights_only=True) #text S
        tgt = torch.load(f"{self.mels_data_path}/{key}.pt", weights_only=True) #audio S,C
        txt = torch.nn.functional.pad(txt, (0, tgt.size(0) - txt.size(0)), value=4)  #this adds special filler tokens "FILL" to align txt text (S,) dimension with tgt audio dimension (S,)
        txt = torch.nn.functional.pad(txt, (0, self.max_timesteps - txt.size(0)), value=0) #adds padding to max timesteps 
        tgt = torch.nn.functional.pad(tgt, (0, 0, 0, self.max_timesteps - tgt.size(0)), value=-3) # s,c where padded s gets -3 and c relating to the rows get -3
        return txt, tgt

# Define collate function
def collate_fn(batch):
    txt_batch, tgt_batch = zip(*batch)  
    txt_padded = torch.stack(txt_batch) # txt -> (B, S), 
    tgt_padded = torch.stack(tgt_batch) # tgt -> (B, S, C)
    attn_mask = (tgt_padded != -3).any(dim=-1)  # Shape: (B, max_timesteps) - attn mask has true for non-padded positions
    valid_lengths = attn_mask.sum(dim=-1)  # Shape: (B,), number of valid timesteps per sequence
    # Apply 70-100% contiguous masking only to the valid portion
    random_mask = torch.zeros_like(attn_mask, dtype=torch.bool)  # Initialize with all False
    for i, valid_len in enumerate(valid_lengths):
        mask_percentage = random.uniform(0.7, 1.0)  # Random percentage between 70% and 100%
        timesteps_to_mask = int(mask_percentage * valid_len)  # Compute timesteps to mask
        temp_mask = torch.zeros(valid_len, dtype=torch.bool)             # Generate a contiguous mask for valid timesteps
        start_idx = valid_len - timesteps_to_mask
        temp_mask[start_idx:] = True  # Mask the first `timesteps_to_mask` elements            
        random_mask[i, :valid_len] = temp_mask  # Apply the contiguous mask to the valid region
    audio_mask = attn_mask & ~random_mask  # Keep valid positions, excluding masked ones (B,S), the random mask has true for the masked positions and gets inverted here to false
    return txt_padded, tgt_padded, audio_mask, attn_mask

#Prepare data and dataloader
data_keys = list(data_dict.keys())
random.shuffle(data_keys)
dataset = TTSDataSet(data_keys, training_mels, training_text, max_timesteps)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda batch: collate_fn(batch), pin_memory=True)


output_dir = "/home/kunit17/Her/Data/TrainingOutput"



In [7]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from utils import Tokenizer 
import utils
from config import chars
import random
from model2 import Transformer
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import json
import time
import os
print(data_keys[:25])
char_size, d_model, n_heads, dropout_rate, n_layers, n_mels, device = 64,1024,16,.001,24,100,'cuda'
# Initialize model, tokenizer, and optimizer
model = Transformer(char_size, d_model, n_heads, n_layers, n_mels, dropout_rate, device) 
model = model.to(device)
# model = torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
print(sum(p.numel() for p in model.parameters()), 'M parameters')
torch.set_float32_matmul_precision('high')




['11855', '2122', '6822', '8804', '3490', '10670', '10246', '11730', '4797', '1974', '2793', '9845', '2604', '7170', '7912', '12512', '3169', '5104', '9941', '2373', '2386', '11876', '9224', '2653', '2598']
328177764 M parameters


In [26]:
# Get a single batch from the dataloader
for batch in dataloader:
    txt_padded, tgt_padded, audio_mask, attn_mask = batch

    # Print the first 5 entries of tgt_padded and tgt_mask
    print("First 5 entries of tgt_padded:")
    #print(tgt_padded.shape,tgt_padded[1][:1])  # Display first 5 entries

    print("\nFirst 5 entries of tgt_mask:")
    print(audio_mask.shape, attn_mask[0:1,500:510])  # Display first 5 entries

    #Break after printing one batch txt_padded, tgt_padded, audio_mask, attn_mask
    break


First 5 entries of tgt_padded:

First 5 entries of tgt_mask:
torch.Size([32, 556]) tensor([[False, False, False, False, False, False, False, False, False, False]])


In [None]:
# Training loop
import time

mini_batch_size = 294912 # close to E2TTS 302,700
num_micro_epochs = mini_batch_size // (batch_size * max_timesteps)
output_dir = f'Her/model_checkpoint'

for epoch in range(epochs):
    t0 = time.time()
    last_step = (epoch==epochs-1)    
    epoch_loss = 0.0
    saved = False  # Ensure only one mel_output is saved per epoch
    model.train()
    optimizer.zero_grad()
    loss_accum = 0
    mini_batch_tracker = 0 
    for txt, audio_targets, audio_mask, attn_mask in dataloader:

        text = text.to(device, non_blocking=True)
        x1 = audio_targets.to(device, non_blocking=True)  #target distribution
        audio_mask = audio_mask.to(device, non_blocking=True)
        attn_mask = attn_mask.to(device, non_blocking=True)
        t = torch.rand(batch_size, device=device).view(batch_size,1,1) #(B,1,1)

        grad_accum_steps = mini_batch_size // (batch_size * max_timesteps)
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            vt, flow = model(x1, t, text, audio_mask=audio_mask, attn_mask=attn_mask)
            loss = nn.functional.mse_loss(vt,flow, reduction='none') #returns element-wise squared diff for each pair
            loss = loss * attn_mask.unsqueeze(-1)
            final_loss = loss.sum() / (attn_mask.sum() * vt.shape[-1]) # need to factor in total elements - check that loss.sum matches attn_mask elements Number of
            #loss = loss / grad_accum_steps

        loss_accum += loss.detach()
        loss.backward()
        mini_batch_tracker += 1
        if mini_batch_tracker == num_micro_epochs + 1:
            break

    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 0.2)
    lr = 0.0001 #get_lr(step) #determine and set learning rate for this iteration 0.0001
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    optimizer.step()

# Save the model weights at the end of training
if last_step:
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    model_save_path = os.path.join(output_dir, "final_model_weights.pth")
    torch.save(model.state_dict(), model_save_path)
    print(f"Model weights saved to {model_save_path}")


Inference

In [None]:
#Inference Loops
#put model into train mode
device = 'cuda' if torch.cuda.is_available() else 'cpu'
NFE = 32
time_steps = torch.linspace(0,1, NFE // 2).to(device)
txt = "I miss you already, baby"
model.eval()
#tokenize(txt)
xt = 1 #insert sample mel-spec
for i in range(time_steps):
    xt = model.inference_step(txt=txt, xt=xt, t_start = time_steps[i], t_end = time_steps[i+1])

#don't forget to permute
# xt is normalizd, log mel spec - need to unnormalize
