## Download LJSpeech

In [260]:
!wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
!tar -xjf LJSpeech-1.1.tar.bz2

In [261]:
%pip install librosa

In [262]:
#!g1.1
%pip install torch==1.10.0+cu111 torchaudio==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
#!git clone https://github.com/darya-baranovskaya/TTS_with_FastSpeach.git

In [263]:
#!g1.1
!git clone https://github.com/NVIDIA/waveglow.git
%pip install googledrivedownloader

In [265]:
#!g1.1
%pip install wandb

In [1]:
#!g1.1
import wandb

In [2]:
#!g1.1
from typing import Tuple, Dict, Optional, List, Union
from itertools import islice

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

In [3]:
#!g1.1
%load_ext autoreload
%autoreload 2

In [683]:
#!g1.1
from models.fastspeech_model import FastSpeech
from utils.batch_sampler import Batch
from utils.ljspeech_dataset import LJSpeechDataset#, LJSpeechCollator
from utils.melspectrogram import MelSpectrogram

from configs.melspectrogram_config import MelSpectrogramConfig
from models.vocoder import Vocoder
from models.grapheme_aligner import GraphemeAligner, Point, Segment

from typing import Tuple, Dict, Optional, List, Union
from itertools import islice

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from IPython import display
from dataclasses import dataclass

import torch
from torch import nn

import torchaudio

import librosa
from matplotlib import pyplot as plt

import warnings
import sys
sys.path.append('waveglow/')

warnings.filterwarnings('ignore')

---

## Dataset

In [661]:
#!g1.1
dataset = LJSpeechDataset('.')

In [273]:
#!g1.1
dataset[0]

In [839]:
#!g1.1
device = torch.device('cuda:0')
aligner = GraphemeAligner().to(device)
featurizer = MelSpectrogram(MelSpectrogramConfig())
class LJSpeechCollator:

    def __call__(self, instances: List[Tuple]) -> Dict:
        waveform, waveforn_length, transcript, tokens, token_lengths = list(
            zip(*instances)
        )

        waveform = pad_sequence([
            waveform_[0] for waveform_ in waveform
        ]).transpose(0, 1)
        waveforn_length = torch.cat(waveforn_length)

        tokens = pad_sequence([
            tokens_[0] for tokens_ in tokens
        ]).transpose(0, 1)
        token_lengths = torch.cat(token_lengths)
        batch = Batch(waveform, waveforn_length, transcript, tokens, token_lengths)
        batch.melspec = featurizer(batch.waveform)
        lengths = []
        for i in range(batch.melspec.shape[0]):
            lengths.append(featurizer(batch.waveform[i:i + 1, :batch.waveforn_length[i]]).shape[-1])
        lengths = torch.Tensor(lengths).unsqueeze(1)
        alignes = aligner(
            batch.waveform.to(device), batch.waveforn_length, batch.transcript
        )
        batch.durations = lengths * alignes
        return batch

In [840]:
#!g1.1
dataloader = DataLoader(LJSpeechDataset('.'), batch_size=3, collate_fn=LJSpeechCollator())

In [720]:
#!g1.1
batch = next(iter(dataloader))
batch.tokens.shape

In [140]:
#!g1.1
batch.__dict__.keys()

In [141]:
#!g1.1
print(batch.durations.shape)
print(batch.waveform.shape) 
print(batch.waveforn_length)
print(len(batch.transcript), batch.transcript[0])
print(batch.tokens.shape)
print(batch.token_lengths)
print(batch.melspec.shape)

In [117]:
#!g1.1
dummy_batch = list(islice(dataloader, 1))[0]

---

## Vocoder

In [681]:
#!g1.1
from google_drive_downloader import GoogleDriveDownloader as gdd

In [682]:
#!g1.1
gdd.download_file_from_google_drive(
    file_id='1rpK8CzAAirq9sWZhe9nlfvxMF1dRgFbF',
    dest_path='./waveglow_256channels_universal_v5.pt'
)

In [841]:
#!g1.1
from models.vocoder import Vocoder

In [842]:
#!g1.1
vocoder = Vocoder().to('cuda:0').eval()

---

# Batch overfit

In [730]:
#!g1.1
from models.fastspeech_model import *
from models.model_layers import *
class ModelConfig:
    vocab_size: int = 1000
    hidden_size: int = 384
    hidden_size_fft: int = 1536
    num_heads: int = 2
    kernel_size: int = 3
    n_fft_blocks: int = 2
    dropout = 0.1
        
model = FastSpeech(ModelConfig)
model = model.to(device)


In [489]:
#!g1.1
reconstructed_wav = vocoder.inference(batch.melspec.to(device)).cpu()

In [490]:
#!g1.1
display.display(display.Audio(reconstructed_wav[0], rate=22050))
display.display(display.Audio(reconstructed_wav[1], rate=22050))
display.display(display.Audio(reconstructed_wav[2], rate=22050))

In [722]:
#!g1.1
wandb.init(name='Batch_overfit1_Sequential_2fft')
wandb.log({'batch_text0': wandb.Html(batch.transcript[0])})
wandb.log({'batch_text1': wandb.Html(batch.transcript[1])})
wandb.log({'batch_text2': wandb.Html(batch.transcript[2])})

N_ITERATIONS = 10000
model.train()

loss_fn_mel = nn.L1Loss()
# loss_fn_align = nn.L1Loss()
loss_fn_align = nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98),
                                 eps=1e-9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.98999)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=N_ITERATIONS//10, epochs=10)


mel_losses = 0
align_losses = 0

for i in range(N_ITERATIONS):
    inp = batch.tokens.to(device)
    gt = batch.melspec.to(device)
    align_gt = batch.durations.to(device)
#     align_gt = (batch.durations + 0.5).type(torch.LongTensor).to(device)

    mel, align = model(inp, align_gt)

    optimizer.zero_grad()
    
    if mel.shape[-1] < gt.shape[-1]:
        mel = torch.cat([mel, -11.5129 * torch.ones((mel.shape[0], 80, gt.shape[-1] - mel.shape[-1])).to(device)], dim=-1)
    elif mel.shape[-1] > gt.shape[-1]:
        with torch.no_grad():
            gt = torch.cat([gt, -11.5129 * torch.ones((gt.shape[0], 80, mel.shape[-1] - gt.shape[-1])).to(device)], dim=-1)
    mel_loss = loss_fn_mel(mel, gt)
#     print(align)
    align_loss = loss_fn_align(align, align_gt)
    loss = mel_loss + align_loss
    loss.backward()

    optimizer.step()
    scheduler.step()
    mel_losses += mel_loss.item()
    align_losses += align_loss.item()
    wandb.log({'train_mel_loss': mel_loss, 'train_align_loss': align_loss, 'train_loss': loss})
    if i % 500 == 0:
        with torch.no_grad():
            reconstructed_wav = vocoder.inference(mel.to(device)).cpu()
            print(f"step {i} sampled_audios")
            display.display(display.Audio(reconstructed_wav[0], rate=22050))
            wandb.log({'reconstructed_wav0': wandb.Audio(reconstructed_wav[0].detach().cpu().numpy(), sample_rate=22050),
                      'reconstructed_wav1': wandb.Audio(reconstructed_wav[1].detach().cpu().numpy(), sample_rate=22050),
                      'reconstructed_wav2': wandb.Audio(reconstructed_wav[2].detach().cpu().numpy(), sample_rate=22050),
                      'melspec0': wandb.Image(mel[0].detach().cpu().numpy()),
                      'melspec1': wandb.Image(mel[1].detach().cpu().numpy()),
                      'melspec2': wandb.Image(mel[2].detach().cpu().numpy())})

In [564]:
#!g1.1
print("Результат переобучения на батче за 10к шагов с собственным alignment")
model.eval()
inp = batch.tokens.to(device)
gt = batch.melspec.to(device)
align_gt = batch.durations.to(device) # revert that with torch.exp(torch.log(align_gt + 1)) - 1 

mel, align = model(inp)
reconstructed_wav = vocoder.inference(mel.to(device)).cpu()
print(f"step {i} sampled_audios")
wandb.log({'reconstructed_wav0_eval': wandb.Audio(reconstructed_wav[0].detach().cpu().numpy(), sample_rate=22050),
          'reconstructed_wav1_eval': wandb.Audio(reconstructed_wav[1].detach().cpu().numpy(), sample_rate=22050),
          'reconstructed_wav2_eval': wandb.Audio(reconstructed_wav[2].detach().cpu().numpy(), sample_rate=22050),
          'melspec0_eval': wandb.Image(mel[0].detach().cpu().numpy()),
          'melspec1_eval': wandb.Image(mel[1].detach().cpu().numpy()),
          'melspec2_eval': wandb.Image(mel[2].detach().cpu().numpy())})
display.display(display.Audio(reconstructed_wav[0], rate=22050))

# Train


In [844]:
#!g1.1
device = torch.device('cuda:0')
def train_epoch(model, optimizer, dataloader, scheduler, loss_fn_mel, loss_fn_align):
    model.train()
    mel_losses = 0
    align_losses = 0
    for batch in dataloader:
        inp = batch.tokens.to(device)
        gt = batch.melspec.to(device)
        align_gt = batch.durations.to(device)
    #     align_gt = (batch.durations + 0.5).type(torch.LongTensor).to(device)
        if inp.shape[-1] != align_gt.shape[-1]:
#             print(inp.shape, align_gt.shape)
            min_shape = min(inp.shape[-1], align_gt.shape[-1])
            inp = inp[:, :min_shape]
            align_gt = align_gt[:, :min_shape]
        mel, align = model(inp, align_gt)

        optimizer.zero_grad()

        if mel.shape[-1] < gt.shape[-1]:
            mel = torch.cat([mel, -11.5129 * torch.ones((mel.shape[0], 80, gt.shape[-1] - mel.shape[-1])).to(device)], dim=-1)
        elif mel.shape[-1] > gt.shape[-1]:
            with torch.no_grad():
                gt = torch.cat([gt, -11.5129 * torch.ones((gt.shape[0], 80, mel.shape[-1] - gt.shape[-1])).to(device)], dim=-1)
        mel_loss = loss_fn_mel(mel, gt)
    #     print(align)
        align_loss = loss_fn_align(align, align_gt)
        loss = mel_loss + align_loss
        loss.backward()

        optimizer.step()
        scheduler.step()
        
        mel_losses += mel_loss.item()
        align_losses += align_loss.item()
    wandb.log({'train_mel_loss':  mel_losses / len(dataloader), 'train_align_loss': align_losses/len(dataloader), 'train_loss': loss,
              'optimizer_lr': optimizer.param_groups[0]['lr']})
    with torch.no_grad():
        reconstructed_wav = vocoder.inference(mel.to(device)).cpu()
        print(f"train sampled_audios")
        display.display(display.Audio(reconstructed_wav[0], rate=22050))
        for i in range(reconstructed_wav.shape[0]):
            wandb.log({f'train_text{i}': wandb.Html(batch.transcript[i]),
                       f'reconstructed_train_wav{i}': wandb.Audio(reconstructed_wav[i].detach().cpu().numpy(), sample_rate=22050),
                       f'melspec_train_{i}': wandb.Image(mel[i].detach().cpu().numpy())})
    return align_losses/len(dataloader), mel_losses / len(dataloader)

In [757]:
#!g1.1
device = torch.device('cuda:0')
def validate(model, dataloader, loss_fn_mel, loss_fn_align):
    model.eval()
    mel_losses = 0
    align_losses = 0
    for batch in dataloader:
        with torch.no_grad():
            inp = batch.tokens.to(device)
            gt = batch.melspec.to(device)
            align_gt = batch.durations.to(device)
            
            mel, align = model(inp)
            if mel.shape[-1] < gt.shape[-1]:
                mel = torch.cat([mel, -11.5129 * torch.ones((mel.shape[0], 80, gt.shape[-1] - mel.shape[-1])).to(device)], dim=-1)
            elif mel.shape[-1] > gt.shape[-1]:
                with torch.no_grad():
                    gt = torch.cat([gt, -11.5129 * torch.ones((gt.shape[0], 80, mel.shape[-1] - gt.shape[-1])).to(device)], dim=-1)
            
            mel_loss = loss_fn_mel(mel, gt)
            align_loss = loss_fn_align(align, align_gt)
            loss = mel_loss + align_loss
            
            mel_losses += mel_loss.item()
            align_losses += align_loss.item()
    with torch.no_grad():
        reconstructed_wav = vocoder.inference(mel.to(device)).cpu()
        print(f"validation sampled_audios")
        display.display(display.Audio(reconstructed_wav[0], rate=22050))
        for i in range(reconstructed_wav.shape[0]):
            wandb.log({f'val_text{i}': wandb.Html(batch.transcript[i]),
                       f'reconstructed_val_wav{i}': wandb.Audio(reconstructed_wav[i].detach().cpu().numpy(), sample_rate=22050),
                       f'melspec_val_{i}': wandb.Image(mel[i].detach().cpu().numpy())})
    wandb.log({'val_mel_loss': mel_losses / len(dataloader),
               'val_align_loss': align_losses/len(dataloader), 'val_loss': (mel_losses + align_losses)/len(dataloader)})
    return align_losses/len(dataloader), mel_losses / len(dataloader)


In [775]:
#!g1.1
tokenizer = torchaudio.pipelines.TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH.get_text_processor()
def test(model):
    model.eval()
    test_texts = ['A defibrillator is a device that gives a high energy electric shock to the heart of someone who is in cardiac arrest',
                 'Massachusetts Institute of Technology may be best known for its math, science and engineering education',
                 'Wasserstein distance or Kantorovich Rubinstein metric is a distance function defined between probability distributions on a given metric space']
    tokens = [tokenizer(text)[0] for text in test_texts]
    for i, token in enumerate(tokens):
        token = token.to(device)
        mel, align = model(token)
        reconstructed_wav = vocoder.inference(mel.to(device)).cpu()
        print(f"test sampled_audios")
        display.display(display.Audio(reconstructed_wav[0], rate=22050))
        wandb.log({f'reconstructed_test_wav{i}': wandb.Audio(reconstructed_wav[0].detach().cpu().numpy(), sample_rate=22050),
                  f'melspec_val_{i}': wandb.Image(mel[0].detach().cpu().numpy()),
                  f'test_text{i}':wandb.Html(test_texts[i])})

In [None]:
#!g1.1
dataset = LJSpeechDataset('.')
device = torch.device('cuda:0')
aligner = GraphemeAligner().to(device)
featurizer = MelSpectrogram(MelSpectrogramConfig())

torch.manual_seed(3407)
test_size = 30
train_size = len(dataset) -  test_size#int(0.95 * len(dataset))
# test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [848]:
#!g1.1
train_dataloader = DataLoader(train_dataset, batch_size=10, collate_fn=LJSpeechCollator())
val_dataloader = DataLoader(test_dataset, batch_size=3, collate_fn=LJSpeechCollator())

In [849]:
#!g1.1
len(train_dataloader), len(val_dataloader)

In [None]:
#!g1.1
from models.model_layers import *
from models.fastspeech_model import FastSpeechEncoder, FastSpeechDecoder, FastSpeech

class ModelConfig:
    vocab_size: int = 57
    hidden_size: int = 384
    hidden_size_fft: int = 1536
    num_heads: int = 2
    kernel_size: int = 3
    n_fft_blocks: int = 6
    dropout = 0.1
        
model = FastSpeech(ModelConfig)
model = model.to(torch.device('cuda:0'))

In [None]:
#!g1.1
from tts.model.fastspeech import FastSpeech

In [847]:
#!g1.1
scheduler

In [None]:
#!g1.1
aligner = GraphemeAligner().to(torch.device('cuda:0'))

N_EPOCHS = 60
RUN_NAME = 'model_nfft-6_hidden1-384_hidden2-1536_dropout-0.1'
wandb.init(name=RUN_NAME)



loss_fn_mel = nn.L1Loss()
loss_fn_align = nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9, lr = 0.00001)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=len(train_dataloader) // 5, gamma=0.98999)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.0005, steps_per_epoch=len(train_dataloader), epochs=N_EPOCHS)
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0008, max_lr=0.01,
#                                               step_size_up=len(train_dataloader), step_size_down=5*len(train_dataloader),
#                                               cycle_momentum=False)


for epoch in range(23, N_EPOCHS):
    train_epoch(model, optimizer, train_dataloader, scheduler, loss_fn_mel, loss_fn_align)
    print(f'epoch {epoch} ended')
    validate(model, val_dataloader, loss_fn_mel, loss_fn_align)
    test(model)
    torch.save(model.state_dict(), RUN_NAME + '_last_epoch.pth')
    if epoch % 5 == 0 and epoch > 20 :
        torch.save(model.state_dict(), RUN_NAME + '_epoch_' + str(epoch) + '.pth')


In [None]:
#!g1.1
