Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] torch.isin(elements=inputs, test_elements=pad_token_id).any() TypeError: isin() received an invalid combination of arguments - got (elements=Tensor, test_elements=int, ) #3786

Closed
XPDD opened this issue Jun 12, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@XPDD
Copy link

XPDD commented Jun 12, 2024

Describe the bug

torch.isin(elements=inputs, test_elements=pad_token_id).any()
TypeError: isin() received an invalid combination of arguments - got (elements=Tensor, test_elements=int, ), but expected one of:

  • (Tensor elements, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
  • (Number element, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
  • (Tensor elements, Number test_element, *, bool assume_unique, bool invert, Tensor out)

image

To Reproduce

print("Loading model...")
json_path = "D:/xtts2/config.json"
xtts_checkpoint = "D:/xtts2/"
config = XttsConfig()
config.load_json(json_path)
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir=xtts_checkpoint,use_deepspeed=False)
model.cuda()
print("Computing speaker latents...")
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["../tts/reference.wav"])

chunks = model.inference_stream(
    "今天天气真好",
    "zh-cn",
    gpt_cond_latent,
    speaker_embedding,
    generation_config=StreamGenerationConfig(
        pad_token_id=torch.tensor([1025], device=model.device),
        eos_token_id=torch.tensor([1025], device=model.device)
    )
)
wav_chuncks = []
for i, chunk in enumerate(chunks):
    print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
    wav_chuncks.append(chunk)
wav = torch.cat(wav_chuncks, dim=0)
torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)

Expected behavior

No response

Logs

No response

Environment

TTS 版本 0.22.0
xtts v2

Additional context

我用自己的类,继承了StreamGenerationConfig,并且重写了update方法解决了这个问题
问题的原因是update 通过kwargs覆盖了原来的配置,为什么python没有一个像spring boot拷贝属性且不覆盖已有属性的方法。
以下是完整代码,希望可以帮到你:
`
from TTS.api import TTS
import torch
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.layers.xtts.stream_generator import StreamGenerationConfig
from TTS.tts.models.xtts import Xtts
import torchaudio

class TokenConfig(StreamGenerationConfig):
def init(self, pad_token_id, eos_token_id, **kwargs):
super().init(**kwargs)
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
def update(self,**kwargs):
to_remove = []
for key, value in kwargs.items():
if hasattr(self, key) and key !='pad_token_id' and key !='eos_token_id':
setattr(self, key, value)
to_remove.append(key)
return {}

device = "cuda" if torch.cuda.is_available() else "cpu"
print("use device {}".format(device))

if name == 'main':
print("Loading model...")
json_path = "D:/xtts2/config.json"
xtts_checkpoint = "D:/xtts2/"
config = XttsConfig()
config.load_json(json_path)
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir=xtts_checkpoint,use_deepspeed=False)
model.cuda()
print("Computing speaker latents...")
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["../tts/reference.wav"])
chunks = model.inference_stream(
"今天天气真好",
"zh-cn",
gpt_cond_latent,
speaker_embedding,
generation_config=TokenConfig(
pad_token_id=torch.tensor([1025], device=model.device),
eos_token_id=torch.tensor([1025], device=model.device)
)
)
wav_chuncks = []
for i, chunk in enumerate(chunks):
print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
wav_chuncks.append(chunk)
wav = torch.cat(wav_chuncks, dim=0)
torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
`

@XPDD XPDD added the bug Something isn't working label Jun 12, 2024
@eginhard
Copy link
Contributor

Duplicate of idiap#31. It should work fine if you install transformers version 4.40.2 or lower. Installing our fork (pip install coqui-tts) will also take care of that.

@XPDD XPDD closed this as completed Jun 13, 2024
@XPDD
Copy link
Author

XPDD commented Jun 13, 2024

Thank you, it can work now.

1 similar comment
@wanjiecc
Copy link

Thank you, it can work now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants