In [1]:
from dataset import OpencpopDataset, MusicLoaderGenerator
from helper import parser_line, merge_note, get_pitch_labels, get_transposed_phoneme_labels, print_all

In [16]:
cc = [['a', 'b'], ['AP']]

In [17]:
for i in cc:
    if 'A' in i:
        print('----------ERROR2---------')
        print(i)
        break

## 数据集加载

In [2]:
def dataset_transform(sample, sample_rate=None):
    id, text, phoneme, note, note_duration, phoneme_duration, slur_note = parser_line(sample['text'])
    text_with_p, phoneme, note, note_duration, slur_note = merge_note(text, phoneme, note, note_duration, slur_note)
    sample['chinese'] = text_with_p
    sample['phoneme'] = phoneme
    sample['note'] = note
    sample['duration'] = note_duration
    sample['slur'] = slur_note
    return sample

dataset = OpencpopDataset('/scratch/bh2283/data/opencpop/segments/', transform=dataset_transform, sample_rate=22050)

In [3]:
train_set, test_set = dataset.split()
len(train_set), len(test_set)

(3744, 12)

In [4]:
BATCH_SIZE = 4
note_labels = get_pitch_labels()
phoneme_labels = get_transposed_phoneme_labels()
slur_labels = [0, 1]
# 0-1 分辨率0.01，1-2 分辨率0.05，2-7 分辨率0.2
duration_labels = [i for i in range(7)]

labels = (
    phoneme_labels,
    note_labels,
    slur_labels
)
loaderGenerator = MusicLoaderGenerator(labels)
train_loader = loaderGenerator.dataloader(train_set, batch_size=BATCH_SIZE)
print('train_set:', len(train_set), 'test_set:',len(test_set))
steps = 1
for i_batch, sample_batched in enumerate(train_loader):
    if steps <= 0:
        break

    print(sample_batched['chinese'])
    print(sample_batched['phoneme'].shape)
    print(sample_batched['phoneme_pre'].shape)
    print(sample_batched['note_post'].shape)
    print(sample_batched['audio_duration_quant'])
    print(sample_batched['mel'].shape)
    print(sample_batched['mel_len'].shape, sample_batched['mel_len'])
    steps -= 1

train_set: 3744 test_set: 12
['得', '过', '且', '过', '是', '我', '如', '今', '速', '写', 'SP', 'AP', 'SP', '爱', '是', '不', '是', 'AP', '不', '开', '口', '才', '~', '珍', '贵', '~', 'AP', '打', '从', '心', '里', '暖', '暖', '的', 'SP', 'AP', '白', '色', '的', '约', 'SP', '定', 'AP', '我', '不', '愿', '醒', '过', '来', '~', '~', 'AP']
torch.Size([52, 2])
torch.Size([52, 2])
torch.Size([52])
tensor([ 5,  9,  7,  9,  7,  9,  6,  9, 30, 35, 18, 15,  4,  8, 24,  6, 31, 10,
         8, 19,  9, 27, 10, 22, 15, 69, 12,  8, 12, 12, 12, 12, 16, 21, 84, 10,
        15, 28, 16, 12,  6, 43, 17, 13,  9,  9, 24, 15, 12,  6, 64, 15])
torch.Size([52, 80, 169])
torch.Size([52]) tensor([ 12,  19,  15,  19,  16,  19,  13,  20,  62,  72,  38,  32,   9,  17,
         50,  14,  63,  21,  17,  40,  20,  56,  21,  46,  32, 139,  26,  18,
         25,  26,  26,  26,  34,  43, 169,  22,  31,  57,  34,  25,  13,  87,
         35,  28,  20,  19,  49,  31,  25,  14, 130,  32])


# Model 设计

- 尝试使用逆卷积，上采样得到所需的音频
- 我们无需去计算时间停止符，只需要在输出的时间内计算loss并且最小化即可
- 设定一个最大时间长度，比如两秒，超过的就不要了（用阈值筛掉）
- 多层上采样得到最佳的输出
- 使用梅尔频谱，还有解码器，可以使得输出音质比stft好（猜的，需要验证）一般机器学习声码器都会好点

但是有问题：
- 使用逆卷积太过刚直，没有变化性，导致无法很好的收敛
- 一般逆卷积和GAN一起使用，用判别器取代刚直的loss
- 使用tacotron模式就会好很多，无需GAN，自收敛

In [5]:
from torchaudio.models.tacotron2 import Tacotron2, _get_mask_from_lengths, _Decoder, _Encoder, _Postnet
from torchaudio.pipelines._tts.utils import _get_taco_params
import torch
from torch import Tensor
from typing import Tuple, List, Optional, Union, overload

class TacotronTail(Tacotron2):
    def __init__(
        self,
        labels_lens: dict,
        decoder = None,
        postnet = None,
    ) -> None:
        _tacotron2_params=_get_taco_params(n_symbols=5) # ignore n_symbols, encoder not used 
        super().__init__(**_tacotron2_params)

        embedding_dim = _tacotron2_params['encoder_embedding_dim']
        self.embeddings = {
            key: torch.nn.Embedding(value, embedding_dim)
            for key, value in labels_lens.items()
        }
        self.embedding_register = torch.nn.ModuleList(self.embeddings.values())
        # 将embedding注册进模型，不确定是否复制，需要在实践中测试
        if decoder is not None:
            self.decoder: _Decoder = decoder
        if postnet is not None:
            self.postnet: _Postnet = postnet
        self.reduce_phoneme = lambda x: torch.sum(x, 1) if len(x.shape)==3 else x
        self.decoder.decoder_max_step = int(4 * 22050 / 256)
        self.version = '0.01'

    def forward(
        self,
        inputs: dict,
    ):
        embedded_inputs = [
            self.reduce_phoneme(self.embeddings[key](inputs[key])) for key in self.embeddings.keys()
        ]
        embedded_inputs = torch.stack(embedded_inputs).sum(0).unsqueeze(1) # [bs, 1, emb_size]
        # print('embedded_inputs', embedded_inputs.shape)
        mel_specgram = inputs['mel'] # (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)
        # print('mel_specgram', mel_specgram.shape)
        mel_specgram_lengths = inputs['mel_len']
        mel_specgram, gate_outputs, alignments = self.decoder(
            embedded_inputs, mel_specgram, memory_lengths=torch.ones_like(mel_specgram_lengths),
        )

        mel_specgram_postnet = self.postnet(mel_specgram)
        mel_specgram_postnet = mel_specgram + mel_specgram_postnet

        if self.mask_padding:
            mask = _get_mask_from_lengths(mel_specgram_lengths)
            mask = mask.expand(self.n_mels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            mel_specgram.masked_fill_(mask, 0.0)
            mel_specgram_postnet.masked_fill_(mask, 0.0)
            gate_outputs.masked_fill_(mask[:, 0, :], 1e3)

        return mel_specgram, mel_specgram_postnet, gate_outputs, alignments

    @torch.jit.export
    def infer(
        self, 
        inputs: dict,
        ) -> Tuple[Tensor, Tensor, Tensor]:

        embedded_inputs = [
            self.reduce_phoneme(self.embeddings[key](inputs[key])) for key in self.embeddings.keys()
        ]
        embedded_inputs = torch.stack(embedded_inputs).sum(0).unsqueeze(1) # [bs, 1, emb_size]
        # print('embedded_inputs', embedded_inputs.shape)
        
        n_batch = embedded_inputs.shape[0]
        mel_specgram, mel_specgram_lengths, _, alignments = \
            self.decoder.infer(embedded_inputs, memory_lengths=torch.ones(n_batch))

        mel_outputs_postnet = self.postnet(mel_specgram)
        mel_outputs_postnet = mel_specgram + mel_outputs_postnet

        alignments = alignments.unfold(1, n_batch, n_batch).transpose(0, 2)

        return mel_outputs_postnet, mel_specgram_lengths, alignments

In [6]:
labels_lens = {
    'audio_duration_quant': 130, # 这个是量化后的计算结果
    'phoneme': len(phoneme_labels), # 拼音
    'phoneme_pre': len(phoneme_labels), # 前一个汉字的拼音
    'phoneme_post': len(phoneme_labels), # 后一个汉字的拼音
    'note': len(note_labels), # 音调音符
    'note_pre': len(note_labels),
    'note_post': len(note_labels),
    'slur': len(slur_labels), # 是否为延长音
}
model = TacotronTail(labels_lens)
steps = 1
for i_batch, sample_batched in enumerate(train_loader):
    if steps <= 0:
        break
    model.infer(sample_batched)
    # print(sample_batched['chinese'])
    # print(sample_batched['phoneme'].shape)
    # print(sample_batched['phoneme_pre'].shape)
    # print(sample_batched['note_post'].shape)
    # print(sample_batched['audio_duration_quant'])
    # print(sample_batched['mel'].shape)
    # print(sample_batched['mel_len'].shape, sample_batched['mel_len'])
    steps -= 1




## 训练

- 初始参数配置
- 训练

In [7]:
LOG_DIR = './log/tacotron-1-'
LEARNING_RATE = 0.001
LOAD_PATH = './checkpoint/pre.pt'
def save_log(file_name, log, mode='a', path = LOG_DIR):
    with open(path+file_name, mode) as f:
        if mode == 'a':
            f.write('\n')
        if type(log) is str:
            f.write(log)
            print(log)
        else:
            log = [str(l) for l in log]
            f.write(' '.join(log))
            print(' '.join(log))

In [8]:
from os.path import exists
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_checkpoint(path):
    if exists(path):
        save_log(f'e.txt', ['path', path, 'exist, loading...'])
        checkpoint = torch.load(path, map_location=device)
        if 'model_state_dict' in checkpoint:
            model.decoder.load_state_dict(checkpoint['model_state_dict'], strict=False) # , strict=False?
            model.postnet.load_state_dict(checkpoint['model_state_dict'], strict=False)

load_checkpoint(LOAD_PATH)

path ./checkpoint/pre.pt exist, loading...


In [10]:
# model

In [9]:
params = model.parameters()
# params = list(model.embedding.parameters())+list(model.encoder.parameters())+list(model.speaker_encoder.parameters())
# optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=0.5)
optimizer = torch.optim.Adam(params, lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
initial_epoch = 0
mse_loss = torch.nn.MSELoss()
bce_loss = torch.nn.BCELoss()
cos_loss = torch.nn.CosineEmbeddingLoss()
mean = lambda x: sum(x)/len(x)

In [10]:
def dump_model(EPOCH, LOSS, PATH):
    torch.save({
            'epoch': EPOCH,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

def save_temp(EPOCH, LOSS):
    PATH = f"./checkpoint/model_temp.pt"
    dump_model(EPOCH, LOSS, PATH)
    
def save_checkpoint(EPOCH, LOSS):
    PATH = f"./checkpoint/model_{EPOCH}_{'%.3f' % LOSS}.pt"
    dump_model(EPOCH, LOSS, PATH)

In [11]:
def train(epoch=1):
    train_loss_q = []
    test_loss_q = []
    for epoch in range(initial_epoch, epoch):
        batch_train_loss = []
        for i_batch, sample_batched in enumerate(train_loader):
            model.train()
            # Step 1. Prepare Data
            mels_tensor = sample_batched['mel'].to(device) # [bs, mel_bins, L]
            mel_length = sample_batched['mel_len'].to(device)

            # Step 2. Run our forward pass

            org_mel, pos_mel, stop_token, _ = model.forward(sample_batched)
            loss1 = mse_loss(mels_tensor, org_mel)
            loss1 += mse_loss(mels_tensor, pos_mel)

            true_stop_token = torch.zeros(stop_token.shape).to(device)
            for i in range(true_stop_token.shape[0]):
                true_stop_token[i][mel_length[i]:]+=1.0
            loss2 = bce_loss(torch.sigmoid(stop_token), true_stop_token)
            
            # Step 3. Run our backward pass
            optimizer.zero_grad()
            loss = loss1 + loss2
            loss.backward()
            optimizer.step()

            if loss.item()!=loss.item(): # if loss == NaN, break
                print('NaN hit!')
                exit()
            
            batch_train_loss.append(loss.item())

            if i_batch % (3000 // BATCH_SIZE) == 0: # log about each n data
                # test_loss = test()
                test_loss = 0
                train_loss = mean(batch_train_loss)
                test_loss_q.append(test_loss)
                train_loss_q.append(train_loss)
                save_log(f'e{epoch}.txt', ['🟣 epoch', epoch, 'data', i_batch*BATCH_SIZE, 
                    'lr', scheduler.get_last_lr(), 
                    'train_loss', '{:.3f}'.format(train_loss), 
                    'test_loss', test_loss, 
                    'bce_loss', '{:.3f}'.format(loss2.item())])
                save_temp(epoch, test_loss) # save temp checkpoint
                # test_decoder(epoch, 5)
            
            # exit()
            
        # scheduler.step()
        save_checkpoint(epoch, mean(test_loss_q))
        save_log(f'e{epoch}.txt', ['============= Final Test ============='])
        # test_decoder(epoch, 10) # run some sample prediction and see the result


In [12]:
train()

🟣 epoch 0 data 0 lr [0.001] train_loss 191.046 test_loss 0 bce_loss 1.031


KeyboardInterrupt: 

: 