Skip to content

Commit

Permalink
use autocast wrapper + add back autocast condition for warp-transducer
Browse files Browse the repository at this point in the history
  • Loading branch information
b-flo committed Jul 6, 2023
1 parent 8c70533 commit 1d590ff
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions espnet2/asr_transducer/espnet_transducer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,12 +424,13 @@ def _calc_transducer_loss(
)
exit(1)

loss_transducer = self.criterion_transducer(
joint_out,
target,
t_len,
u_len,
)
with autocast(False):
loss_transducer = self.criterion_transducer(
joint_out.float(),
target,
t_len,
u_len,
)

return loss_transducer

Expand Down Expand Up @@ -511,7 +512,7 @@ def _calc_k2_transducer_pruned_loss(
lm = self.lm_proj(decoder_out)
am = self.am_proj(encoder_out)

with torch.cuda.amp.autocast(enabled=False):
with autocast(False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm.float(),
am.float(),
Expand Down Expand Up @@ -540,7 +541,7 @@ def _calc_k2_transducer_pruned_loss(

joint_out = self.joint_network(am_pruned, lm_pruned, no_projection=True)

with torch.cuda.amp.autocast(enabled=False):
with autocast(False):
pruned_loss = k2.rnnt_loss_pruned(
joint_out.float(),
target_padded,
Expand Down

0 comments on commit 1d590ff

Please sign in to comment.