Skip to content

Commit

Permalink
Merge pull request #4875 from DongjiGao/uasr
Browse files Browse the repository at this point in the history
EURO: small fix (temporarily remove support for nbest_rescoring)
  • Loading branch information
ftshijt committed Jan 18, 2023
2 parents eb28c38 + c6b6a4a commit 7ad5f58
Showing 1 changed file with 3 additions and 86 deletions.
89 changes: 3 additions & 86 deletions espnet2/bin/uasr_inference_k2.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def __init__(
assert check_argument_types()

# 1. Build UASR model
logging.info(f"==========device to build model from: {device}===========")
uasr_model, uasr_train_args = UASRTask.build_model_from_file(
uasr_train_config, uasr_model_file, device
)
Expand Down Expand Up @@ -252,93 +251,11 @@ def __call__(
lattices.scores *= self.lattice_weight

results = []
# TODO(Dongji): add nbest_rescoring
if self.use_nbest_rescoring:
(
am_scores,
lm_scores,
token_ids,
new2old,
path_to_seq_map,
seq_to_path_splits,
) = nbest_am_lm_scores(
lattices, self.num_paths, self.device, self.nbest_batch_size
raise ValueError(
"Currently nbest rescoring is not supported"
)

ys_pad_lens = torch.tensor([len(hyp) for hyp in token_ids]).to(self.device)
max_token_length = max(ys_pad_lens)
ys_pad_list = []
for hyp in token_ids:
ys_pad_list.append(
torch.cat(
[
torch.tensor(hyp, dtype=torch.long),
torch.tensor(
[self.uasr_model_ignore_id]
* (max_token_length.item() - len(hyp)),
dtype=torch.long,
),
]
)
)

ys_pad = (
torch.stack(ys_pad_list).to(torch.long).to(self.device)
) # [batch, max_token_length]

encoder_out = generated_sample.index_select(
0, path_to_seq_map.to(torch.long)
).to(
self.device
) # [batch, T, dim]
encoder_out_lens = encoder_out_lens.index_select(
0, path_to_seq_map.to(torch.long)
).to(
self.device
) # [batch]

decoder_scores = -self.uasr_model.batchify_nll(
encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, self.nll_batch_size
)

# padded_value for nnlm is 0
ys_pad[ys_pad == self.uasr_model_ignore_id] = 0
nnlm_nll, x_lengths = self.lm.batchify_nll(
ys_pad, ys_pad_lens, self.nll_batch_size
)
nnlm_scores = -nnlm_nll.sum(dim=1)

batch_tot_scores = (
self.am_weight * am_scores
+ self.decoder_weight * decoder_scores
+ self.nnlm_weight * nnlm_scores
)
split_size = indices_to_split_size(
seq_to_path_splits.tolist(), total_elements=batch_tot_scores.size(0)
)
batch_tot_scores = torch.split(
batch_tot_scores,
split_size,
)

hyps = []
scores = []
processed_seqs = 0
for tot_scores in batch_tot_scores:
if tot_scores.nelement() == 0:
# the last element by torch.tensor_split may be empty
# e.g.
# torch.tensor_split(torch.tensor([1,2,3,4]), torch.tensor([2,4]))
# (tensor([1, 2]), tensor([3, 4]), tensor([], dtype=torch.int64))
break
best_seq_idx = processed_seqs + torch.argmax(tot_scores)

assert best_seq_idx < len(token_ids)
best_token_seqs = token_ids[best_seq_idx]
processed_seqs += tot_scores.nelement()
hyps.append(best_token_seqs)
scores.append(tot_scores.max().item())

assert len(hyps) == len(split_size)
else:
best_paths = one_best_decoding(lattices, use_double_scores=True)
scores = best_paths.get_tot_scores(
Expand Down

0 comments on commit 7ad5f58

Please sign in to comment.