In [1]:
import malaya_speech
import torch
import json
from librosa.util import normalize
from torch.nn.utils.rnn import pad_sequence
from malaya_speech.augmentation.waveform import random_sampling
from malaya_speech.torch_model.hifivoice.env import AttrDict
from malaya_speech.torch_model.hifivoice.meldataset import mel_spectrogram, mel_normalize, mel_denormalize
from malaya_speech.torch_model.mediumvc.any2any import MagicModel

`pyaudio` is not available, `malaya_speech.streaming.pyaudio` is not able to use.


In [2]:
vocoder = malaya_speech.vocoder.pt_hifigan()

In [3]:
speaker_v = malaya_speech.speaker_vector.nemo(model = 'huseinzol05/nemo-titanet_large', 
                                              local_files_only=True)

_ = speaker_v.eval()

In [4]:
config = 'hifigan-config.json'
with open(config) as fopen:
    json_config = json.load(fopen)
    
config = AttrDict(json_config)

In [5]:
y, _ = malaya_speech.load('speech/example-speaker/husein-zolkepli.wav', sr = 22050)
y_16k, _ = malaya_speech.load('speech/example-speaker/khalil-nooh.wav')
spk_emb = speaker_v([y_16k])[0]
spk_emb = normalize(spk_emb)

In [6]:
audio = normalize(y) * 0.95
audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0)

mel = mel_spectrogram(audio, config["n_fft"], config["num_mels"], config["sampling_rate"],
                                          config["hop_size"], config["win_size"], config["fmin"], config["fmax"],
                                          center=False)

mel = mel.squeeze(0).transpose(0, 1)
mel = mel_normalize(mel)

  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]


In [7]:
from pytorch_lightning import LightningModule
from transformers import get_linear_schedule_with_warmup



In [8]:
class Module(LightningModule):
    def __init__(self, ):
        super().__init__()

        self.Generator = MagicModel(d_model=192)
        self.criterion = torch.nn.L1Loss()

        config = 'hifigan-config.json'
        with open(config) as fopen:
            json_config = json.load(fopen)

        self.config = AttrDict(json_config)

        # self.optimizer = torch.optim.AdamW(
        #     [{'params': filter(lambda p: p.requires_grad, self.Generator.parameters()), 'initial_lr': self.config["learning_rate"]}],
        #     self.config["learning_rate"], betas=[self.config["adam_b1"], self.config["adam_b2"]])
        # self.scheduler = torch.optim.lr_scheduler.ExponentialLR(
        #     self.optimizer, gamma=self.config["lr_decay"], last_epoch=-1)

        self.optimizer = torch.optim.AdamW(
            [{'params': filter(lambda p: p.requires_grad, self.Generator.parameters()),
              'initial_lr': 5e-5}], 5e-5, betas=[self.config["adam_b1"], self.config["adam_b2"]])

    def configure_optimizers(self):
        scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=10000,
            num_training_steps=1000000,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return (
            [self.optimizer],
            [scheduler],
        )

    def training_step(self, batch, batch_idx):
        spk_embs, input_mels, input_masks, overlap_lens = batch
        fake_mels = self.Generator(spk_embs, input_mels, input_masks)
        losses = []
        for fake_mel, target_mel, overlap_len in zip(
                fake_mels.unbind(), input_mels.unbind(), overlap_lens):
            temp_loss = self.criterion(fake_mel[:overlap_len, :], target_mel[:overlap_len, :])
            losses.append(temp_loss)
        loss = sum(losses) / len(losses)
        self.log(f"Losses/training_loss", loss, on_step=True, on_epoch=True)
        return loss

In [9]:
!ls mediumvc-32

'model-epoch=00-step=10000.ckpt'    'model-epoch=00-step=1155000.ckpt'
'model-epoch=00-step=1145000.ckpt'  'model-epoch=00-step=5000.ckpt'
'model-epoch=00-step=1150000.ckpt'


In [10]:
model = Module.load_from_checkpoint('mediumvc-32/model-epoch=00-step=1155000.ckpt').eval()
_ = model.Generator.eval()
# model.Generator.cont_encoder.remove_weight_norm()
# model.Generator.generator.remove_weight_norm()

In [11]:
torch.save(model.Generator.state_dict(), 'mediumvc.pt')

In [12]:
model.Generator.cont_encoder.remove_weight_norm()
model.Generator.generator.remove_weight_norm()

In [13]:
ori_mels = [mel]
spk_input_mels = [torch.tensor(spk_emb)]
spk_input_mels = torch.stack(spk_input_mels)
ori_lens = [len(ori_mel) for ori_mel in ori_mels]
overlap_lens = ori_lens
ori_mels = pad_sequence(ori_mels, batch_first=True)
mel_masks = [torch.arange(ori_mels.size(1)) >= mel_len for mel_len in ori_lens]
mel_masks = torch.stack(mel_masks)

In [14]:
fake_mels = model.Generator(spk_input_mels,ori_mels,mel_masks)

In [15]:
ori_mels.shape, fake_mels.shape

(torch.Size([1, 484, 80]), torch.Size([1, 484, 80]))

In [16]:
fake_mel = torch.clamp(fake_mels, min=0, max=1)
fake_mel = mel_denormalize(fake_mel)
fake_mel = fake_mel.transpose(1,2)
fake_mel = fake_mel.detach().cpu().numpy()

In [17]:
r = vocoder.predict(fake_mel)

In [18]:
r[0].shape

(123904,)

In [19]:
len(y)

124156

In [20]:
import IPython.display as ipd
ipd.Audio(r[0], rate = 22050)