In [1]:
import torch
import torchaudio

In [2]:
from helper import get_initial_table, get_final_table

## 语音数据集预处理

- 将连续的多音节重叠到一起，划分单个汉字的时间
- 延长发音需要标注出来，确定为延长，方便之后生成

In [3]:
initial_table = get_initial_table()

In [4]:
def print_all(x): 
    for s in x:
        print(len(s), s)

In [5]:
def get_transcriptions(path):
    with open(path) as f:
        lines = f.read().split('\n')
        if (lines[-1]==''):
            lines = lines[:-1]
        return lines

In [6]:
path = '../data/opencpop/segments/'
lines = get_transcriptions(path+'train.txt')
len(lines)

3550

In [7]:
def get_audio(id, path, sr = 16000):
    wav_path = path+'wavs/'+str(id)+'.wav'
    waveform, sample_rate = torchaudio.load(wav_path)
    if sample_rate != sr:
        waveform = torchaudio.functional.resample(waveform[0].unsqueeze(0), sample_rate, sr)
    return waveform

def parser_line(line):
    id, text, phoneme, note, note_duration, phoneme_duration, slur_note = line.split('|')
    phoneme = phoneme.split(' ')
    note = note.split(' ')
    note_duration = [float(i) for i in note_duration.split(' ')]
    phoneme_duration = [float(i) for i in phoneme_duration.split(' ')]
    slur_note = [int(i) for i in slur_note.split(' ')]
    assert len(phoneme) == len(note_duration) and len(phoneme_duration) == len(slur_note) and len(slur_note) == len(phoneme)
    return id, text, phoneme, note, note_duration, phoneme_duration, slur_note

看下一共用到了多少元音辅音

In [8]:
phoneme_set = set()
note_set = set()
for line in lines:
    id, text, phoneme, note, note_duration, phoneme_duration, slur_note = parser_line(line)
    phoneme_set.update(set(phoneme))
    note_set.update(set(note))

In [9]:
print_all([phoneme_set,note_set])

60 {'p', 'h', 'ie', 'e', 'SP', 'un', 'c', 'o', 'g', 'u', 'van', 'd', 'i', 'y', 'en', 'ing', 'ai', 'ei', 's', 'ch', 'ia', 'iong', 'uang', 'f', 'r', 'm', 'q', 'iu', 'ong', 'b', 'an', 'eng', 'k', 't', 'zh', 'v', 'a', 'vn', 'uai', 'ou', 'ian', 'ao', 'x', 'sh', 'uo', 'j', 'AP', 'ui', 'iang', 'er', 'iao', 'ang', 'l', 'n', 'in', 'ua', 'z', 'w', 'uan', 've'}
35 {'F#4/Gb4', 'G4', 'D2', 'F4', 'A5', 'C3', 'C4', 'F#3/Gb3', 'D3', 'rest', 'G#4/Ab4', 'D5', 'C#3/Db3', 'B4', 'C#2/Db2', 'D#5/Eb5', 'C5', 'G#3/Ab3', 'D#4/Eb4', 'A4', 'E3', 'F3', 'C#4/Db4', 'F#5/Gb5', 'A#3/Bb3', 'E4', 'A#4/Bb4', 'C#5/Db5', 'E5', 'F5', 'A3', 'G3', 'B3', 'D4', 'D#3/Eb3'}


打一个示例

In [10]:
line = lines[5]
line

'2001000006|漂浮在一片无奈|p iao f u z ai ai ai AP SP y i i p ian ian ian w u n ai SP AP|E4 E4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A4 G#4/Ab4 rest rest E4 E4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A4 G#4/Ab4 E4 E4 F#4/Gb4 F#4/Gb4 rest rest|0.185230 0.185230 0.177410 0.177410 0.193930 0.193930 0.259670 0.299340 0.215550 0.031770 0.197520 0.197520 0.165450 0.184760 0.184760 0.212290 0.246960 0.440370 0.440370 1.524950 1.524950 0.855830 0.559100|0.06011 0.12512 0.07517 0.10224 0.08603 0.1079 0.25967 0.29934 0.21555 0.03177 0.05175 0.14577 0.16545 0.0748 0.10996 0.21229 0.24696 0.09617 0.3442 0.1437 1.38125 0.85583 0.5591|0 0 0 0 0 0 1 1 0 0 0 0 1 0 0 1 1 0 0 0 0 0 0'

In [11]:
id, text, phoneme, note, note_duration, phoneme_duration, slur_note = parser_line(line)
print_all([id, text, phoneme, note, note_duration, phoneme_duration, slur_note])

10 2001000006
7 漂浮在一片无奈
23 ['p', 'iao', 'f', 'u', 'z', 'ai', 'ai', 'ai', 'AP', 'SP', 'y', 'i', 'i', 'p', 'ian', 'ian', 'ian', 'w', 'u', 'n', 'ai', 'SP', 'AP']
23 ['E4', 'E4', 'F#4/Gb4', 'F#4/Gb4', 'G#4/Ab4', 'G#4/Ab4', 'A4', 'G#4/Ab4', 'rest', 'rest', 'E4', 'E4', 'F#4/Gb4', 'G#4/Ab4', 'G#4/Ab4', 'A4', 'G#4/Ab4', 'E4', 'E4', 'F#4/Gb4', 'F#4/Gb4', 'rest', 'rest']
23 [0.18523, 0.18523, 0.17741, 0.17741, 0.19393, 0.19393, 0.25967, 0.29934, 0.21555, 0.03177, 0.19752, 0.19752, 0.16545, 0.18476, 0.18476, 0.21229, 0.24696, 0.44037, 0.44037, 1.52495, 1.52495, 0.85583, 0.5591]
23 [0.06011, 0.12512, 0.07517, 0.10224, 0.08603, 0.1079, 0.25967, 0.29934, 0.21555, 0.03177, 0.05175, 0.14577, 0.16545, 0.0748, 0.10996, 0.21229, 0.24696, 0.09617, 0.3442, 0.1437, 1.38125, 0.85583, 0.5591]
23 [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0]


音频文件读取测试

In [12]:
waveform = get_audio(id, path)
waveform.shape

torch.Size([1, 92003])

汉字元音辅音组合为单个汉字

In [13]:
print(get_final_table())

['i', 'u', 'v', 'a', 'ia', 'ua', 'o', 'uo', 'e', 'ie', 've', 'ai', 'uai', 'ei', 'uei', 'ao', 'iao', 'ou', 'iou', 'an', 'ian', 'uan', 'van', 'en', 'in', 'uen', 'vn', 'ang', 'iang', 'uang', 'eng', 'ing', 'ueng', 'ong', 'iong', 'er', 'ê']


In [17]:
def merge_note(text, phoneme, note, note_duration, slur_note):
    # 1. check whether the phoneme is in finals
    INITIALS = get_initial_table()
    FINALS = get_final_table()
    # is_final = [1 if p in FINALS else 0 for p in phoneme]
    phoneme = phoneme.copy()
    note = note.copy()
    note_duration = note_duration.copy()
    slur_note = slur_note.copy()
    j = -1
    text+='////////////////////'
    text_with_p = phoneme.copy()
    used_flag = False
    for i in range(len(text_with_p)):
        if text_with_p[i] in ['AP', 'SP']:
            continue
        if j==-1 or phoneme[i] in INITIALS or (phoneme[i-1] not in INITIALS and phoneme[i] != phoneme[i-1]):
            j+=1
            used_flag = False
        text_with_p[i] = text[j] if used_flag == False else '~'
        used_flag = True
    for i in range(len(phoneme)-1, 0, -1):
        if (note_duration[i] == note_duration[i-1] and phoneme[i-1] in INITIALS):
            del note_duration[i]
            del note[i]
            phoneme[i-1]=[phoneme[i-1],phoneme[i]]
            del phoneme[i]
            del text_with_p[i]
            del slur_note[i]
        elif phoneme[i] in FINALS or phoneme[i] in ['AP', 'SP']:
            phoneme[i] = [phoneme[i]]
    return text_with_p, phoneme, note, note_duration, slur_note

In [15]:
def merge_note_old(text, phoneme, note, note_duration):
    # remove the duplicate items in phoneme, note, and note_duration
    # use text to verify the length
    phoneme = phoneme.copy()
    note = note.copy()
    note_duration = note_duration.copy()
    j = -1
    text+='////////////////////'
    text_with_p = phoneme.copy()
    used_flag = False
    for i in range(len(text_with_p)):
        if text_with_p[i] in ['AP', 'SP']:
            continue
        if j==-1 or phoneme[i] in initial_table or (phoneme[i-1] not in initial_table and phoneme[i] != phoneme[i-1]):
            j+=1
            used_flag = False
        text_with_p[i] = text[j] if used_flag == False else '~'
        used_flag = True
    for i in range(len(phoneme)-1, 0, -1):
        if (note_duration[i] == note_duration[i-1] and phoneme[i-1] in initial_table):
            del note_duration[i]
            del note[i]
            phoneme[i-1]=phoneme[i-1]+phoneme[i]
            del phoneme[i]
            del text_with_p[i]
    return text_with_p, phoneme, note, note_duration

In [18]:
print_all(merge_note(text, phoneme, note, note_duration, slur_note))

16 ['漂', '浮', '在', '~', '~', 'AP', 'SP', '一', '~', '片', '~', '~', '无', '奈', 'SP', 'AP']
16 [['p', 'iao'], ['f', 'u'], ['z', 'ai'], ['ai'], ['ai'], ['AP'], ['SP'], ['y', 'i'], ['i'], ['p', 'ian'], ['ian'], ['ian'], ['w', 'u'], ['n', 'ai'], ['SP'], ['AP']]
16 ['E4', 'F#4/Gb4', 'G#4/Ab4', 'A4', 'G#4/Ab4', 'rest', 'rest', 'E4', 'F#4/Gb4', 'G#4/Ab4', 'A4', 'G#4/Ab4', 'E4', 'F#4/Gb4', 'rest', 'rest']
16 [0.18523, 0.17741, 0.19393, 0.25967, 0.29934, 0.21555, 0.03177, 0.19752, 0.16545, 0.18476, 0.21229, 0.24696, 0.44037, 1.52495, 0.85583, 0.5591]
16 [0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0]


In [19]:
from dataset import SpeechDataset
from torch.utils.data import DataLoader, Dataset, random_split

In [21]:
class OpencpopDataset(SpeechDataset):

    def __init__(self, data_path, sample_rate=16000, transform=None):
        super().__init__(data_path, sample_rate, transform)
        transcript_file = data_path+'transcriptions.txt'
        self.transcript = self.gen_transcript(transcript_file)
        self.dataset_file_num = len(self.transcript)
        self.threshold = 120000 # to avoid GPU memory used out
        self.batch_size = 80 # to avoid GPU memory used out
        self.split_ratio = [1000, 3]

    def __len__(self):
        return self.dataset_file_num

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if idx >= self.dataset_file_num:
            return {'audio': None, 'text': None}
        line = self.transcript[idx]
        id, text, phoneme, note, note_duration, phoneme_duration, slur_note = self.parser_line(line)
        waveform = self.get_audio(id)
        # text_with_p, phoneme, note, note_duration = merge_note(text, phoneme, note, note_duration)
        sample = {'audio': waveform, 'text': line}
        if self.transform:
            sample = self.transform(sample, self.sample_rate)
        return sample

    def get_audio(self, id):
        wav_path = self.data_path+'wavs/'+str(id)+'.wav'
        waveform, sample_rate = torchaudio.load(wav_path)
        if sample_rate != self.sample_rate:
            waveform = torchaudio.functional.resample(waveform[0].unsqueeze(0), sample_rate, self.sample_rate)
        return waveform

    def parser_line(self, line):
        id, text, phoneme, note, note_duration, phoneme_duration, slur_note = line.split('|')
        phoneme = phoneme.split(' ')
        note = note.split(' ')
        note_duration = [float(i) for i in note_duration.split(' ')]
        phoneme_duration = [float(i) for i in phoneme_duration.split(' ')]
        slur_note = [int(i) for i in slur_note.split(' ')]
        assert len(phoneme) == len(note_duration) and len(phoneme_duration) == len(slur_note) and len(slur_note) == len(phoneme)
        return id, text, phoneme, note, note_duration, phoneme_duration, slur_note

    def gen_transcript(self, transcript_file):
        with open(transcript_file) as f:
            lines = f.read().split('\n')
            if (lines[-1]==''):
                lines = lines[:-1]
            return lines

    def split(self, split_ratio=None, seed=42):
        audio_dataset = self
        size = len(audio_dataset)
        my_split_ratio = self.split_ratio if split_ratio is None else split_ratio
        lengths = [(i*size)//sum(my_split_ratio) for i in my_split_ratio]
        lengths[-1] = size - sum(lengths[:-1])
        split_dataset = random_split(audio_dataset, lengths, generator=torch.Generator().manual_seed(seed))
        return split_dataset

In [22]:

def dataset_transform(sample, sample_rate=None):
    id, text, phoneme, note, note_duration, phoneme_duration, slur_note = parser_line(sample['text'])
    text_with_p, phoneme, note, note_duration, slur_note = merge_note(text, phoneme, note, note_duration, slur_note)
    sample['chinese'] = text_with_p
    sample['phoneme'] = phoneme
    sample['note'] = note
    sample['duration'] = note_duration
    sample['slur'] = slur_note
    return sample

dataset = OpencpopDataset('/scratch/bh2283/data/opencpop/segments/', transform=dataset_transform)

In [24]:
print_all(dataset[18].values())

1 tensor([[-1.8905e-03, -2.8823e-03, -2.3118e-03,  ...,  5.2431e-05,
          2.1897e-04,  1.8222e-05]])
1119 2001000019|宇宙磅礡而冷漠我们的爱微小却闪烁|y v zh ou b ang ang b o er l eng eng m o o o AP w o m en d e ai w ei x iao q ve sh an sh ou uo uo AP SP|C#4/Db4 C#4/Db4 D#4/Eb4 D#4/Eb4 C#4/Db4 C#4/Db4 D#4/Eb4 E4 E4 D#4/Eb4 E4 E4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A4 G#4/Ab4 rest C#4/Db4 C#4/Db4 C#4/Db4 C#4/Db4 D#4/Eb4 D#4/Eb4 C#4/Db4 D#4/Eb4 D#4/Eb4 E4 E4 D#4/Eb4 D#4/Eb4 E4 E4 G#4/Ab4 G#4/Ab4 A4 G#4/Ab4 rest rest|0.194490 0.194490 0.191880 0.191880 0.219800 0.219800 0.138290 0.170840 0.170840 0.204960 0.131260 0.131260 0.219430 0.183230 0.183230 0.197770 0.379730 0.380810 0.203550 0.203550 0.165270 0.165270 0.141470 0.141470 0.159550 0.111580 0.111580 0.246980 0.246980 0.126240 0.126240 0.329620 0.329620 0.306020 0.306020 0.171160 0.302490 0.236470 0.095710|0.05355 0.14094 0.06248 0.1294 0.06762 0.15218 0.13829 0.06402 0.10682 0.20496 0.0324 0.09886 0.21943 0.07985 0.10338 0.19777 0.37973 0.38081 0.07355 0.1

In [26]:
train_set, test_set = dataset.split()
len(train_set), len(test_set)

(3744, 12)

单个字测试

In [28]:
def play(waveform):
    torchaudio.save('./audio-temp.wav', waveform.unsqueeze(0), 16000)

In [54]:
def test_one_char(data, idx, sample_rate = 16000):
    assert idx < len(data['chinese'])
    chinese = data['chinese'][idx]
    time = data['duration'][idx]
    duration = data['duration']
    for i in range(1, len(duration)):
        duration[i] += duration[i-1]
    start = 0 if idx == 0 else int(duration[idx-1]*sample_rate)
    end = int(duration[idx]*sample_rate)
    print(data['audio'].shape)
    waveform = data['audio'][0, start: end]
    print(chinese, time, ''.join(data['phoneme'][idx]))
    play(waveform)

test_one_char(train_set[2], 13)

torch.Size([1, 107614])
也 0.42954 ye


靠耳朵听，发现大部分的标注是准确的，也有小部分划分有点出入。
文本注音也有些问题，但是问题不大。

## Dataloader

这儿我们制作一个单个字的dataloader，包含拼音、音调、时常（量化后），延音。

之后使用lookup emb制作decoder的输入emb。这儿注意可以使用前一个字后一个字来辅助生成更好的声音。如果这样做，注意前一个字，本字，后一个字都需要使用不同的lookup table。

先做一个naive版本，前后一个字不考虑。即使如此，我们的dataloader任然需要提供所有信息。

In [None]:
class MusicLoaderGenerator:
    def __init__(self, 
        labels, k_size=0, 
        num_workers=0,
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        ) -> None:
        self.k_size = k_size
        self.labels = labels
        self.look_up = {s: i for i, s in enumerate(labels)}
        self.device = device
        self.num_workers = num_workers
        self.version = '0.02'

    def label2id(self, str):
        return [self.look_up[i] for i in str]

    def id2label(self, idcs):
        return ''.join([self.labels[i] for i in idcs])

    def batch_filter(self, batch:list):
        # remove all audio with tag if audio length > threshold
        for i in range(len(batch)-1, -1, -1):
            if batch[i]['audio'].shape[-1] > self.threshold:
                del batch[i]
        return batch

    def collate_wrapper(self, batch:list): # RAW
        batch = self.batch_filter(batch)
        bs = len(batch)
        rand_shift = torch.randint(self.k_size, (bs,))
        audio_list = [batch[i]['audio'][:,rand_shift[i]:] for i in range(bs)]
        audio_length = [audio.shape[-1] for audio in audio_list]
        target_list = [self.label2id(item['text']) for item in batch]
        target_length = [len(l) for l in target_list]
        chinese_list = [batch[i]['chinese'] for i in range(bs)]

        target_length, target_list, audio_length, audio_list, chinese_list = zip(*sorted(zip(target_length, target_list, audio_length, audio_list, chinese_list), reverse=True))
        target_length = torch.tensor(target_length)
        audio_length = torch.tensor(audio_length)

        max_audio_length = torch.max(audio_length)
        audio_list = torch.cat([
            torch.cat(
            (audio, torch.zeros(max_audio_length-audio.shape[-1]).unsqueeze(0)), -1)
            for audio in audio_list], 0)
        
        max_target_length = torch.max(target_length)
        target_list = torch.cat([
            torch.cat(
            (torch.tensor(l), torch.zeros([max_target_length-len(l)], dtype=torch.int)), -1).unsqueeze(0) 
            for l in target_list], 0)
        return {'audio': audio_list, 'audio_len': audio_length, 
                'target': target_list, 'target_len': target_length,
                'chinese': chinese_list}

    def dataloader(self, audioDataset, batch_size, shuffle=True):
        # k_size is the kernel size for the encoder, for data augmentation
        self.threshold = audioDataset.dataset.threshold
        return DataLoader(audioDataset, batch_size,
                            shuffle, num_workers=self.num_workers, collate_fn=self.collate_wrapper)


In [None]:
loaderGenerator = MusicLoaderGenerator(labels, k_size=5)
train_loader = loaderGenerator.dataloader(train_set, batch_size=8)
print('train_set:', len(train_set), 'test_set:',len(test_set))
    steps = 10
    for i_batch, sample_batched in enumerate(train_loader):
        if steps <= 0:
            break
        print(sample_batched['audio'].shape, sample_batched['target'].shape)
        print(sample_batched['audio_len'], sample_batched['target_len'])
        steps -= 1