In [1]:
# imports
import numpy as np
import soundfile as sf
from functools import partial

import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence, pack_sequence, pad_packed_sequence, pack_padded_sequence
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms import GaussianBlur

from dataset import parse_manifest, get_yaapt_f0

In [2]:
# constants
data_path = '.'  # 'datasets/VCTK/hubert100'
f0_path = 'datasets/VCTK/f0_stats.th'
device = 'cpu'

In [3]:
# def get_spkrs_dict(path):
#     names, _ = parse_manifest(path)
#     speakers = [n.name.split('_')[0] for n in names]
#     spk_id_dict = {n:i for i, n in enumerate(np.unique(speakers))}
#     return spk_id_dict

In [4]:
def get_spkrs_dict(path):
    speakers = []
    with open(path, 'r') as f:
        for line in f.readlines():
            val_dict = eval(line)
            speakers.append(val_dict['audio'].split('_')[0])
    return {n:i for i, n in enumerate(np.unique(speakers))}

In [5]:
# def mean_pool_1d(arr, pool_size):
#     arr = arr.reshape(-1, pool_size)
#     return arr.mean(axis=1).reshape(-1)

In [6]:
# from multiprocessing import Pool
# import tqdm

# def parse_audio(p, spk_id_dict, f0_param_dict):
#     aud, _ = sf.read(p)
#     spk_id = spk_id_dict[p.name.split('_')[0]]
#     f0 = get_yaapt_f0(aud.reshape((1, 1, -1)))
#     return torch.from_numpy((f0 - f0_param_dict[spk_id]['f0_mean'])/f0_param_dict[spk_id]['f0_std']).view(-1), torch.IntTensor([spk_id])
    

# def prepare_dataset(path, spk_id_dict, f0_param_dict):
#     paths, seqs = parse_manifest(path)
#     seqs = [torch.from_numpy(s) for s in seqs]
#     seqs = seqs[:1]
#     paths = paths[:1]
#     with Pool() as p:
#         fs, spk_ids = zip(*list(tqdm.tqdm(p.imap(partial(parse_audio, spk_id_dict=spk_id_dict, f0_param_dict=f0_param_dict), paths), total=len(paths))))
#     fs = list(fs)
#     for i in range(len(seqs)):
#         fs[i] = mean_pool_1d(fs[i][:len(seqs[i])*4], 4)
#     return pad_sequence(seqs, batch_first=True, padding_value=100), pad_sequence(fs, batch_first=True, padding_value=100), torch.concat(spk_ids).view(-1, 1)

In [7]:
def prepare_dataset(path, spk_id_dict, f0_param_dict):
    fs, seqs, spk_ids = [], [], []
    with open(path, 'r') as f:
        for line in f.readlines():
            val_dict = eval(line)
            seqs.append(torch.IntTensor(val_dict['units']))
            fs.append(torch.FloatTensor(val_dict['f0']))
            name = val_dict['audio'].split('_')[0]
            spk_ids.append(torch.IntTensor([spk_id_dict[name]]))
    return pad_sequence(seqs, batch_first=True, padding_value=100), pad_sequence(fs, batch_first=True, padding_value=100), torch.concat(spk_ids).view(-1, 1)

In [14]:
EPS = 0.001

def get_scaling(fs, nbins=50, f_min=None, scale=None):
    if f_min is None:
        f_min = fs.min()
    if scale is None:
        scale = (fs.max() + EPS - f_min)/nbins
    return f_min, scale

def quantise_f0(fs, nbins=50, f_min=None, scale=None):
    if f_min is None:
        f_min = fs.min()
    if scale is None:
        scale = (fs.max() + EPS - f_min)/nbins
    q_fs = torch.clip(torch.div(fs - f_min, scale, rounding_mode='floor').type(torch.LongTensor), min=0, max=nbins-1)
    return nn.functional.one_hot(q_fs, num_classes=nbins), f_min, scale

def prepare_f0(fs, nbins=50, f_min=None, scale=None):
    res = torch.zeros((fs.shape[0], fs.shape[1], nbins)).long()
    res[fs!=100], fmin, scale = quantise_f0(fs[fs!=100], nbins, f_min, scale)
    
    filt = GaussianBlur(kernel_size=(5, 1), sigma=0.5)
    res = filt(res.float())
    res[fs==100] = -1
    return res, fmin, scale

In [15]:
class PitchDataset(Dataset):
    def __init__(self, path, spk_id_dict, f0_param_dict, nbins=50, f_min=None, scale=None):
        self.vals, self.fs, self.spk_ids = prepare_dataset(path, spk_id_dict, f0_param_dict)
#         self.fs1, self.f_min, self.scale = prepare_f0(self.fs, nbins, f_min, scale)
        self.f_min, self.scale = get_scaling(self.fs, nbins, f_min, scale)
        self.nbins = nbins
        
    def __len__(self):
        return len(self.vals)
    
    def __getitem__(self, i):
#         return self.vals[i], self.fs1[i], self.fs[i], self.spk_ids[i]
        return self.vals[i], self.fs[i], self.spk_ids[i]

In [10]:
class PitchPredictor(nn.Module):
    def __init__(self, token_dict_size=100, spk_dict_size=199, emb_size=32, nbins=50):
        super(PitchPredictor, self).__init__()
        self.nbins = nbins
        self.token_emb = nn.Embedding(token_dict_size + 1, emb_size, padding_idx=token_dict_size)
        self.spk_emb = nn.Embedding(spk_dict_size + 1, emb_size, padding_idx=spk_dict_size)
        self.leaky = nn.LeakyReLU()
        self.dropout = nn.Dropout(p=0.0)
        
        self.cnn1 = nn.Conv1d(2*emb_size, 128, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(128)
        self.cnn11 = nn.Conv1d(128, 128, kernel_size=3, padding=1)
        self.bn11 = nn.BatchNorm1d(128)
        self.cnn12 = nn.Conv1d(128, 128, kernel_size=3, padding=1)
        self.bn12 = nn.BatchNorm1d(128)
        
        self.cnn2 = nn.Conv1d(128, nbins, kernel_size=3, padding=1)
        
    def forward(self, seq, spk_id):
        emb_seq = self.token_emb(seq)
        emb_spk = torch.repeat_interleave(self.spk_emb(spk_id), seq.shape[-1], dim=1)
        emb_seq = torch.cat([emb_seq, emb_spk], dim=-1)
        
        cnn1 = self.leaky(self.dropout(self.bn1(self.cnn1(emb_seq.transpose(1, 2)))))
        cnn1 = self.leaky(self.dropout(self.bn11(self.cnn11(cnn1))))
        cnn1 = self.leaky(self.dropout(self.bn12(self.cnn12(cnn1))))
        return self.cnn2(cnn1).squeeze(1)
    
    def infer_norm_freq(self, seq, spk_id, fmin, scale):
        preds = torch.sigmoid(self(seq, spk_id).transpose(1, 2))  # calculate class probs
        f_weights = torch.linspace(fmin + 0.5 * scale, fmin + (self.nbins - 0.5) * scale, self.nbins)
        return torch.inner(preds, f_weights)
    
    def calc_norm_freq(self, preds, fmin, scale):
        preds = torch.sigmoid(preds.transpose(1, 2))  # calculate class probs
        f_weights = torch.linspace(fmin + 0.5 * scale, fmin + (self.nbins - 0.5) * scale, self.nbins)
        return torch.inner(preds, f_weights)

In [11]:
# class PitchLoss(nn.Module):
#     def __init__(self, pad_idx=-1):
#         super(PitchLoss, self).__init__()
#         self.pad_idx = pad_idx
#         self.bce = nn.BCEWithLogitsLoss(reduction='none')
    
#     def forward(self, preds, gt):
#         mask = (gt != self.pad_idx)
#         total_loss = self.bce(preds, gt)        
#         return (mask * total_loss).sum() / mask.sum()

class PitchLoss(nn.Module):
    def __init__(self, f_min, scale, nbins, pad_idx=-1):
        super(PitchLoss, self).__init__()
        self.pad_idx = pad_idx
        self.f_min = f_min
        self.scale = scale
        self.nbins = nbins
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
    
    def forward(self, preds, gt):
        gt, _, _ = prepare_f0(gt, self.nbins, self.f_min, self.scale)
        mask = (gt != self.pad_idx)
        total_loss = self.bce(preds, gt)        
        return (mask * total_loss).sum() / mask.sum()
    
class PitchRegLoss(nn.Module):
    def __init__(self, pad_idx=100):
        super(PitchRegLoss, self).__init__()
        self.pad_idx = pad_idx
        self.mse = nn.MSELoss(reduction='none')
    
    def forward(self, preds, gts):
        mask = (gts != self.pad_idx)
        total_loss = self.mse(preds, gts)        
        return (mask * total_loss).sum() / mask.sum()

In [12]:
f0_param_dict = torch.load(f0_path)
spk_id_dict = get_spkrs_dict(f'{data_path}/train.txt')

In [17]:
ds_train = PitchDataset(f'{data_path}/train.txt', spk_id_dict, f0_param_dict, nbins=50)
dl_train = DataLoader(ds_train, batch_size=32, shuffle=True)

ds_val = PitchDataset(f'{data_path}/val.txt', spk_id_dict, f0_param_dict, nbins=ds_train.nbins, f_min=ds_train.f_min, scale=ds_train.scale)
dl_val = DataLoader(ds_val, batch_size=32, shuffle=True)

In [18]:
model = PitchPredictor(nbins=50)
model.to(device)

PitchPredictor(
  (token_emb): Embedding(101, 32, padding_idx=100)
  (spk_emb): Embedding(200, 32, padding_idx=199)
  (leaky): LeakyReLU(negative_slope=0.01)
  (dropout): Dropout(p=0.0, inplace=False)
  (cnn1): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (cnn11): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn11): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (cnn12): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn12): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (cnn2): Conv1d(128, 50, kernel_size=(3,), stride=(1,), padding=(1,))
)

In [19]:
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
pitch_loss = PitchLoss(ds_train.f_min, ds_train.scale, ds_train.nbins)
reg_loss = PitchRegLoss()

In [36]:
# dl_train = dl_val
# ds_train = ds_val


Epoch: 0
 finished: 99.90%, train loss: 0.06771
total_train_loss: 0.05443, train MSE: 194573.07812
total_val_loss: 0.04092, val MSE: 1772.62000

Epoch: 1
 finished: 99.90%, train loss: 0.05906
total_train_loss: 0.04041, train MSE: 1682.88159
total_val_loss: 0.03941, val MSE: 1387.76257

Epoch: 2
 finished: 99.90%, train loss: 0.05162
total_train_loss: 0.03949, train MSE: 1343.99573
total_val_loss: 0.03900, val MSE: 1339.63123

Epoch: 3
 finished: 99.90%, train loss: 0.04731
total_train_loss: 0.03907, train MSE: 1224.34924
total_val_loss: 0.03881, val MSE: 1215.50439

Epoch: 4
 finished: 99.90%, train loss: 0.04696
total_train_loss: 0.03882, train MSE: 1174.66016
total_val_loss: 0.03863, val MSE: 1222.37939

Epoch: 5
 finished: 99.90%, train loss: 0.07348
total_train_loss: 0.03869, train MSE: 1156.67810
total_val_loss: 0.03851, val MSE: 1265.03882

Epoch: 6
 finished: 99.90%, train loss: 0.03669
total_train_loss: 0.03853, train MSE: 1139.99341
total_val_loss: 0.03839, val MSE: 1186.385

In [44]:
for epoch in range(100):
    print(f'\nEpoch: {epoch}')

    model.train()
    total_train_loss = 0
    total_train_mse = 0
    for i, batch in enumerate(dl_train):
        seqs, gts_reg, spk_id = batch
        seqs = seqs.to(device)
        gts_reg = gts_reg.to(device)
        spk_id = spk_id.to(device)
        opt.zero_grad()

        preds = model(seqs, spk_id)
        loss = pitch_loss(preds.transpose(1,2), gts_reg)
        loss.backward()
        opt.step()
        total_train_loss += loss
        total_train_mse += reg_loss(model.calc_norm_freq(preds, ds_train.f_min, ds_train.scale), gts_reg)

        print(f'\r finished: {100*i/len(dl_train):.2f}%, train loss: {loss:.5f}', end='')

    # validation
    model.eval()
    total_val_loss = 0
    total_val_mse = 0
    for i, batch in enumerate(dl_val):
        seqs, gts_reg, spk_id = batch
        seqs = seqs.to(device)
        gts_reg = gts_reg.to(device)
        spk_id = spk_id.to(device)
        with torch.no_grad():
            preds = model(seqs, spk_id)
            loss = pitch_loss(preds.transpose(1,2), gts_reg)
        total_val_loss += loss
        total_val_mse += reg_loss(model.calc_norm_freq(preds, ds_val.f_min, ds_val.scale), gts_reg)

    print(f'\ntotal_train_loss: {total_train_loss/len(dl_train):.5f}, train MSE: {total_train_mse/len(dl_train):.5f}')
    print(f'total_val_loss: {total_val_loss/len(dl_val):.5f}, val MSE: {total_val_mse/len(dl_val):.5f}')
for epoch in range(100):
    print(f'\nEpoch: {epoch}')
    
    model.train()
    total_train_loss = 0
    total_train_mse = 0
    for i, batch in enumerate(dl_train):
        seqs, gts_cls, gts_reg, spk_id = batch
        seqs = seqs.to(device)
        gts_cls = gts_cls.to(device)
        gts_reg = gts_reg.to(device)
        spk_id = spk_id.to(device)
        opt.zero_grad()

        preds = model(seqs, spk_id)
        loss = pitch_loss(preds, gts_cls.transpose(1,2))
        loss.backward()
        opt.step()
        total_train_loss += loss
        total_train_mse += reg_loss(model.calc_norm_freq(preds, ds_train.f_min, ds_train.scale), gts_reg)
        
        print(f'\r finished: {100*i/len(dl_train):.2f}%, train loss: {loss:.5f}', end='')
    
    # validation 
    model.eval()
    total_val_loss = 0
    total_val_mse = 0
    for i, batch in enumerate(dl_val):
        seqs, gts_cls, gts_reg, spk_id = batch
        seqs = seqs.to(device)
        gts_cls = gts_cls.to(device)
        gts_reg = gts_reg.to(device)
        spk_id = spk_id.to(device)
        with torch.no_grad():
            preds = model(seqs, spk_id)
            loss = pitch_loss(preds, gts_cls.transpose(1,2))
        total_val_loss += loss
        total_val_mse += reg_loss(model.calc_norm_freq(preds, ds_val.f_min, ds_val.scale), gts_reg)
        
    print(f'\ntotal_train_loss: {total_train_loss/len(dl_train):.5f}, train MSE: {total_train_mse/len(dl_train):.5f}')
    print(f'total_val_loss: {total_val_loss/len(dl_val):.5f}, val MSE: {total_val_mse/len(dl_val):.5f}')


Epoch: 0
 finished: 99.76%, train loss: 0.04000
total_train_loss: 0.03820, train MSE: 1100.38123
total_val_loss: 0.03798, val MSE: 1061.28638

Epoch: 1
 finished: 99.76%, train loss: 0.03668
total_train_loss: 0.03815, train MSE: 1093.88330
total_val_loss: 0.03794, val MSE: 1094.55273

Epoch: 2
 finished: 99.76%, train loss: 0.03792
total_train_loss: 0.03810, train MSE: 1086.51672
total_val_loss: 0.03795, val MSE: 1086.67590

Epoch: 3
 finished: 99.76%, train loss: 0.03787
total_train_loss: 0.03804, train MSE: 1082.90088
total_val_loss: 0.03787, val MSE: 1058.09058

Epoch: 4
 finished: 99.76%, train loss: 0.03737
total_train_loss: 0.03802, train MSE: 1081.33630
total_val_loss: 0.03778, val MSE: 1042.58691

Epoch: 5
 finished: 99.76%, train loss: 0.03698
total_train_loss: 0.03797, train MSE: 1075.89636
total_val_loss: 0.03776, val MSE: 1031.78772

Epoch: 6
 finished: 99.76%, train loss: 0.03707
total_train_loss: 0.03793, train MSE: 1069.02795
total_val_loss: 0.03768, val MSE: 1039.17578

KeyboardInterrupt: 

In [43]:
(((ds_val.fs[ds_val.fs!=100] - ds_val.fs[ds_val.fs!=100].mean())**2)**0.5).mean()

tensor(76.9706)

In [12]:
# naive train test split
with open('results.txt', 'r') as f, open('train.txt', 'a+') as f_tr, open('val.txt', 'a+') as f_val:
    for line in f.readlines():
        if np.random.rand() <= .7:
            f_tr.write(line)
        else:
            f_val.write(line)

43873 30785 13088


In [112]:
all_fs, all_seqs, all_spk_id = [], [], []
with open('results.txt', 'r') as f:
    for line in f.readlines():
        val_dict = eval(line)
        all_seqs.append(torch.IntTensor(val_dict['units']))
        all_fs.append(torch.FloatTensor(val_dict['f0']))
        name = val_dict['audio'].split('_')[0]
        if name in spk_id_dict:
            all_spk_id.append(torch.IntTensor([spk_id_dict[name]]))
        else:
            print(name)

s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s5
s