# Libs & Utils

In [23]:
import torch
import numpy as np
import random
import os
from torch.utils.data import Dataset
from tqdm import tqdm
import librosa
from torch.utils.data import DataLoader
from IPython.display import Audio

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [24]:
def load_checkpoint(model, optimizer, filename, lr, device):
    checkpoint = torch.load(filename, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    current_epoch = checkpoint["current_epoch"]
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    return current_epoch

def save_checkpoint(model, optimizer, config, filename="my_checkpoint.pth"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "current_epoch": config.current_epoch
    }
    torch.save(checkpoint, filename)

# Dataset

In [25]:
class Voice_Dataset(Dataset):
    def __init__(self, source_voice_path, target_voice_path):
        self.source_voice_spects = os.path.join(source_voice_path, "speaker1", "spects")
        self.source_voice_embeds = os.path.join(source_voice_path, "speaker1", "embeddings")
        self.target_voice_spects = os.path.join(target_voice_path, "speaker2", "spects")
        self.target_voice_embeds = os.path.join(target_voice_path, "speaker2", "embeddings")
        #
        self.source_voices_len = len(os.listdir(self.source_voice_spects))
        self.target_voices_len = len(os.listdir(self.target_voice_spects))
        #
        self.dataset_length = min(self.source_voices_len, self.target_voices_len)
        self.idxs = [*range(self.dataset_length)]
        random.shuffle(self.idxs)
    
    def __len__(self):
        return self.dataset_length

    def __getitem__(self, index):
      src_voice_spect = torch.load(
          os.path.join(self.source_voice_spects, f"spect{index}.pth")
      )
      src_voice_embed = torch.load(
          os.path.join(self.source_voice_embeds, f"embed{index}.pth")
      )
      trg_voice_spect = torch.load(
          os.path.join(self.target_voice_spects, f"spect{index}.pth")
      )
      trg_voice_embed = torch.load(
          os.path.join(self.target_voice_embeds, f"embed{index}.pth")
      )
      return src_voice_spect, src_voice_embed, trg_voice_spect, trg_voice_embed

In [26]:
src_voice_path = "/content/gdrive/MyDrive/processed_data/"
target_voice_path = "/content/gdrive/MyDrive/processed_data/"

In [27]:
dataset = Voice_Dataset(src_voice_path, target_voice_path)
train_dataloader = DataLoader(dataset, batch_size=1)
train_dataloader_iter = iter(train_dataloader)
src_spect, src_embed, trg_spect, trg_embed = next(train_dataloader_iter)
print(
    src_spect.shape,
    src_embed.shape,
    trg_spect.shape,
    trg_embed.shape
)

torch.Size([1, 80, 128]) torch.Size([1, 256]) torch.Size([1, 80, 128]) torch.Size([1, 256])


# Generator

In [28]:
import torch
from torch import nn
from fastai.layers import init_linear

class CIN(nn.Module):
    
    def __init__(self, dim_out, embed_dim):
        super().__init__()
       
        self.gamma = nn.Linear(embed_dim, dim_out)
        init_linear(self.gamma)
        self.beta = nn.Linear(embed_dim, dim_out)
        init_linear(self.beta)
    
    def forward(self, x, embed):
        sigma, mu = torch.std_mean(x, dim=2, keepdim=True)
        sigma = torch.clamp(sigma, min=1e-7)
        gamma = self.gamma(embed)[..., None]
        beta = self.gamma(embed)[..., None]
        return gamma * (x - mu) / sigma + beta


class ConditioningBlock(nn.Module):
    
    def __init__(self, dim_in, dim_out, kernel_size, stride, padding, embed_dim):
        super(ConditioningBlock, self).__init__()
        
        self.conv = nn.Conv1d(in_channels=dim_in, out_channels=dim_out, 
            kernel_size=kernel_size, stride=stride, padding=padding, bias=True
        )
        self.cin = CIN(dim_out, embed_dim)
        self.glu = nn.GLU(dim=1)

    def forward(self, x, embed):
        x = self.conv(x)
        x = self.cin(x, embed)
        x = self.glu(x)
        return x

In [29]:
class Generator(nn.Module):

    def __init__(self, embed_dim, device=None):
        super().__init__()
        self.device = device
        #if device is None:
        #    self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        #else:
        #    self.device = device

        # speaker embed
        self.embed_dim = embed_dim
        self.embed_map = nn.Sequential(
            nn.Linear(self.embed_dim*2, 256),
            nn.SELU()
        )
        init_linear(self.embed_map[0])

        self.init_layer = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=128, kernel_size=(5, 15), stride=(1, 1), padding=(2, 7)),
            nn.GLU(dim=1)
        )

        dims = [64, 128, 256]
        block = []
        for i in range(1, len(dims)):
            cur, nxt = dims[i-1], dims[i]
            block.append(nn.Conv2d(in_channels=cur, out_channels=nxt*2, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=True))
            block.append(nn.InstanceNorm2d(num_features=nxt*2, affine=True))
            block.append(nn.GLU(dim=1))
        self.down_sample = nn.Sequential(*block)

        self.down_converse = nn.Sequential(
            nn.Conv1d(in_channels=5120, out_channels=256, kernel_size=1, stride=1, padding=0, bias=True),
            nn.InstanceNorm1d(num_features=256, affine=True)
        )

        self.cond_1 = ConditioningBlock(dim_in=256, dim_out=512, kernel_size=5, stride=1, padding=2, embed_dim=self.embed_dim)
        self.cond_2 = ConditioningBlock(dim_in=256, dim_out=512, kernel_size=5, stride=1, padding=2, embed_dim=self.embed_dim)
        self.cond_3 = ConditioningBlock(dim_in=256, dim_out=512, kernel_size=5, stride=1, padding=2, embed_dim=self.embed_dim)
        self.cond_4 = ConditioningBlock(dim_in=256, dim_out=512, kernel_size=5, stride=1, padding=2, embed_dim=self.embed_dim)
        self.cond_5 = ConditioningBlock(dim_in=256, dim_out=512, kernel_size=5, stride=1, padding=2, embed_dim=self.embed_dim)
        self.cond_6 = ConditioningBlock(dim_in=256, dim_out=512, kernel_size=5, stride=1, padding=2, embed_dim=self.embed_dim)
        self.cond_7 = ConditioningBlock(dim_in=256, dim_out=512, kernel_size=5, stride=1, padding=2, embed_dim=self.embed_dim)
        self.cond_8 = ConditioningBlock(dim_in=256, dim_out=512, kernel_size=5, stride=1, padding=2, embed_dim=self.embed_dim)
        self.cond_9 = ConditioningBlock(dim_in=256, dim_out=512, kernel_size=5, stride=1, padding=2, embed_dim=self.embed_dim)

        self.up_converse = nn.Conv1d(in_channels=256, out_channels=5120, kernel_size=1, stride=1, padding=0, bias=True)

        self.up_sample = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=1024, kernel_size=5, stride=1, padding=2, bias=True),
            nn.PixelShuffle(2), # channels / 4
            nn.GLU(dim=1),
            nn.Conv2d(in_channels=128, out_channels=512, kernel_size=5, stride=1, padding=2, bias=True),
            nn.PixelShuffle(2),
            nn.GLU(dim=1),
        )

        self.out_layer = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=(5, 15), stride=(1, 1), padding=(2, 7), bias=True)


    # x:   (bs, 80, width)
    # src: (256,)
    # trg: (256,)
    def forward(self, x, src, trg):
        x = x.to(self.device)
        src = src.to(self.device)
        trg = trg.to(self.device)

        bs, _, width = x.shape
        src_trg = torch.cat([src, trg], dim=1)
        src_trg = self.embed_map(src_trg)
        # initialize layer
        x = x.unsqueeze(1)
        x = self.init_layer(x)
        
        # down sampling layer
        x = self.down_sample(x)
        
        # down conversion layer
        x = x.contiguous().view(bs, 5120, width // 4)

        x = self.down_converse(x)
        
        # bottleneck layer
        x = self.cond_1(x, src_trg)
        x = self.cond_2(x, src_trg)
        x = self.cond_3(x, src_trg)
        x = self.cond_4(x, src_trg)
        x = self.cond_5(x, src_trg)
        x = self.cond_6(x, src_trg)
        x = self.cond_7(x, src_trg)
        x = self.cond_8(x, src_trg)
        x = self.cond_9(x, src_trg)
        
        # up conversion layer
        x = self.up_converse(x)
        x = x.view(bs, 256, 20, width // 4)

        # up sampling layer
        x = self.up_sample(x)
        
        # output layer
        x = self.out_layer(x)
        
        return x.view(bs, 80, width)

In [30]:
genr = Generator(embed_dim = 256, device=None)
random_tensor = torch.randn(1, 80, 128)
embed = torch.randn(1, 256)
result = genr(random_tensor, embed, embed)
result.shape

torch.Size([1, 80, 128])

# Discriminator

In [31]:
class Discriminator(nn.Module):

    def __init__(self, embed_dim, device=None):
        super().__init__()
        self.device = device
        #if device is None:
        #    self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        #else:
        #    self.device = device

        self.embed_dim = embed_dim
        self.input_dropout = nn.Dropout(p=0.3)

        self.init_layer = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=1),
            nn.GLU(dim=1)
        )

        dims = [64, 128, 256, 512]
        block = []
        for i in range(1, len(dims)):
            cur, nxt = dims[i-1], dims[i]
            block.append(nn.Conv2d(in_channels=cur, out_channels=nxt*2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=True))
            block.append(nn.InstanceNorm2d(num_features=nxt*2, affine=True))
            block.append(nn.GLU(dim=1))
        block.append(nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=(1, 5), stride=(1, 1), padding=(0, 2), bias=True))
        block.append(nn.InstanceNorm2d(num_features=1024, affine=True))
        block.append(nn.GLU(dim=1))
        self.down_sample = nn.Sequential(*block)

        self.linear = nn.Linear(in_features=512*10*16, out_features=1)

        self.embed_map_src = nn.Sequential(
            nn.Linear(self.embed_dim, 256),
            nn.SELU()
        )
        self.embed_map_trg = nn.Sequential(
            nn.Linear(self.embed_dim, 256),
            nn.SELU()
        )
        self.embed = nn.Linear(self.embed_dim*2, 512)
        init_linear(self.embed_map_src[0])
        init_linear(self.embed_map_trg[0])
        init_linear(self.embed)

    def forward(self, x, src, trg, dropout=False):
        x = x.to(self.device)
        src = src.to(self.device)
        trg = trg.to(self.device)

        bs, _, width = x.shape

        src = self.embed_map_src(src)
        trg = self.embed_map_trg(trg)
        src_trg = torch.cat([src, trg], dim=1)
        embed = self.embed(src_trg)

        # input drop out
        if dropout:
            x = self.input_dropout(x)

        # init layer
        x = x.unsqueeze(1)
        x = self.init_layer(x)
        
        # down sampling layer
        x = self.down_sample(x)
        
        # global sum pooling
        h = torch.sum(x, dim=(-1, -2))
        x = self.linear(x.view(-1, 512*10*16))
        
        y = x + (embed[:, None]@h[..., None]).squeeze(-1)
        return y.view(bs)

In [32]:
discr = Discriminator(embed_dim = 256, device=None)
random_tensor = torch.randn(1, 80, 128)
embed = torch.randn(1, 256)
result = discr(random_tensor, embed, embed)
result

tensor([-10.6861], grad_fn=<ViewBackward0>)

# Training

In [33]:
src_voice_path = "/content/gdrive/MyDrive/processed_data/"
target_voice_path = "/content/gdrive/MyDrive/processed_data/"

In [34]:
class Config:
    def __init__(self):
        self.resume_path_gen = "/content/gdrive/MyDrive/test/generator2.pth"
        self.resume_path_dis = "/content/gdrive/MyDrive/test/discriminator2.pth"
        self.save_model_path = "/content/gdrive/MyDrive/test"
        self.optimizers = {
            "gen_lr": 0.0001,
            "dis_lr": 0.00005,
            "beta1": 0.9,
            "beta2": 0.999
        }
        self.hparam = {
        "a": 1,
        "b": 0,
        "lambda_id": 5,
        "lambda_cyc": 10
        }
        self.num_epochs = 1001
        self.epoch_save = 2
        self.current_epoch = 0
        self.gen_freq = 5
        self.batch_size = 1
        self.num_workers = 0
        self.load_checkpoint = False

In [35]:
def train(config, source_voice_path, target_voice_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = Voice_Dataset(source_voice_path = source_voice_path,
                  target_voice_path = target_voice_path)
    train_loader = DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=False
    )
    gen = Generator(embed_dim=256).to(device)
    dis = Discriminator(embed_dim=256).to(device)

    gen_lr = config.optimizers['gen_lr']
    dis_lr = config.optimizers['dis_lr']
    beta1 = config.optimizers['beta1']
    beta2 = config.optimizers['beta2']

    gen_opt = torch.optim.Adam(gen.parameters(), gen_lr, [beta1, beta2])
    dis_opt = torch.optim.Adam(dis.parameters(), dis_lr, [beta1, beta2])
    
    hparam = config.hparam
    l1_loss = nn.L1Loss()
    l2_loss = nn.MSELoss()

    if config.load_checkpoint:
        print("downloaded weights")
        # generator
        config.current_epoch = load_checkpoint(gen, gen_opt, config.resume_path_gen, config.optimizers['gen_lr'], device)
        # discriminator
        _ = load_checkpoint(dis, dis_opt, config.resume_path_dis, config.optimizers['dis_lr'], device)
    print(f"starting from epoch {config.current_epoch}")
    for epoch in range(config.current_epoch, config.num_epochs):
        loop = tqdm(train_loader, leave=True)
        for idx, (src_spect , src_embed, trg_spect , trg_embed) in enumerate(loop):
            src_spect = src_spect.to(device)
            src_embed = src_embed.to(device)
            trg_spect = trg_spect.to(device)
            trg_embed = trg_embed.to(device)
            # gen inference
            x_src_src = gen(src_spect, src_embed, src_embed)
            x_src_trg = gen(src_spect, src_embed, trg_embed)
            x_src_trg_src = gen(x_src_trg, trg_embed, src_embed)
            # discriminator
            d_src = dis(src_spect, src_embed, trg_embed)
            d_src_trg = dis(x_src_trg, trg_embed, src_embed)
            #
            dis_loss = torch.mean((d_src_trg - hparam['b']) ** 2 + (d_src - hparam['a']) ** 2)
            # reset grad discriminator
            dis_opt.zero_grad()
            dis_loss.backward(retain_graph=True)
            dis_opt.step()
            if idx % config.gen_freq == 0:
                id_loss = l2_loss(src_spect, x_src_src)
                cyc_loss = l1_loss(src_spect, x_src_trg_src)
                d_src_trg_2 = dis(x_src_trg, trg_embed, src_embed)
                adv_loss = torch.mean((d_src_trg_2 - hparam['a']) ** 2)
                gen_loss = hparam['lambda_id'] * id_loss + hparam['lambda_cyc'] * cyc_loss + adv_loss
                gen_opt.zero_grad()
                gen_loss.backward(retain_graph=True)
                gen_opt.step()
                metrics = dis_loss.item(), gen_loss.item(), adv_loss.item()
                loop.set_postfix({'dis loss':metrics[0], 'gen loss':metrics[1], 'adv loss':metrics[2]})
        if (epoch % config.epoch_save == 0) and (epoch != 0):
            save_checkpoint(gen, gen_opt, config, os.path.join(config.save_model_path, f"generator{epoch}.pth"))
            save_checkpoint(dis, dis_opt, config, os.path.join(config.save_model_path, f"discriminator{epoch}.pth"))

In [36]:
config = Config()
train(config, src_voice_path, target_voice_path)

downloaded weights
starting from epoch 100


  0%|          | 4/1132 [00:08<42:03,  2.24s/it, dis loss=325, gen loss=7.68e+3, adv loss=7.46e+3]


KeyboardInterrupt: ignored