# **voice2voice**
### Welcome to the voice conversion system based on the **pix2pix** architecture! Let's train a model!

**Some utils**

In [2]:
from glob import glob
from pathlib import Path
from random import shuffle

import cv2
import librosa
import matplotlib.pyplot as plt
import numpy as np


def get_vcc2018_filenames(path, source_voice, target_voice):
    """Gather all files corresponding to source and target.
    Put them in a list of tuples.
    """
    filenames = glob((Path(path) / Path('**')).as_posix(), recursive=True)

    sources = sorted([f for f in filenames if source_voice in Path(f).parent.name])
    targets = sorted([f for f in filenames if target_voice in Path(f).parent.name])

    return list(zip(sources, targets))


def shuffle_split(alist, first_half_ratio):
    """Shufle and split a list in two. Choose the size of the parts."""
    shuffle(alist)
    k = int(first_half_ratio * len(alist))
    return alist[:k], alist[k:]


def audio_sampler(filenames, batch_size, return_max_power=False):
    """Returns a generator that iterates through the given audio file names
    and returns a batch of pairs of (source, target) spectograms.
    """
    sample_rate = 44100
    n_fft = 2048
    hop_length = 518
    n_mels = 256
    
    a = np.zeros((batch_size, 1, 256, 256)) 
    b = np.zeros((batch_size, 1, 256, 256))
    
    max_power_a = []
    
    while True:
        indices = np.random.randint(len(filenames), size=batch_size)
        
        for i, idx in enumerate(indices):
            audio_a, _ = librosa.load(filenames[idx][0], sr=sample_rate)
            audio_b, _ = librosa.load(filenames[idx][1], sr=sample_rate)
            
            len_a = audio_a.size
            len_b = audio_b.size
            
            if len_a < 3*sample_rate:
                diff = 3*sample_rate - len_a
                audio_a = np.concatenate((audio_a, np.zeros(diff)))
                len_a = audio_a.size
                
            if len_b < 3*sample_rate:
                diff = 3*sample_rate - len_b
                audio_b = np.concatenate((audio_b, np.zeros(diff)))
                len_b = audio_b.size
                
            if len_a < len_b:
                diff = len_b - len_a
                audio_a = np.concatenate((audio_a, np.zeros(diff)))
            else:
                diff = len_a - len_b
                audio_b = np.concatenate((audio_b, np.zeros(diff)))
            
            assert audio_a.size == audio_b.size
            size = audio_a.size
            
            r = np.random.randint(size - 3*sample_rate + 1)
            
            audio_a = audio_a[r:r+3*sample_rate]
            audio_b = audio_b[r:r+3*sample_rate]            

            S_a = librosa.feature.melspectrogram(
                y=audio_a,
                sr=sample_rate,
                n_fft=n_fft,
                hop_length=hop_length,
                n_mels=n_mels
            )
            
            S_b = librosa.feature.melspectrogram(
                y=audio_b,
                sr=sample_rate,
                n_fft=n_fft,
                hop_length=hop_length,
                n_mels=n_mels
            )

            S_db_a = librosa.power_to_db(S_a, ref=np.max)
            S_db_b = librosa.power_to_db(S_b, ref=np.max)

            a[i, 0] = S_db_a
            b[i, 0] = S_db_b
                
            max_power_a.append(np.max(S_a))
        
        if return_max_power:
            yield (a, b), max_power_a
        else:
            yield (a, b)
            
            
def reconstruct_signal(S_db, ref=1.0):
    """Builds an audio signal (numpy array) from a spectogram."""
    sample_rate = 44100
    n_fft = 2048
    hop_length = 518
    
    S = librosa.db_to_power(S_db, ref=ref)
    
    audio = librosa.feature.inverse.mel_to_audio(
        M=S,
        sr=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length
    )
    
    return audio


def style(im):
    """Converts a single-channel image into an RGB image
    with the viridis color palette.
    """
    norm = plt.Normalize(im.min(), im.max())
    im = plt.cm.viridis(norm(im))
    im = (255 * im).astype(np.uint8)
    return cv2.cvtColor(im, cv2.COLOR_BGRA2RGB)

**Network definitions**

In [None]:
import torch
import torch.nn as nn


class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, normalization=True, stride=2, track_stats=False):
        super().__init__()
        
        self.layers = nn.ModuleList()
        
        self.layers.append(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=4,
                stride=stride,
                padding=1,
                bias=not normalization
            )
        )
        
        if normalization:
            self.layers.append(
                nn.InstanceNorm2d(out_channels, track_running_stats=track_stats)
            )
            
        self.layers.append(
            nn.LeakyReLU(0.2)
        )
    
    def forward(self, x):
        for l in self.layers:
            x = l(x)
        
        return x
    

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=False, track_stats=False):
        super().__init__()
        
        self.layers = nn.ModuleList()
        
        self.layers.append(
            nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False
            )
        )
        
        self.layers.append(
            nn.InstanceNorm2d(out_channels, track_running_stats=track_stats)
        )
        
        if dropout:
            self.layers.append(
                nn.Dropout2d(0.5)
            )
            
        self.layers.append(
            nn.ReLU()
        )
        
    def forward(self, x):
        for l in self.layers:
            x = l(x)
        
        return x
    
    
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.encoder = nn.ModuleList([
            EncoderBlock(in_channels, 64, normalization=False),
            EncoderBlock(64, 128),
            EncoderBlock(128, 256),
            EncoderBlock(256, 512),
            EncoderBlock(512, 512),
            EncoderBlock(512, 512),
            EncoderBlock(512, 512),
            EncoderBlock(512, 512)
        ])
        
        self.decoder = nn.ModuleList([
            DecoderBlock(512, 512, dropout=True),
            DecoderBlock(1024, 512, dropout=True),
            DecoderBlock(1024, 512, dropout=True),
            DecoderBlock(1024, 512),
            DecoderBlock(1024, 256),
            DecoderBlock(512, 128),
            DecoderBlock(256, 64)
        ])
        
        self.last_conv = nn.ConvTranspose2d(
            in_channels=128,
            out_channels=out_channels,
            kernel_size=4,
            stride=2,
            padding=1
        )
        
    def forward(self, x):
        skips = []
        
        for l in self.encoder:
            x = l(x)
            skips.insert(0, x)
            
        for s, l in zip(skips[1:], self.decoder):
            x = l(x)
            x = torch.cat((s, x), dim=1)
            
        return self.last_conv(x)
    
    
class PatchGAN(nn.Module):
    def __init__(self, in_channels, sigmoid=False):
        super().__init__()
        
        self.layers = nn.ModuleList([
            EncoderBlock(in_channels, 64, normalization=False),
            EncoderBlock(64, 128),
            EncoderBlock(128, 256),
            EncoderBlock(256, 512, stride=1),
            nn.Conv2d(
                in_channels=512,
                out_channels=1,
                kernel_size=4,
                stride=1,
                padding=1
            )
        ])
        
        if sigmoid:
            self.layers.append(nn.Sigmoid())
        
    def forward(self, x):
        for l in self.layers:
            x = l(x)
            
        return x

**Hyperparameter settings**

In [None]:
cuda = torch.cuda.is_available()  # Do you have a CUDA enabled GPU?

data_path = 'data/'               # Your data should be here, download from https://datashare.is.ed.ac.uk/handle/10283/3061
output_path = 'outputs/'          # Save your output .WAV files and spectogram images here 
checkpoint_path = 'checkpoints/'  # Your models will be saved here
checkpoint_name = 'checkpoint'    # Base name of a model

# The source and target speakers
# There are 8 speakers (4 women and 4 men): SF1, SF2, SM1, SM2, TF1, TF2, TM1, TM2
voice_a = 'SF2'
voice_b = 'TM1'

batch_size = 4

epochs = 2000
checkpoint_freq = 100

**Initialize the data**

In [None]:
# Gather the relevant filenames
filenames = get_vcc2018_filenames(data_path, voice_a, voice_b)

# Split the data
train_filenames, test_filenames = shuffle_split(filenames, 0.8)

# Create a training sampler so we can ask for random batches
sampler = audio_sampler(train_filenames, batch_size)

**Initialize the networks, optimizers and loss function**

In [None]:
# Networks
G = UNet(in_channels=1, out_channels=1)
D = PatchGAN(in_channels=1)

# Optimizers
optim_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optim_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss
criterion_adversarial = nn.MSELoss()

if cuda:
    Tensor = torch.cuda.FloatTensor
    G.cuda()
    D.cuda()
    
else:
    Tensor = torch.FloatTensor

**Start training**

In [None]:
import time
import sys

# Set networks to training mode
G.train()
D.train()

start_time = time.time()

for e in range(1, epochs+1):
    # Give me a batch
    a, b = next(sampler)
    
    # Model inputs
    real_a = Tensor(a)
    real_b = Tensor(b)
    
    # Adversarial ground truths
    real = torch.ones(batch_size, 1, 30, 30).type(Tensor)
    fake = torch.zeros(batch_size, 1, 30, 30).type(Tensor)
    
    # Train generator
    optim_G.zero_grad()
    
    fake_b = G(real_a)
    pred_fake = D(fake_b)
    loss_G = criterion_adversarial(pred_fake, real)
    
    loss_G.backward()
    optim_G.step()
    
    # Train discriminator
    optim_D.zero_grad()
    
    pred_real = D(real_b)
    loss_real = criterion_adversarial(pred_real, real)
    
    pred_fake = D(fake_b.detach())
    loss_fake = criterion_adversarial(pred_fake, fake)
    
    loss_D = 0.5 * (loss_real + loss_fake)
    
    loss_D.backward()
    optim_D.step()
    
    # Save the states of the models, optimizers, and the training and test splits
    if e % checkpoint_freq == 0:
        torch.save(
            {
                'epoch': e,

                'G_state_dict': G.state_dict(),
                'D_state_dict': D.state_dict(),

                'G_loss': loss_G,
                'D_loss': loss_D,

                'G_optim_state_dict': optim_G.state_dict(),
                'D_optim_state_dict': optim_D.state_dict(),
                
                'train_filenames': train_filenames,
                'test_filenames': test_filenames
            },
            (Path(checkpoint_path) / Path(checkpoint_name)).as_posix() + f'_{int(start_time)}_{e:0{len(str(epochs))}}.tar'
        )

    # How are we doing? Show a summary
    time_left = int((time.time() - start_time) * (epochs - e) / e)
    
    h = time_left // 3600
    m = (time_left % 3600) // 60
    s = (time_left % 3600) % 60

    eta = f'{h:02}h {m:02}m {s:02}s'
    
    sys.stdout.write(
        f'\r[Epoch {e:0{len(str(epochs))}}/{epochs}] [G loss: {loss_G:.4f}] [D loss: {loss_D:.4f}] ETA: {eta}'
    )

**Let's check what we learned**

In [None]:
import matplotlib.pyplot as plt


# Set the generator in evaluation mode
G.eval()

# We want to sample the test partition
test_sampler = audio_sampler(test_filenames, 1)

for i in range(10):
    # Source, target and fake
    real_a, real_b = next(test_sampler)
    fake_b = G(Tensor(real_a)).cpu().data.numpy()
    
    # Save the source, target and fake spectograms
    cv2.imwrite(output_path + f'source_{i+1}.png', style(real_a[0, 0]))
    cv2.imwrite(output_path + f'target_{i+1}.png', style(real_b[0, 0]))
    cv2.imwrite(output_path + f'fake_{i+1}.png', style(fake_b[0, 0]))
    
    source = reconstruct_signal(real_a[0, 0])
    target = reconstruct_signal(real_b[0, 0])
    fake = reconstruct_signal(fake_b[0, 0])
    
    # Save the source, target and fake audios
    librosa.output.write_wav(output_path + f'source_{i+1}.wav', source, sr=44100)
    librosa.output.write_wav(output_path + f'target_{i+1}.wav', target, sr=44100)
    librosa.output.write_wav(output_path + f'fake_{i+1}.wav', fake, sr=44100)
    
    # Show them, too
    fig = plt.figure()

    fig.add_subplot(1, 3, 1)
    plt.imshow(real_a[0, 0])

    fig.add_subplot(1, 3, 2)
    plt.imshow(fake_b[0, 0])
    
    fig.add_subplot(1, 3, 3)
    plt.imshow(real_b[0, 0])

    plt.show()