Skip to content

Commit

Permalink
Fix for rnnt_loss.py (#1177)
Browse files Browse the repository at this point in the history
* Update rnnt_loss.py

* Update rnnt_loss.py

* Fix for style check

* Fix for style check

* Update rnnt_loss.py
  • Loading branch information
yfyeung committed Apr 26, 2023
1 parent 8c9044e commit a23383c
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions k2/python/k2/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,10 @@ def get_rnnt_logprobs(

# px is the probs of the actual symbols..
px_am = torch.gather(
am.unsqueeze(1).expand(B, S, T, C),
dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1),
).squeeze(
-1
) # [B][S][T]
am.transpose(1, 2), # (B, C, T)
dim=1,
index=symbols.unsqueeze(2).expand(B, S, T),
) # (B, S, T)

if rnnt_type == "regular":
px_am = torch.cat(
Expand Down Expand Up @@ -1399,12 +1397,10 @@ def get_rnnt_logprobs_smoothed(

# px is the probs of the actual symbols (not yet normalized)..
px_am = torch.gather(
am.unsqueeze(1).expand(B, S, T, C),
dim=3,
index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1),
).squeeze(
-1
) # [B][S][T]
am.transpose(1, 2), # (B, C, T)
dim=1,
index=symbols.unsqueeze(2).expand(B, S, T),
) # (B, S, T)

if rnnt_type == "regular":
px_am = torch.cat(
Expand Down

0 comments on commit a23383c

Please sign in to comment.