Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add suppport for Thai #117

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ffd8f41
Add suppport for Thai
Apr 29, 2024
c960e81
Add Thai Wiktionary Corpus
Apr 29, 2024
80460cc
Fix issues with g2p function
Apr 29, 2024
4d1fd9c
Add better test cases
Apr 29, 2024
6d7f2ac
Remove debug line
Apr 29, 2024
aff4b04
More fixes
Apr 29, 2024
5cd7da1
Fix normalize and dictionary
Apr 29, 2024
4484651
Update expected format of test_word_tokenize test case
Apr 29, 2024
17cafba
Fix format of test_thai_text_to_phonemes test case
Apr 30, 2024
1c74eb7
Add test to compare with Korean
Apr 30, 2024
11c5a75
Clean up import statements
Apr 30, 2024
d8a5e13
Add missing code for training Thai language and Thai Grapheme to Phon…
May 8, 2024
28dd03e
Initialize the TextEncoder with fixed size
May 9, 2024
809a043
Add get_resized_embeddings to SynthesizerTrn class
May 9, 2024
818e4b4
Use self.enc_p.emb
May 9, 2024
b25a99d
Remove th_symbols from sorted list
May 9, 2024
11c55ef
Fix unsupported characters in g2p function
May 10, 2024
3004182
Fix tones list from g2p function being initliazed to zeroes and adjus…
May 16, 2024
5b1faee
Assign multiple phones based on the number of phones
May 16, 2024
c1870d9
Fix regression
May 16, 2024
adb9232
Fix tones mapping
May 18, 2024
17a3bd7
Add new test case
May 18, 2024
5c19c03
Change format of word2ph to inlclude count
May 19, 2024
a65cc2f
Keep special characters in word2ph to be consistent with other languages
May 20, 2024
af38d49
Update test case for Thai bert to match format without special cases
May 20, 2024
2be87f2
Fix dictionary lookup and use bert tokenizer in g2p
May 28, 2024
120bcfd
Fix tone issues and add warning for mismatch due to _ underscore char…
May 28, 2024
1251906
Squash 2 nasty bugs: 1.) Assign tone value to each phoneme excluding …
May 28, 2024
d4c9124
Add gradient clipping, process underscores to align with bert feature…
Jun 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion melo/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def get_text(self, text, word2ph, phone, tone, language_str, wav_path):
if language_str in ["ZH"]:
bert = bert
ja_bert = torch.zeros(768, len(phone))
elif language_str in ["JP", "EN", "ZH_MIX_EN", "KR", 'SP', 'ES', 'FR', 'DE', 'RU']:
elif language_str in ["JP", "EN", "ZH_MIX_EN", "KR", 'SP', 'ES', 'FR', 'DE', 'RU', 'TH']:
ja_bert = bert
bert = torch.zeros(1024, len(phone))
else:
Expand Down
38 changes: 36 additions & 2 deletions melo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,23 @@ def __init__(
num_languages=num_languages,
num_tones=num_tones,
)
self.enc_p = TextEncoder(
219, # Initialize with the original symbol size
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
gin_channels=self.enc_gin_channels,
num_languages=num_languages,
num_tones=num_tones,
)
if n_vocab != 219:
old_embeddings = self.enc_p.emb
new_num_tokens = n_vocab
self.enc_p.emb = self.get_resized_embeddings(old_embeddings, new_num_tokens)
self.dec = Generator(
inter_channels,
resblock,
Expand Down Expand Up @@ -884,6 +901,23 @@ def __init__(
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels, layernorm=norm_refenc)
self.use_vc = use_vc

def get_resized_embeddings(self, old_embeddings, new_num_tokens):
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
if old_num_tokens == new_num_tokens:
return old_embeddings

if not isinstance(old_embeddings, nn.Embedding):
raise TypeError(
f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. "
f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Embedding}."
)

new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(
device=old_embeddings.weight.device, dtype=old_embeddings.weight.dtype
)
new_embeddings.weight.data[:old_num_tokens, :] = old_embeddings.weight.data[:old_num_tokens, :]

return new_embeddings

def forward(self, x, x_lengths, y, y_lengths, sid, tone, language, bert, ja_bert):
if self.n_speakers > 0:
Expand Down Expand Up @@ -998,7 +1032,7 @@ def infer(
sdp_ratio
) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
w = torch.exp(logw) * x_mask * length_scale

w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
Expand All @@ -1020,7 +1054,7 @@ def infer(
# print('max/min of o:', o.max(), o.min())
return o, attn, y_mask, (z, z_p, m_p, logs_p)

def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
g_src = sid_src
g_tgt = sid_tgt
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau)
Expand Down
5 changes: 3 additions & 2 deletions melo/text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def get_bert(norm_text, word2ph, language, device):
from .spanish_bert import get_bert_feature as sp_bert
from .french_bert import get_bert_feature as fr_bert
from .korean import get_bert_feature as kr_bert
from .thai import get_bert_feature as th_bert

lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert, 'ZH_MIX_EN': zh_mix_en_bert,
'FR': fr_bert, 'SP': sp_bert, 'ES': sp_bert, "KR": kr_bert}
lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert, 'ZH_MIX_EN': zh_mix_en_bert,
'FR': fr_bert, 'SP': sp_bert, 'ES': sp_bert, "KR": kr_bert, "TH": th_bert}
bert = lang_bert_func_map[language](norm_text, word2ph, device)
return bert
10 changes: 5 additions & 5 deletions melo/text/cleaner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from . import chinese, japanese, english, chinese_mix, korean, french, spanish
from . import chinese, japanese, english, chinese_mix, korean, french, spanish, thai
from . import cleaned_text_to_sequence
import copy

language_module_map = {"ZH": chinese, "JP": japanese, "EN": english, 'ZH_MIX_EN': chinese_mix, 'KR': korean,
'FR': french, 'SP': spanish, 'ES': spanish}
'FR': french, 'SP': spanish, 'ES': spanish, 'TH': thai}


def clean_text(text, language):
Expand All @@ -17,13 +17,13 @@ def clean_text_bert(text, language, device=None):
language_module = language_module_map[language]
norm_text = language_module.text_normalize(text)
phones, tones, word2ph = language_module.g2p(norm_text)

word2ph_bak = copy.deepcopy(word2ph)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
bert = language_module.get_bert_feature(norm_text, word2ph, device=device)

return norm_text, phones, tones, word2ph_bak, bert


Expand All @@ -33,4 +33,4 @@ def text_to_sequence(text, language):


if __name__ == "__main__":
pass
pass
38 changes: 24 additions & 14 deletions melo/text/english_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from transformers import AutoTokenizer, AutoModelForMaskedLM
import sys

model_id = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = None
models = {}
tokenizers = {}

def get_bert_feature(text, word2ph, device=None):
def get_bert_feature(text, word2ph, device=None, model_id='airesearch/wangchanberta-base-att-spm-uncased'):
global model
global tokenizer

if (
sys.platform == "darwin"
and torch.backends.mps.is_available()
Expand All @@ -16,24 +17,33 @@ def get_bert_feature(text, word2ph, device=None):
device = "mps"
if not device:
device = "cuda"
if model is None:

if model_id not in models:
model = AutoModelForMaskedLM.from_pretrained(model_id).to(
device
)
models[model_id] = model
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizers[model_id] = tokenizer
else:
model = models[model_id]
tokenizer = tokenizers[model_id]

with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
import pdb; pdb.set_trace();
for i in inputs:
inputs[i] = inputs[i].to(device)
res = model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
assert inputs["input_ids"].shape[-1] == len(word2ph)
word2phone = word2ph
phone_level_feature = []
for i in range(len(word2phone)):
repeat_feature = res[i].repeat(word2phone[i], 1)
phone_level_feature.append(repeat_feature)

phone_level_feature = torch.cat(phone_level_feature, dim=0)

assert inputs["input_ids"].shape[-1] == len(word2ph), f"{inputs['input_ids'].shape[-1]}/{len(word2ph)}"

word2phone = word2ph[1:-1]
phone_level_feature = []
for i in range(len(word2phone)):
repeat_feature = res[i].repeat(word2phone[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)

return phone_level_feature.T
Loading