Skip to content

Commit

Permalink
XTTS: add inference_stream_text (slightly friendlier for text-streaming)
Browse files Browse the repository at this point in the history
  • Loading branch information
czuzu committed May 8, 2024
1 parent dbf1a08 commit 57a47d2
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 58 deletions.
186 changes: 132 additions & 54 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def __init__(self, config: Coqpit):
self.decoder_checkpoint = self.args.decoder_checkpoint # TODO: check if this is even needed
self.models_dir = config.model_dir
self.gpt_batch_size = self.args.gpt_batch_size
self._stream_text_holder = []
self._stream_generator = None

self.tokenizer = VoiceBpeTokenizer()
self.gpt = None
Expand Down Expand Up @@ -632,64 +634,140 @@ def inference_stream(
length_scale = 1.0 / max(speed, 0.05)
gpt_cond_latent = gpt_cond_latent.to(self.device)
speaker_embedding = speaker_embedding.to(self.device)
if enable_text_splitting:
text = split_sentence(text, language, self.tokenizer.char_limits[language])
else:
text = [text]
text_streaming = (text is None)

for sent in text:
sent = sent.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
while True:
if text_streaming:
yield None
if len(self._stream_text_holder) == 0:
return
text, enable_text_splitting = self._stream_text_holder

assert (
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
if enable_text_splitting:
text = split_sentence(text, language, self.tokenizer.char_limits[language])
else:
text = [text]

fake_inputs = self.gpt.compute_embeddings(
gpt_cond_latent.to(self.device),
text_tokens,
)
gpt_generator = self.gpt.get_generator(
fake_inputs=fake_inputs,
top_k=top_k,
top_p=top_p,
temperature=temperature,
do_sample=do_sample,
num_beams=1,
num_return_sequences=1,
length_penalty=float(length_penalty),
repetition_penalty=float(repetition_penalty),
output_attentions=False,
output_hidden_states=True,
**hf_generate_kwargs,
)
for sent in text:
sent = sent.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)

assert (
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."

fake_inputs = self.gpt.compute_embeddings(
gpt_cond_latent.to(self.device),
text_tokens,
)
gpt_generator = self.gpt.get_generator(
fake_inputs=fake_inputs,
top_k=top_k,
top_p=top_p,
temperature=temperature,
do_sample=do_sample,
num_beams=1,
num_return_sequences=1,
length_penalty=float(length_penalty),
repetition_penalty=float(repetition_penalty),
output_attentions=False,
output_hidden_states=True,
**hf_generate_kwargs,
)

last_tokens = []
all_latents = []
wav_gen_prev = None
wav_overlap = None
is_end = False

while not is_end:
try:
x, latent = next(gpt_generator)
last_tokens += [x]
all_latents += [latent]
except StopIteration:
is_end = True

if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
).transpose(1, 2)
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
)
last_tokens = []
yield wav_chunk

if not text_streaming:
return

def inference_stream_text(
self,
language,
gpt_cond_latent,
speaker_embedding,
# Streaming
stream_chunk_size=20,
overlap_wav_len=1024,
# GPT inference
temperature=0.75,
length_penalty=1.0,
repetition_penalty=10.0,
top_k=50,
top_p=0.85,
do_sample=True,
speed=1.0,
**hf_generate_kwargs,
):
if self._stream_generator is not None:
raise Exception('Inference text-streaming already in progress. '
'Did you forget to call inference_finalize_text?')

# Arguments `text` and `enable_text_splitting` given through holder
self._stream_text_holder = [None, None]
self._stream_generator = self.inference_stream(
None,
language,
gpt_cond_latent,
speaker_embedding,
stream_chunk_size=stream_chunk_size,
overlap_wav_len=overlap_wav_len,
temperature=temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
speed=speed,
**hf_generate_kwargs,
)

last_tokens = []
all_latents = []
wav_gen_prev = None
wav_overlap = None
is_end = False

while not is_end:
try:
x, latent = next(gpt_generator)
last_tokens += [x]
all_latents += [latent]
except StopIteration:
is_end = True

if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
).transpose(1, 2)
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
)
last_tokens = []
yield wav_chunk
# Start the generator and return it
_ = next(self._stream_generator)
return self._stream_generator

def inference_add_text(self, text: str, enable_text_splitting=False):
if self._stream_generator is None:
raise Exception('Inference text-streaming not started. '
'Please call inference_stream_text first')
self._stream_text_holder[0] = text
self._stream_text_holder[1] = enable_text_splitting

def inference_finalize_text(self):
if self._stream_generator is None:
raise Exception('Inference text-streaming was not started '
'(start with inference_stream_text)')
# Finalize and reset the generator
self._stream_text_holder.clear()
try:
_ = next(self._stream_generator)
except StopIteration:
pass
self._stream_generator = None

def forward(self):
raise NotImplementedError(
Expand Down
42 changes: 38 additions & 4 deletions docs/source/models/xtts.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
```


##### Streaming manually
##### Streaming inference

Here the goal is to stream the audio as it is being generated. This is useful for real-time applications.
Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster.
Expand Down Expand Up @@ -253,16 +253,50 @@ chunks = model.inference_stream(
speaker_embedding
)

wav_chuncks = []
wav_chunks = []
for i, chunk in enumerate(chunks):
if i == 0:
print(f"Time to first chunck: {time.time() - t0}")
print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
wav_chuncks.append(chunk)
wav = torch.cat(wav_chuncks, dim=0)
wav_chunks.append(chunk)
wav = torch.cat(wav_chunks, dim=0)
torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
```

If you also need to do text-streaming you can use `inference_stream_text`, like so:

```python
# ...same setup as before

def text_streaming_generator():
yield "It took me quite a long time to develop a voice and now that I have it I am not going to be silent."
yield "Having discovered not just one, but many voices, I will champion each."

print("Inference with text streaming...")

text_gen = text_streaming_generator()
inf_gen = model.inference_stream_text(
"en",
gpt_cond_latent,
speaker_embedding
)

wav_chunks = []
for text in text_gen:
# Add text progressively
model.inference_add_text(text, enable_text_splitting=True)
for chunk in enumerate(inf_gen):
if chunk is None:
break # all chunks generated for the current text
print(f"Received chunk {len(wav_chunks)} of audio length {chunk.shape[-1]}")
wav_chunks.append(chunk)

# Call finalize to discard the inference generator
model.inference_finalize_text()

wav = torch.cat(wav_chunks, dim=0)
torchaudio.save("xtts_streaming_text.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
```

### Training

Expand Down

0 comments on commit 57a47d2

Please sign in to comment.