In [1]:
from dataset import OpencpopDataset, MusicLoaderGenerator
from helper import parser_line, merge_note, get_pitch_labels, get_transposed_phoneme_labels, print_all

## 数据集加载

In [2]:
def dataset_transform(sample, sample_rate=None):
    id, text, phoneme, note, note_duration, phoneme_duration, slur_note = parser_line(sample['text'])
    text_with_p, phoneme, note, note_duration, slur_note = merge_note(text, phoneme, note, note_duration, slur_note)
    sample['chinese'] = text_with_p
    sample['phoneme'] = phoneme
    sample['note'] = note
    sample['duration'] = note_duration
    sample['slur'] = slur_note
    return sample

dataset = OpencpopDataset('/scratch/bh2283/data/opencpop/segments/', transform=dataset_transform, sample_rate=22050)

In [3]:
train_set, test_set = dataset.split()
len(train_set), len(test_set)

(3744, 12)

In [9]:
note_labels = get_pitch_labels()
phoneme_labels = get_transposed_phoneme_labels()
slur_labels = [-1, 0, 1]
# 0-1 分辨率0.01，1-2 分辨率0.05，2-7 分辨率0.2
duration_labels = [i for i in range(7)]

labels = (
    phoneme_labels,
    note_labels,
    slur_labels
)
loaderGenerator = MusicLoaderGenerator(labels)
train_loader = loaderGenerator.dataloader(train_set, batch_size=2)
print('train_set:', len(train_set), 'test_set:',len(test_set))
steps = 1
for i_batch, sample_batched in enumerate(train_loader):
    if steps <= 0:
        break

    print(sample_batched['chinese'])
    print(sample_batched['phoneme'].shape)
    print(sample_batched['phoneme_pre'].shape)
    print(sample_batched['note_post'].shape)
    print(sample_batched['audio_duration_quant'])
    print(sample_batched['mel'].shape)
    print(sample_batched['mel_len'].shape, sample_batched['mel_len'])
    steps -= 1

train_set: 3744 test_set: 12
['你', '说', '空', '瓶', '适', '合', '许', 'SP', '愿', 'SP', 'AP', '根', '本', '不', '重', '要', 'SP']
torch.Size([17, 2])
torch.Size([17, 2])
torch.Size([17])
tensor([12, 17, 18, 14, 17, 14, 33,  4, 42, 82, 15, 15, 12, 15, 28, 32, 98])
torch.Size([17, 80, 225])
torch.Size([17]) tensor([ 26,  36,  38,  29,  36,  29,  67,  10,  86, 166,  32,  32,  25,  31,
         58,  65, 225])


# Model 设计

- 尝试使用逆卷积，上采样得到所需的音频
- 我们无需去计算时间停止符，只需要在输出的时间内计算loss并且最小化即可
- 设定一个最大时间长度，比如两秒，超过的就不要了（用阈值筛掉）
- 多层上采样得到最佳的输出
- 使用梅尔频谱，还有解码器，可以使得输出音质比stft好（猜的，需要验证）一般机器学习声码器都会好点

但是有问题：
- 使用逆卷积太过刚直，没有变化性，导致无法很好的收敛
- 一般逆卷积和GAN一起使用，用判别器取代刚直的loss
- 使用tacotron模式就会好很多，无需GAN，自收敛

In [21]:
from torchaudio.models.tacotron2 import Tacotron2, _get_mask_from_lengths, _Decoder, _Encoder, _Postnet
from torchaudio.pipelines._tts.utils import _get_taco_params
import torch
from torch import Tensor
from typing import Tuple, List, Optional, Union, overload

class TacotronTail(Tacotron2):
    def __init__(
        self,
        labels_lens: dict,
        decoder = None,
        postnet = None,
    ) -> None:
        _tacotron2_params=_get_taco_params(n_symbols=5) # ignore n_symbols, encoder not used 
        super().__init__(**_tacotron2_params)

        embedding_dim = _tacotron2_params['encoder_embedding_dim']
        self.embeddings = {
            key: torch.nn.Embedding(value, embedding_dim)
            for key, value in labels_lens.items()
        }
        self.embedding_register = torch.nn.ModuleList(self.embeddings.values())
        # 将embedding注册进模型，不确定是否复制，需要在实践中测试
        if decoder is not None:
            self.decoder: _Decoder = decoder
        if postnet is not None:
            self.postnet: _Postnet = postnet
        self.reduce_phoneme = lambda x: torch.sum(x, 1) if len(x.shape)==3 else x
        self.decoder.decoder_max_step = int(4 * 22050 / 256)
        self.version = '0.01'

    def forward(
        self,
        inputs: dict,
    ):
        embedded_inputs = [
            self.reduce_phoneme(self.embeddings[key](inputs[key])) for key in self.embeddings.keys()
        ]
        embedded_inputs = torch.stack(embedded_inputs).sum(0).unsqueeze(1) # [bs, 1, emb_size]
        print('embedded_inputs', embedded_inputs.shape)
        mel_specgram = inputs['mel'] # (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)
        print('mel_specgram', mel_specgram.shape)
        mel_specgram_lengths = inputs['mel_len']
        mel_specgram, gate_outputs, alignments = self.decoder(
            embedded_inputs, mel_specgram, memory_lengths=torch.ones_like(mel_specgram_lengths),
        )

        mel_specgram_postnet = self.postnet(mel_specgram)
        mel_specgram_postnet = mel_specgram + mel_specgram_postnet

        if self.mask_padding:
            mask = _get_mask_from_lengths(mel_specgram_lengths)
            mask = mask.expand(self.n_mels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            mel_specgram.masked_fill_(mask, 0.0)
            mel_specgram_postnet.masked_fill_(mask, 0.0)
            gate_outputs.masked_fill_(mask[:, 0, :], 1e3)

        return mel_specgram, mel_specgram_postnet, gate_outputs, alignments

    @torch.jit.export
    def infer(
        self, 
        inputs: dict,
        ) -> Tuple[Tensor, Tensor, Tensor]:

        embedded_inputs = [
            self.reduce_phoneme(self.embeddings[key](inputs[key])) for key in self.embeddings.keys()
        ]
        embedded_inputs = torch.stack(embedded_inputs).sum(0).unsqueeze(1) # [bs, 1, emb_size]
        print('embedded_inputs', embedded_inputs.shape)
        
        n_batch = embedded_inputs.shape[0]
        mel_specgram, mel_specgram_lengths, _, alignments = \
            self.decoder.infer(embedded_inputs, memory_lengths=torch.ones(n_batch))

        mel_outputs_postnet = self.postnet(mel_specgram)
        mel_outputs_postnet = mel_specgram + mel_outputs_postnet

        alignments = alignments.unfold(1, n_batch, n_batch).transpose(0, 2)

        return mel_outputs_postnet, mel_specgram_lengths, alignments

In [22]:
labels_lens = {
    'audio_duration_quant': 130, # 这个是量化后的计算结果
    'phoneme': len(phoneme_labels), # 拼音
    'phoneme_pre': len(phoneme_labels), # 前一个汉字的拼音
    'phoneme_post': len(phoneme_labels), # 后一个汉字的拼音
    'note': len(note_labels), # 音调音符
    'note_pre': len(note_labels),
    'note_post': len(note_labels),
    'slur': len(slur_labels), # 是否为延长音
}
model = TacotronTail(labels_lens)
steps = 1
for i_batch, sample_batched in enumerate(train_loader):
    if steps <= 0:
        break
    model.infer(sample_batched)
    # print(sample_batched['chinese'])
    # print(sample_batched['phoneme'].shape)
    # print(sample_batched['phoneme_pre'].shape)
    # print(sample_batched['note_post'].shape)
    # print(sample_batched['audio_duration_quant'])
    # print(sample_batched['mel'].shape)
    # print(sample_batched['mel_len'].shape, sample_batched['mel_len'])
    steps -= 1


embedded_inputs torch.Size([25, 1, 512])


