diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index 7dabb2b084f4ba..262fae182f50ca 100644 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -262,7 +262,7 @@ class BartTranslationDistiller(BartSummarizationDistiller): mode = "translation" metric_names = ["bleu"] - default_val_metric = "bleu" + val_metric = "bleu" def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index cf47e32a4f2451..90591e1b0c0524 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -63,7 +63,7 @@ class SummarizationModule(BaseTransformer): mode = "summarization" loss_names = ["loss"] metric_names = ROUGE_KEYS - default_val_metric = "rouge2" + val_metric = "rouge2" def __init__(self, hparams, **kwargs): super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) @@ -110,9 +110,6 @@ def __init__(self, hparams, **kwargs): self.dataset_class = ( Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset ) - self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams - assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1" - self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric def freeze_embeds(self): """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" @@ -304,8 +301,6 @@ def add_model_specific_args(parser, root_dir): parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) parser.add_argument("--src_lang", type=str, default="", required=False) parser.add_argument("--tgt_lang", type=str, default="", required=False) - parser.add_argument("--eval_beams", type=int, default=None, required=False) - parser.add_argument("--val_metric", type=str, default=None, required=False) parser.add_argument( "--early_stopping_patience", type=int, @@ -320,7 +315,7 @@ class TranslationModule(SummarizationModule): mode = "translation" loss_names = ["loss"] metric_names = ["bleu"] - default_val_metric = "bleu" + val_metric = "bleu" def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 410c3ee0a4c7e5..e7c795b7c5d57e 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -31,8 +31,6 @@ CUDA_AVAILABLE = torch.cuda.is_available() CHEAP_ARGS = { "label_smoothing": 0.2, - "eval_beams": 1, - "val_metric": None, "adafactor": True, "early_stopping_patience": 2, "logger_name": "default",