Skip to content

Commit

Permalink
Merge pull request #24 from yfyeung/yfyeung-patch-1
Browse files Browse the repository at this point in the history
Update rnnt_loss.py
  • Loading branch information
danpovey committed Jul 19, 2023
2 parents 2945bd7 + 878d7c8 commit a801adc
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions fast_rnnt/python/fast_rnnt/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,10 @@ def get_rnnt_logprobs(

# px is the probs of the actual symbols..
px_am = torch.gather(
am.unsqueeze(1).expand(B, S, T, C),
dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1),
).squeeze(
-1
) # [B][S][T]
am.transpose(1, 2), # (B, C, T)
dim=1,
index=symbols.unsqueeze(2).expand(B, S, T),
) # (B, S, T)

if rnnt_type == "regular":
px_am = torch.cat(
Expand Down Expand Up @@ -1247,12 +1245,10 @@ def get_rnnt_logprobs_smoothed(

# px is the probs of the actual symbols (not yet normalized)..
px_am = torch.gather(
am.unsqueeze(1).expand(B, S, T, C),
dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1),
).squeeze(
-1
) # [B][S][T]
am.transpose(1, 2), # (B, C, T)
dim=1,
index=symbols.unsqueeze(2).expand(B, S, T),
) # (B, S, T)

if rnnt_type == "regular":
px_am = torch.cat(
Expand Down

0 comments on commit a801adc

Please sign in to comment.