In [0]:
from google.colab import drive
drive.mount('/content/drive')

!pip install -U librosa

In [0]:
from glob import glob
from pathlib import Path
from random import shuffle

import librosa
import numpy as np


def get_vcc2018_filenames(path, source_voice, target_voice):
    filenames = glob((Path(path) / Path('**')).as_posix(), recursive=True)

    sources = sorted([f for f in filenames if source_voice in Path(f).parent.name])
    targets = sorted([f for f in filenames if target_voice in Path(f).parent.name])

    return list(zip(sources, targets))


def filter_by_duration(filenames, seconds=3, sample_rate=44100):
    valid = []
    
    for s, t in filenames:
        audio_s, _ = librosa.load(s, sr=sample_rate)
        audio_t, _ = librosa.load(t, sr=sample_rate)

        size = min(len(audio_s), len(audio_t))
        
        if size >= seconds * sample_rate:
            valid.append((s, t))
            
    return valid


def shuffle_split(alist, first_half_ratio):
    shuffle(alist)
    k = int(first_half_ratio * len(alist))
    return alist[:k], alist[k:]


def audio_sampler(filenames, batch_size, return_max_power=False, dtw=False):
    sample_rate = 44100
    n_fft = 2048
    hop_length = 518
    n_mels = 256
    
    if not dtw:
        a = np.zeros((batch_size, 1, 256, 256))
    else:
        a = np.zeros((batch_size, 2, 256, 256))
        
    b = np.zeros((batch_size, 1, 256, 256))
    
    max_power_a = []
    
    while True:
        indices = np.random.randint(len(filenames), size=batch_size)
        
        for i, idx in enumerate(indices):
            audio_a, _ = librosa.load(filenames[idx][0], sr=sample_rate)
            audio_b, _ = librosa.load(filenames[idx][1], sr=sample_rate)
            
            len_a = audio_a.size
            len_b = audio_b.size
            
            if len_a < 3*sample_rate:
                diff = 3*sample_rate - len_a
                audio_a = np.concatenate((audio_a, np.zeros(diff)))
                len_a = audio_a.size
                
            if len_b < 3*sample_rate:
                diff = 3*sample_rate - len_b
                audio_b = np.concatenate((audio_b, np.zeros(diff)))
                len_b = audio_b.size
                
            if len_a < len_b:
                diff = len_b - len_a
                audio_a = np.concatenate((audio_a, np.zeros(diff)))
            else:
                diff = len_a - len_b
                audio_b = np.concatenate((audio_b, np.zeros(diff)))
            
            assert audio_a.size == audio_b.size
            size = audio_a.size
            
            r = np.random.randint(size - 3*sample_rate + 1)
            
            audio_a = audio_a[r:r+3*sample_rate]
            audio_b = audio_b[r:r+3*sample_rate]            

            S_a = librosa.feature.melspectrogram(
                y=audio_a,
                sr=sample_rate,
                n_fft=n_fft,
                hop_length=hop_length,
                n_mels=n_mels
            )
            
            S_b = librosa.feature.melspectrogram(
                y=audio_b,
                sr=sample_rate,
                n_fft=n_fft,
                hop_length=hop_length,
                n_mels=n_mels
            )

            S_db_a = librosa.power_to_db(S_a, ref=np.max)
            S_db_b = librosa.power_to_db(S_b, ref=np.max)

            a[i, 0] = S_db_a
            b[i, 0] = S_db_b
            
            if dtw:
                a_chroma = librosa.feature.chroma_stft(
                    y=audio_a, sr=sample_rate, tuning=0, norm=2, hop_length=hop_length, n_fft=n_fft
                )

                b_chroma = librosa.feature.chroma_stft(
                    y=audio_b, sr=sample_rate, tuning=0, norm=2, hop_length=hop_length, n_fft=n_fft
                )
                
                D, _ = librosa.sequence.dtw(X=a_chroma, Y=b_chroma, metric='cosine')
                D[D == np.inf] = -1
                a[i, 1] = D
                
            max_power_a.append(np.max(S_a))
        
        if return_max_power:
            yield (a, b), max_power_a
        else:
            yield (a, b)
            
            
def reconstruct_signal(S_db, ref=1.0):
    sample_rate = 44100
    n_fft = 2048
    hop_length = 518
    
    S = librosa.db_to_power(S_db, ref=ref)
    
    audio = librosa.feature.inverse.mel_to_audio(
        M=S,
        sr=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length
    )
    
    return audio

In [0]:
import torch
import torch.nn as nn


class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, batch_norm=True, stride=2, track_stats=False):
        super().__init__()
        
        self.layers = nn.ModuleList()
        
        self.layers.append(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=4,
                stride=stride,
                padding=1,
                bias=not batch_norm
            )
        )
        
        if batch_norm:
            self.layers.append(
                nn.InstanceNorm2d(out_channels, track_running_stats=track_stats)
            )
            
        self.layers.append(
            nn.LeakyReLU(0.2)
        )
    
    def forward(self, x):
        for l in self.layers:
            x = l(x)
        
        return x
    

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=False, track_stats=False):
        super().__init__()
        
        self.layers = nn.ModuleList()
        
        self.layers.append(
            nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False
            )
        )
        
        self.layers.append(
            nn.InstanceNorm2d(out_channels, track_running_stats=track_stats)
        )
        
        if dropout:
            self.layers.append(
                nn.Dropout2d(0.5)
            )
            
        self.layers.append(
            nn.ReLU()
        )
        
    def forward(self, x):
        for l in self.layers:
            x = l(x)
        
        return x
    
    
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.encoder = nn.ModuleList([
            EncoderBlock(in_channels, 64, batch_norm=False),
            EncoderBlock(64, 128),
            EncoderBlock(128, 256),
            EncoderBlock(256, 512),
            EncoderBlock(512, 512),
            EncoderBlock(512, 512),
            EncoderBlock(512, 512),
            EncoderBlock(512, 512)
        ])
        
        self.decoder = nn.ModuleList([
            DecoderBlock(512, 512, dropout=True),
            DecoderBlock(1024, 512, dropout=True),
            DecoderBlock(1024, 512, dropout=True),
            DecoderBlock(1024, 512),
            DecoderBlock(1024, 256),
            DecoderBlock(512, 128),
            DecoderBlock(256, 64)
        ])
        
        self.last_conv = nn.ConvTranspose2d(
            in_channels=128,
            out_channels=out_channels,
            kernel_size=4,
            stride=2,
            padding=1
        )
        
    def forward(self, x):
        skips = []
        
        for l in self.encoder:
            x = l(x)
            skips.insert(0, x)
            
        for s, l in zip(skips[1:], self.decoder):
            x = l(x)
            x = torch.cat((s, x), dim=1)
            
        return self.last_conv(x)
    
    
class PatchGAN(nn.Module):
    def __init__(self, in_channels, sigmoid=False):
        super().__init__()
        
        self.layers = nn.ModuleList([
            EncoderBlock(in_channels, 64, batch_norm=False),
            EncoderBlock(64, 128),
            EncoderBlock(128, 256),
            EncoderBlock(256, 512, stride=1),
            nn.Conv2d(
                in_channels=512,
                out_channels=1,
                kernel_size=4,
                stride=1,
                padding=1
            )
        ])
        
        if sigmoid:
            self.layers.append(nn.Sigmoid())
        
    def forward(self, a, b):
        x = torch.cat((a, b), dim=1)
        
        for l in self.layers:
            x = l(x)
            
        return x

In [0]:
cuda = torch.cuda.is_available()

data_path = '/content/drive/My Drive/voice2voice/vcc2018/'
voice_a = 'SF2'
voice_b = 'TM1'

batch_size = 4

epochs = 10000
checkpoint_freq = 500
checkpoint_path = '/content/drive/My Drive/voice2voice/checkpoints/'
checkpoint_name = 'checkpoint'

lambda_pixel = 0.001

output_path = '/content/drive/My Drive/voice2voice/outputs/'

dtw = False

In [0]:
# filenames = get_vcc2018_filenames(data_path, voice_a, voice_b)

# train_filenames, test_filenames = shuffle_split(filenames, 0.8)

# checkpoint = torch.load(
#     (Path(checkpoint_path) / Path(checkpoint_name)).as_posix() + '_1568414757_0003.tar'
# )

# train_filenames = checkpoint['train_filenames']
# test_filenames = checkpoint['test_filenames']

sampler = audio_sampler(train_filenames, batch_size, dtw=dtw)

In [0]:
# Networks
if dtw:
    G_in_channels = 2
else:
    G_in_channels = 1

G = UNet(in_channels=G_in_channels, out_channels=1)
D = PatchGAN(in_channels=2)

# Optimizers
optim_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optim_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Losses
criterion_adversarial = nn.BCEWithLogitsLoss()
criterion_pixelwise = nn.L1Loss()

if cuda:
    Tensor = torch.cuda.FloatTensor
    G.cuda()
    D.cuda()
    
else:
    Tensor = torch.FloatTensor

In [0]:
downscale = nn.Upsample(size=(32, 32), mode='bicubic').cuda()

In [0]:
import time
import sys


G.train()
D.train()

start_time = time.time()

for e in range(1, epochs+1):
    a, b = next(sampler)
    
    # Model inputs
    real_a = Tensor(a)
    real_b = Tensor(b)
    
    # Adversarial ground truths
    real = torch.ones(batch_size, 1, 30, 30).type(Tensor)
    fake = torch.zeros(batch_size, 1, 30, 30).type(Tensor)
    
    # Train generator
    optim_G.zero_grad()
    
    fake_b = G(real_a)
    
    if dtw:
        real_a = real_a[:, :1, ...]

    pred_fake = D(real_a, fake_b)
    
    loss_adv = criterion_adversarial(pred_fake, real)
    
    small_fake_b = downscale(fake_b)
    small_real_a = downscale(real_a)
    loss_pix = criterion_pixelwise(small_fake_b, small_real_a)
    
#     loss_pix = criterion_pixelwise(fake_b, real_a)
    
    loss_G = loss_adv + lambda_pixel * loss_pix
    
    loss_G.backward()
    optim_G.step()
    
    # Train discriminator
    optim_D.zero_grad()
    
    real_a = torch.zeros_like(real_a).cuda()
    pred_real = D(real_a, real_b)
    loss_real = criterion_adversarial(pred_real, real)
    
    pred_fake = D(real_a, fake_b.detach())
    loss_fake = criterion_adversarial(pred_fake, fake)
    
    loss_D = 0.5 * (loss_real + loss_fake)
    
    loss_D.backward()
    optim_D.step()
    
    # Checkpoint
    if e % checkpoint_freq == 0:
        torch.save(
            {
                'epoch': e,

                'G_state_dict': G.state_dict(),
                'D_state_dict': D.state_dict(),

                'G_loss': loss_G,
                'G_loss_adv': loss_adv,
                'G_loss_pix': loss_pix,
                'D_loss': loss_D,

                'G_optim_state_dict': optim_G.state_dict(),
                'D_optim_state_dict': optim_D.state_dict(),
                
                'train_filenames': train_filenames,
                'test_filenames': test_filenames
            },
            (Path(checkpoint_path) / Path(checkpoint_name)).as_posix() + f'_{int(start_time)}_{e:0{len(str(epochs))}}.tar'
        )

    # Log
    time_left = int((time.time() - start_time) * (epochs - e) / e)
    
    h = time_left // 3600
    m = (time_left % 3600) // 60
    s = (time_left % 3600) % 60

    eta = f'{h:02}h {m:02}m {s:02}s'
    
    sys.stdout.write(
        f'\r[Epoch {e:0{len(str(epochs))}}/{epochs}] [G loss: {loss_G:.3f}, pix: {loss_pix:.3f}, adv: {loss_adv:.3f}] [D loss: {loss_D:.3f}] ETA: {eta}'
    )

In [0]:
import matplotlib.pyplot as plt


name = (Path(checkpoint_path) / Path(checkpoint_name)).as_posix() + '_1568420748_01500.tar'
checkpoint = torch.load(name)
G.load_state_dict(checkpoint['G_state_dict'])
G.eval()

test_sampler = audio_sampler(test_filenames, 1, dtw=False)

real_a, real_b = next(sampler)

if dtw:
    real_a[:, 1] = -1

fake_b = G(Tensor(real_a)).cpu().data.numpy()

fig = plt.figure()

fig.add_subplot(1, 3, 1)
plt.imshow(real_a[0, 0])

fig.add_subplot(1, 3, 2)
plt.imshow(fake_b[0, 0])

fig.add_subplot(1, 3, 3)
plt.imshow(real_b[0, 0])

plt.show()

In [0]:
for i in range(1):
    source = reconstruct_signal(real_a[i][0])
    target = reconstruct_signal(real_b[i][0])
    fake = reconstruct_signal(fake_b[i][0])

    librosa.output.write_wav(output_path + f'source_{i}.wav', source, sr=44100)
    librosa.output.write_wav(output_path + f'target_{i}.wav', target, sr=44100)
    librosa.output.write_wav(output_path + f'fake_{i}.wav', fake, sr=44100)

In [0]:
import torchvision

im = Tensor(real_b)

im2 = nn.Upsample(size=(32, 32), mode='bicubic')(im)

plt.imshow(im2.cpu().data.numpy()[0, 0])
plt.show()

plt.imshow(real_a[0, 0])
plt.show()

In [0]:
# name = (Path(checkpoint_path) / Path(checkpoint_name)).as_posix() + '_1568364460_0250.tar'
# checkpoint = torch.load(name)
# G.load_state_dict(checkpoint['G_state_dict'])