[Bug] torch.isin(elements=inputs, test_elements=pad_token_id).any() TypeError: isin() received an invalid combination of arguments - got (elements=Tensor, test_elements=int, ) #3786
Labels
bug
Something isn't working
Describe the bug
torch.isin(elements=inputs, test_elements=pad_token_id).any()
TypeError: isin() received an invalid combination of arguments - got (elements=Tensor, test_elements=int, ), but expected one of:
To Reproduce
print("Loading model...")
json_path = "D:/xtts2/config.json"
xtts_checkpoint = "D:/xtts2/"
config = XttsConfig()
config.load_json(json_path)
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir=xtts_checkpoint,use_deepspeed=False)
model.cuda()
print("Computing speaker latents...")
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["../tts/reference.wav"])
Expected behavior
No response
Logs
No response
Environment
Additional context
我用自己的类,继承了StreamGenerationConfig,并且重写了update方法解决了这个问题
问题的原因是update 通过kwargs覆盖了原来的配置,为什么python没有一个像spring boot拷贝属性且不覆盖已有属性的方法。
以下是完整代码,希望可以帮到你:
`
from TTS.api import TTS
import torch
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.layers.xtts.stream_generator import StreamGenerationConfig
from TTS.tts.models.xtts import Xtts
import torchaudio
class TokenConfig(StreamGenerationConfig):
def init(self, pad_token_id, eos_token_id, **kwargs):
super().init(**kwargs)
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
def update(self,**kwargs):
to_remove = []
for key, value in kwargs.items():
if hasattr(self, key) and key !='pad_token_id' and key !='eos_token_id':
setattr(self, key, value)
to_remove.append(key)
return {}
device = "cuda" if torch.cuda.is_available() else "cpu"
print("use device {}".format(device))
if name == 'main':
print("Loading model...")
json_path = "D:/xtts2/config.json"
xtts_checkpoint = "D:/xtts2/"
config = XttsConfig()
config.load_json(json_path)
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir=xtts_checkpoint,use_deepspeed=False)
model.cuda()
print("Computing speaker latents...")
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=["../tts/reference.wav"])
chunks = model.inference_stream(
"今天天气真好",
"zh-cn",
gpt_cond_latent,
speaker_embedding,
generation_config=TokenConfig(
pad_token_id=torch.tensor([1025], device=model.device),
eos_token_id=torch.tensor([1025], device=model.device)
)
)
wav_chuncks = []
for i, chunk in enumerate(chunks):
print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
wav_chuncks.append(chunk)
wav = torch.cat(wav_chuncks, dim=0)
torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
`
The text was updated successfully, but these errors were encountered: