Skip to content

Commit

Permalink
Revert checkpointing in espnet2/enh/espnet_enh_s2t_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Emrys365 committed Apr 27, 2023
1 parent 08534da commit b973527
Showing 1 changed file with 3 additions and 13 deletions.
16 changes: 3 additions & 13 deletions espnet2/enh/espnet_enh_s2t_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
s2t_model: Union[ESPnetASRModel, ESPnetSTModel, ESPnetDiarizationModel],
calc_enh_loss: bool = True,
bypass_enh_prob: float = 0, # 0 means do not bypass enhancement for all data
enh_checkpointing: bool = False,
):
assert check_argument_types()

Expand All @@ -44,7 +43,6 @@ def __init__(
self.s2t_model = s2t_model # ASR or ST or DIAR model

self.bypass_enh_prob = bypass_enh_prob
self.enh_checkpointing = enh_checkpointing

self.calc_enh_loss = calc_enh_loss
if isinstance(self.s2t_model, ESPnetDiarizationModel):
Expand Down Expand Up @@ -236,17 +234,9 @@ def forward(
loss_enh = None
perm = None
if not bypass_enh_flag:
if self.enh_checkpointing:
ret = torch.utils.checkpoint.checkpoint(
self.enh_model.forward_enhance,
speech,
speech_lengths,
{"num_spk": num_spk},
)
else:
ret = self.enh_model.forward_enhance(
speech, speech_lengths, {"num_spk": num_spk}
)
ret = self.enh_model.forward_enhance(
speech, speech_lengths, {"num_spk": num_spk}
)
speech_pre, feature_mix, feature_pre, others = ret
# loss computation
if not skip_enhloss_flag:
Expand Down

0 comments on commit b973527

Please sign in to comment.