From c9e4f5c7acde481c35e8e8d96cde5104de127476 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 11 Oct 2021 10:39:49 -0700 Subject: [PATCH] fix --gradient_checkpointing --- examples/research_projects/wav2vec2/run_asr.py | 8 +------- examples/research_projects/wav2vec2/run_common_voice.py | 8 +------- examples/research_projects/wav2vec2/run_pretrain.py | 8 +------- 3 files changed, 3 insertions(+), 21 deletions(-) diff --git a/examples/research_projects/wav2vec2/run_asr.py b/examples/research_projects/wav2vec2/run_asr.py index f4c2561ccf580a..9b031cca1972e1 100755 --- a/examples/research_projects/wav2vec2/run_asr.py +++ b/examples/research_projects/wav2vec2/run_asr.py @@ -54,12 +54,6 @@ class ModelArguments: freeze_feature_extractor: Optional[bool] = field( default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."} ) - gradient_checkpointing: Optional[bool] = field( - default=False, - metadata={ - "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." - }, - ) verbose_logging: Optional[bool] = field( default=False, metadata={"help": "Whether to log verbose messages or not."}, @@ -352,7 +346,7 @@ def main(): model = Wav2Vec2ForCTC.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, - gradient_checkpointing=model_args.gradient_checkpointing, + gradient_checkpointing=training_args.gradient_checkpointing, vocab_size=len(processor.tokenizer), ) diff --git a/examples/research_projects/wav2vec2/run_common_voice.py b/examples/research_projects/wav2vec2/run_common_voice.py index bb69784a8d2cd3..edae86641e0b94 100644 --- a/examples/research_projects/wav2vec2/run_common_voice.py +++ b/examples/research_projects/wav2vec2/run_common_voice.py @@ -84,12 +84,6 @@ class ModelArguments: "vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``." }, ) - gradient_checkpointing: Optional[bool] = field( - default=True, - metadata={ - "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." - }, - ) layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."}) @@ -373,7 +367,7 @@ def extract_all_chars(batch): hidden_dropout=model_args.hidden_dropout, feat_proj_dropout=model_args.feat_proj_dropout, mask_time_prob=model_args.mask_time_prob, - gradient_checkpointing=model_args.gradient_checkpointing, + gradient_checkpointing=training_args.gradient_checkpointing, layerdrop=model_args.layerdrop, ctc_loss_reduction="mean", pad_token_id=processor.tokenizer.pad_token_id, diff --git a/examples/research_projects/wav2vec2/run_pretrain.py b/examples/research_projects/wav2vec2/run_pretrain.py index f0ef04d1814fe4..af354e24b03127 100755 --- a/examples/research_projects/wav2vec2/run_pretrain.py +++ b/examples/research_projects/wav2vec2/run_pretrain.py @@ -50,12 +50,6 @@ class ModelArguments: freeze_feature_extractor: Optional[bool] = field( default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."} ) - gradient_checkpointing: Optional[bool] = field( - default=False, - metadata={ - "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." - }, - ) verbose_logging: Optional[bool] = field( default=False, metadata={"help": "Whether to log verbose messages or not."}, @@ -364,7 +358,7 @@ def normalize(batch): config = Wav2Vec2Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, - gradient_checkpointing=model_args.gradient_checkpointing, + gradient_checkpointing=training_args.gradient_checkpointing, ) if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":