Skip to content

Commit

Permalink
🐛 fix refine #117
Browse files Browse the repository at this point in the history
- revert `gpt.py`
  • Loading branch information
zhzLuke96 committed Jul 29, 2024
1 parent a9c9d16 commit f0c2602
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 95 deletions.
3 changes: 2 additions & 1 deletion modules/core/models/zoo/ChatTTSInfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _infer(
text_tokens = [
i[i.less(self.instance.tokenizer.break_0_ids)] for i in text_tokens
]
text = self.get_tokenizer().batch_decode(text_tokens)
text = self.instance.tokenizer.decode(text_tokens)
refined.destroy()
if refine_text_only:
yield text
Expand All @@ -157,6 +157,7 @@ def _infer(
params_infer_code,
):
wavs = self._decode_to_wavs(result, length, use_decoder)
result.destroy()
yield wavs
else:
# NOTE: 貌似没什么用...?
Expand Down
127 changes: 33 additions & 94 deletions modules/repos_static/ChatTTS/ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,8 @@ def generate(
stream_batch=24,
context=Context(),
):

attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = []
hiddens = []
stream_iter = 0

start_idx, end_idx = inputs_ids.shape[1], torch.zeros(
inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long
Expand Down Expand Up @@ -405,19 +403,6 @@ def generate(
attention_mask
)

progress = inputs_ids.size(1)
# pre-allocate inputs_ids
inputs_ids_buf = torch.zeros(
inputs_ids.size(0),
progress + max_new_token,
inputs_ids.size(2),
dtype=inputs_ids.dtype,
device=inputs_ids.device,
)
inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids)
del inputs_ids
inputs_ids = inputs_ids_buf.narrow(1, 0, progress)

pbar: Optional[tqdm] = None

if show_tqdm:
Expand All @@ -430,12 +415,11 @@ def generate(
past_key_values = None

for i in range(max_new_token):

model_input = self._prepare_generation_inputs(
inputs_ids,
past_key_values,
attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
use_cache=not self.is_te_llama,
use_cache=True,
)

if i > 0:
Expand Down Expand Up @@ -483,16 +467,14 @@ def generate(
hidden_states.size(1),
self.num_audio_tokens,
self.num_vq,
dtype=torch.float,
dtype=self.gpt.dtype,
device=self.device,
)
for num_vq_iter in range(self.num_vq):
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
logits[..., num_vq_iter] = x
del x

del hidden_states

# logits = logits[:, -1].float()
logits = logits.narrow(1, -1, 1).squeeze_(1).to(dtype=self.gpt.dtype)

Expand All @@ -501,26 +483,13 @@ def generate(
logits = logits.permute(0, 2, 1)
logits = logits.reshape(-1, logits.size(2))
# logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
inputs_ids_sliced = inputs_ids.narrow(
1,
start_idx,
inputs_ids.size(1) - start_idx,
).permute(0, 2, 1)
inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1)
logits_token = inputs_ids_sliced.reshape(
inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
-1,
).to(self.device)
del inputs_ids_sliced
else:
logits_token = (
inputs_ids.narrow(
1,
start_idx,
inputs_ids.size(1) - start_idx,
)
.narrow(2, 0, 1)
.to(self.device)
)
logits_token = inputs_ids[:, start_idx:, 0].to(self.device)

logits /= temperature

Expand All @@ -537,87 +506,57 @@ def generate(

scores = F.softmax(logits, dim=-1)

del logits

if i == 0:
# when i == 0, we want to ensure that the first token is not eos_token
scores[:, eos_token] = 0

del logits

idx_next = torch.multinomial(scores, num_samples=1).to(
device=finish.device, dtype=self.gpt.dtype
)

del scores
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)

if not infer_text:
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
idx_next = idx_next.view(-1, self.num_vq)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
inputs_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1))
inputs_ids_tmp = torch.cat([inputs_ids, idx_next.unsqueeze_(1)], 1)
else:
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
inputs_ids_buf.narrow(1, progress, 1).copy_(
idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq),
inputs_ids_tmp = torch.cat(
[
inputs_ids,
idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq),
],
1,
)

if i == 0 and finish.any():
self.logger.warning(
self.logger.warn(
"unexpected end at index %s",
str([unexpected_idx.item() for unexpected_idx in finish.nonzero()]),
)
if ensure_non_empty:
if show_tqdm:
pbar.close()
self.logger.warning("regenerate in order to ensure non-empty")
del_all(attentions)
del_all(hiddens)
del (
start_idx,
end_idx,
finish,
temperature,
attention_mask_cache,
past_key_values,
idx_next,
inputs_ids_buf,
self.logger.warn(
"ensure_non_empty is deprecated, you can safely ignore it"
)
new_gen = self.generate(
emb,
inputs_ids,
old_temperature,
eos_token,
attention_mask,
max_new_token,
min_new_token,
logits_warpers,
logits_processors,
infer_text,
return_attn,
return_hidden,
stream,
show_tqdm,
ensure_non_empty,
stream_batch,
context,
)
for result in new_gen:
yield result
del inputs_ids
return

del idx_next
progress += 1
inputs_ids = inputs_ids_buf.narrow(1, 0, progress)

not_finished = finish.logical_not().to(end_idx.device)
end_idx.add_(not_finished.int())
stream_iter += not_finished.any().int()

del inputs_ids
inputs_ids = inputs_ids_tmp
del inputs_ids_tmp, idx_next

if stream:
minus_prev_end_index = end_idx.neg()

end_idx.add_((finish.logical_not().to(end_idx.device)).int())
if stream:
if stream_iter > 0 and stream_iter % stream_batch == 0:
if (
end_idx.all()
and end_idx.fmod(stream_batch).eq(0).any()
and minus_prev_end_index.add_(end_idx).any()
):
self.logger.debug("yield stream result, end: %d", end_idx)
yield self._prepare_generation_outputs(
inputs_ids,
Expand All @@ -627,7 +566,7 @@ def generate(
hiddens,
infer_text,
)
del not_finished
del minus_prev_end_index

if finish.all() or context.get():
break
Expand All @@ -646,7 +585,7 @@ def generate(
f"incomplete result. hit max_new_token: {max_new_token}"
)

del finish, inputs_ids_buf
del finish

yield self._prepare_generation_outputs(
inputs_ids,
Expand Down

0 comments on commit f0c2602

Please sign in to comment.