From 8bf3ca8b7a0b0077b9018aac77999b1ef2ec14ea Mon Sep 17 00:00:00 2001 From: fabiocapsouza Date: Sun, 15 Nov 2020 12:30:46 -0300 Subject: [PATCH] Revert "PL: --adafactor option (#6776)" This reverts commit 1179f8772118ebad692bbccf3435aa40d6c46456. --- examples/lightning_base.py | 12 +----------- examples/seq2seq/test_seq2seq_examples.py | 1 - 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/examples/lightning_base.py b/examples/lightning_base.py index 4c8d3649f9693a..d23757a9bcb894 100644 --- a/examples/lightning_base.py +++ b/examples/lightning_base.py @@ -22,7 +22,6 @@ PreTrainedTokenizer, ) from transformers.optimization import ( - Adafactor, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, @@ -138,15 +137,7 @@ def configure_optimizers(self): "weight_decay": 0.0, }, ] - if self.hparams.adafactor: - optimizer = Adafactor( - optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False - ) - - else: - optimizer = AdamW( - optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon - ) + optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) self.opt = optimizer scheduler = self.get_lr_scheduler() @@ -260,7 +251,6 @@ def add_model_specific_args(parser, root_dir): parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int) parser.add_argument("--train_batch_size", default=32, type=int) parser.add_argument("--eval_batch_size", default=32, type=int) - parser.add_argument("--adafactor", action="store_true") class LoggingCallback(pl.Callback): diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index f853557f189ccb..2f397c7adcba08 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -30,7 +30,6 @@ CUDA_AVAILABLE = torch.cuda.is_available() CHEAP_ARGS = { "label_smoothing": 0.2, - "adafactor": True, "early_stopping_patience": 2, "logger_name": "default", "length_penalty": 0.5,