Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] Can we add the batch inference or batch decoding for XTTS #3776

Open
Onkarsus13 opened this issue Jun 5, 2024 · 3 comments
Labels
feature request feature requests for making TTS better.

Comments

@Onkarsus13
Copy link

I tried the batch inference in XTTS, So I am doing padding till the max text sequence in the batch and also adding the attention mask for this, But for shorter sequences,
I am getting some random noise at the end of the audio
It would be helpful if we get this feature in Coqui tts.

@Onkarsus13 Onkarsus13 added the feature request feature requests for making TTS better. label Jun 5, 2024
@tuanh123789
Copy link

I face same problem when infer with batch size. Do you solve it

@Rakshith12-pixel
Copy link

@Onkarsus13 Could you implement batched inference successfully?

@Onkarsus13
Copy link
Author

Yes Rakshith I can implement it
But is I am able to do it like partial batch decoding
Let me share the code sinppet with you guys

    @torch.inference_mode()
    def Pbatch_inference(
        self,
        text,
        language,
        gpt_cond_latent,
        speaker_embedding,
        # GPT inference
        temperature=0.75,
        length_penalty=1.0,
        repetition_penalty=10.0,
        top_k=50,
        top_p=0.85,
        do_sample=True,
        num_beams=1,
        speed=1.0,
        enable_text_splitting=False,
        **hf_generate_kwargs,
    ):
        language = language.split("-")[0]  # remove the country code
        length_scale = 1.0 / max(speed, 0.05)
        gpt_cond_latent = gpt_cond_latent.to(self.device)
        speaker_embedding = speaker_embedding.to(self.device)

        xg = gpt_cond_latent.repeat(len(text), 1, 1)
        xse = speaker_embedding.repeat(len(text), 1, 1)

        wavs = []
        text_tokens = []
        gpt_latents_list = []
        lens = []
        GPT_in = []
        with torch.no_grad():
            for sent in text:
                sent = sent.strip().lower()
                text_token = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0)
                lens.append(text_token.shape[1])
                text_tokens.append(text_token)
                
                gpt_codes = self.gpt.generate(
                    cond_latents=xg[0].unsqueeze(0),
                    text_inputs=text_token.to(self.device),
                    input_tokens=None,
                    do_sample=do_sample,
                    top_p=top_p,
                    top_k=top_k,
                    temperature=temperature,
                    num_return_sequences=self.gpt_batch_size,
                    num_beams=num_beams,
                    length_penalty=length_penalty,
                    repetition_penalty=repetition_penalty,
                    output_attentions=False,
                    **hf_generate_kwargs,
                )
                GPT_in.append(gpt_codes[0])
        
            max_text_len = max(lens)
            text_padded = torch.IntTensor(len(text), max_text_len)
            text_padded = text_padded.zero_()
            for i in range(len(text)):
                t = text_tokens[i]
                text_padded[i, : lens[i]] = torch.IntTensor(t)
            text_padded = text_padded.to(self.device)
            
            gpt_codes = rnn_utils.pad_sequence(GPT_in, batch_first=True, padding_value=1025)

            expected_output_len = torch.tensor(
                    [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=self.device
                )

            text_len = torch.tensor(lens, device=self.device)
            gpt_latents = self.gpt2(
                text_padded,
                text_len,
                gpt_codes,
                expected_output_len,
                cond_latents=xg,
                return_attentions=False,
                return_latent=True,
            )

            for i in range(gpt_codes.shape[0]):
                for idx, d in enumerate(gpt_codes[i]):
                    if d == 1025:
                        break
                
                z = torch.zeros((gpt_codes[i].shape[0] - idx-1, gpt_latents.shape[-1]), dtype=gpt_latents.dtype, device=gpt_latents.device)
                gpt_latents[i,idx:,:] = z

            if length_scale != 1.0:
                gpt_latents = F.interpolate(
                    gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
                ).transpose(1, 2)
            wav = self.hifigan_decoder(gpt_latents, g=xse).cpu().squeeze()


        return {
            "wav": wav.cpu().unsqueeze(1),
            "gpt_latents": gpt_latents.cpu().numpy(),
            "speaker_embedding": speaker_embedding,
        }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request feature requests for making TTS better.
Projects
None yet
Development

No branches or pull requests

3 participants