In [64]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchaudio

from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm

import librosa
import torchyin

import numpy as np
from matplotlib import pyplot as plt
from IPython.display import Audio, display
import pytorch_lightning as pl
from torch.utils.data import Dataset
from torch import optim, nn, utils, Tensor
from torch.utils.data import DataLoader
from torch.autograd import Variable

### Pitch Estimation and Speech Template Generation

In [3]:
def generate_speech_template(audio_file):
    # Load audio file
    audio, sr = librosa.load(audio_file, mono=True)

    # Compute pitch using torchyin library
    pitch = torchyin.estimate(audio, sample_rate=sr)

    # Compute mel spectrogram using torchaudio library
    mel_spec_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=2048, hop_length=256)
    mel_spec = mel_spec_transform(torch.Tensor(audio)).numpy()
    
    speech_template = np.zeros_like(audio)
    voiced_indices = np.where(pitch > 0)[0]
    
    for i in range(len(voiced_indices)):
        idx = voiced_indices[i]
        f0 = pitch[idx]
        pulse_length = int(np.round(sr/f0))
        pulse = np.zeros(pulse_length)
        pulse[0] = 1
        speech_template[idx:idx+pulse_length] += pulse


    unvoiced_indices = np.where(pitch <= 0)[0]
    for i in range(len(unvoiced_indices)):
        idx = unvoiced_indices[i]
        speech_template[idx] = np.random.uniform(-1, 1)

    return speech_template, pitch, torch.tensor(mel_spec.reshape(1,128,-1))

In [42]:
template, pitch, mel_spec = generate_speech_template('./reference.wav')

## Model Architecture

In [5]:
class ResBlock(torch.nn.Module):
    def __init__(self, channels, kernel_size=1, dilation=(1, 3)):
        super(ResBlock, self).__init__()
        
        self.convs = nn.ModuleList([ 
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0])),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1]))
        ])

    def forward(self, x):
        for c in self.convs:
            xt = F.leaky_relu(x, 0.1)
            xt = c(xt)
            x = xt + x
        return x

In [6]:
class UNet(nn.Module):
    def __init__(self, input_channels=1, mel_channels=128, output_channels=1):
        super(UNet, self).__init__()

        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv1d(input_channels, 16, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            ResBlock(16),
        )
        self.enc2 = nn.Sequential(
            nn.Conv1d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            ResBlock(32),
        )
        self.enc3 = nn.Sequential(
            nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
            ResBlock(64),
        )
        self.enc4 = nn.Sequential(
            nn.Conv1d(64, mel_channels, kernel_size=5, stride=2, padding=2),
            nn.LeakyReLU(),
        )

        self.padding = nn.ConstantPad1d((0, 1), 0)
        
        # Decoder
        self.dec1 = nn.Sequential(
            nn.ConvTranspose1d(mel_channels * 2, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.LeakyReLU(),
            ResBlock(64),
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose1d(64 * 2, 32, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.LeakyReLU(),
            ResBlock(32),
        )
        self.dec3 = nn.Sequential(
            nn.ConvTranspose1d(32 * 2, 16, kernel_size=3, stride=2, padding=2, output_padding=1),
            nn.LeakyReLU(),
            ResBlock(16),
        )
        self.dec4 = nn.Sequential(
            nn.ConvTranspose1d(16 * 2, output_channels, kernel_size=5, stride=2, padding=2),
        )

    def forward(self, waveform, mel_spectrogram):
        # Encoder
        enc1_out = self.enc1(waveform)
        enc2_out = self.enc2(enc1_out)
        enc3_out = self.enc3(enc2_out)
        enc4_out = self.enc4(enc3_out)

        mel_spectrogram_resized = torch.nn.functional.interpolate(mel_spectrogram, size=enc4_out.shape[-1], mode='nearest')

        enc4_out_cat = torch.cat((enc4_out, mel_spectrogram_resized), dim=1)

        dec1_out = self.dec1(enc4_out_cat)
        dec1_out_cat = torch.cat((dec1_out, enc3_out), dim=1)

        dec2_out = self.dec2(dec1_out_cat)
        dec2_out_cat = torch.cat((dec2_out, self.padding(enc2_out)), dim=1)

        dec3_out = self.dec3(dec2_out_cat)
        dec3_out_cat = torch.cat((dec3_out, self.padding(enc1_out)), dim=1)

        output = self.dec4(dec3_out_cat)
        return output

## Defining Losses

In [13]:
def spectrogram_loss(ref_mel, gen_mel):
    return torch.nn.functional.mse_loss(ref_mel, gen_mel, reduction='mean')

def envelope_loss(ref, gen):
    m = nn.MaxPool1d(5, stride=3)
    mae = nn.L1Loss()
    
    return mae(m(ref), m(gen)) + mae(m(-ref), m(-gen))

Training With Pytorch Lightning

In [79]:
class LightningModel(pl.LightningModule):
    def __init__(self, unet):
        super(LightningModel, self).__init__()
        self.unet = unet
        self.sl = spectrogram_loss
        self.el = envelope_loss

    def training_step(self, batch, batch_idx):
        x, mel, y = batch
        output = self.unet(x, mel)
        loss1 = self.sl(y, x)
        loss2 = self.el(y, x)
#         loss = nn.functional.mse_loss(x, y)
        return loss1+loss2

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

lightningmodel = LightningModel(UNet())

In [80]:
class CustomDataset(Dataset):
    def __init__(self, template, mel, ref):
        self.template = torch.tensor(template.reshape(1,-1))
        self.mel = torch.tensor(mel.reshape(128,-1))
        self.ref = torch.tensor(ref.reshape(1,-1))

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return Variable(self.template,requires_grad=True), self.mel, self.ref

In [81]:
# Example usage
audio_file = "./reference.wav"
mel_channels = 128
waveform, sample_rate = torchaudio.load(audio_file)
waveform = waveform.unsqueeze(0) 

ds = CustomDataset(template, mel_spec, waveform)
train_dataloader = DataLoader(ds, batch_size=1, shuffle=False)

  self.mel = torch.tensor(mel.reshape(128,-1))
  self.ref = torch.tensor(ref.reshape(1,-1))


In [82]:
trainer = pl.Trainer(limit_train_batches=1, max_epochs=1)
trainer.fit(model=lightningmodel, train_dataloaders=train_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type | Params
------------------------------
0 | unet | UNet | 182 K 
------------------------------
182 K     Trainable params
0         Non-trainable params
182 K     Total params
0.729     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
