In [29]:
import pickle
from glob import glob
import malaya_speech
from datasets import Audio
from sklearn.utils import shuffle
import random
import torch
import json
from librosa.util import normalize
from torch.nn.utils.rnn import pad_sequence
from malaya_speech.augmentation.waveform import random_sampling
from malaya_speech.torch_model.hifivoice.env import AttrDict
from malaya_speech.torch_model.hifivoice.meldataset import mel_spectrogram, mel_normalize
from malaya_speech.torch_model.mediumvc.any2any import MagicModel

In [43]:
class Dataset(torch.utils.data.IterableDataset):
    
    sr = 22050
    
    def __init__(self):
        super(Dataset).__init__()
        
        self.speakers = glob('random-embedding-*.pkl') + glob('/home/husein/ssd2/processed-youtube/*.pkl')
        self.speakers = shuffle(self.speakers)
        
        self.audio = Audio(sampling_rate=self.sr)
        config = 'hifigan-config.json'
        with open(config) as fopen:
            json_config = json.load(fopen)

        self.config = AttrDict(json_config)
        
    def __iter__(self):
        while True:
            batch = []
            for i in range(len(self.speakers)):
                with open(self.speakers[i], 'rb') as fopen:
                    data = pickle.load(fopen)
                    
                data = random.sample(data, min(len(data), 4))
                for d in data:
                    spk_emb = d['classification_model'][0]
                    y = dataset.audio.decode_example(dataset.audio.encode_example(d['wav_data']))
                    y = y['array']
                    y = random_sampling(y, 22050, length = 8000)
                    batch.append((y, spk_emb))
                    
                if len(batch) >= 32:
                    batch = shuffle(batch)
                    for y, spk_emb in batch:
                        spk_emb = normalize(spk_emb)
                        audio = normalize(y) * 0.95
                        audio = torch.FloatTensor(audio)
                        audio = audio.unsqueeze(0)

                        mel = mel_spectrogram(audio, 
                                              self.config["n_fft"], 
                                              self.config["num_mels"], 
                                              self.config["sampling_rate"],
                                              self.config["hop_size"], 
                                              self.config["win_size"], 
                                              self.config["fmin"], 
                                              self.config["fmax"],
                                              center=False)

                        mel = mel.squeeze(0).transpose(0, 1)
                        mel = mel_normalize(mel)
                        
                        yield mel, torch.tensor(spk_emb)
                        
                    batch = []

In [44]:
dataset = Dataset()

In [45]:
def batch(batches):
    ori_mels, spk_input_mels = zip(*batches)
    
    spk_input_mels = torch.stack(spk_input_mels)
    ori_lens = [len(ori_mel) for ori_mel in ori_mels]

    overlap_lens = ori_lens
    ori_mels = pad_sequence(ori_mels, batch_first=True)
    mel_masks = [torch.arange(ori_mels.size(1)) >= mel_len for mel_len in ori_lens]
    mel_masks = torch.stack(mel_masks)  #

    return spk_input_mels, ori_mels, mel_masks, overlap_lens

loader = torch.utils.data.DataLoader(dataset, batch_size = 4, collate_fn = batch)
loader = iter(loader)

In [47]:
d = next(loader)

In [48]:
Generator = MagicModel(d_model = 192)

In [49]:
criterion = torch.nn.L1Loss()

In [51]:
spk_embs, input_mels, input_masks, overlap_lens = d
fake_mels = Generator(spk_embs,input_mels,input_masks)

In [52]:
losses = []
for fake_mel, target_mel, overlap_len in zip(fake_mels.unbind(), input_mels.unbind(), overlap_lens):
    temp_loss = criterion(fake_mel[:overlap_len, :], target_mel[:overlap_len, :])
    losses.append(temp_loss)
loss = sum(losses) / len(losses)

In [53]:
loss

tensor(0.5427, grad_fn=<DivBackward0>)

In [56]:
torch.save(Generator.state_dict(), 'mediumvc.pt')