Skip to content

Commit

Permalink
Revert "[s2s] command line args for faster val steps (huggingface#6833)"
Browse files Browse the repository at this point in the history
This reverts commit 07d0a53.
  • Loading branch information
fabiocapsouza committed Nov 15, 2020
1 parent 2805245 commit e5abdf8
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 10 deletions.
2 changes: 1 addition & 1 deletion examples/seq2seq/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class BartTranslationDistiller(BartSummarizationDistiller):

mode = "translation"
metric_names = ["bleu"]
default_val_metric = "bleu"
val_metric = "bleu"

def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
Expand Down
9 changes: 2 additions & 7 deletions examples/seq2seq/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class SummarizationModule(BaseTransformer):
mode = "summarization"
loss_names = ["loss"]
metric_names = ROUGE_KEYS
default_val_metric = "rouge2"
val_metric = "rouge2"

def __init__(self, hparams, **kwargs):
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
Expand Down Expand Up @@ -110,9 +110,6 @@ def __init__(self, hparams, **kwargs):
self.dataset_class = (
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
)
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric

def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
Expand Down Expand Up @@ -304,8 +301,6 @@ def add_model_specific_args(parser, root_dir):
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False)
parser.add_argument("--eval_beams", type=int, default=None, required=False)
parser.add_argument("--val_metric", type=str, default=None, required=False)
parser.add_argument(
"--early_stopping_patience",
type=int,
Expand All @@ -320,7 +315,7 @@ class TranslationModule(SummarizationModule):
mode = "translation"
loss_names = ["loss"]
metric_names = ["bleu"]
default_val_metric = "bleu"
val_metric = "bleu"

def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
Expand Down
2 changes: 0 additions & 2 deletions examples/seq2seq/test_seq2seq_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
CUDA_AVAILABLE = torch.cuda.is_available()
CHEAP_ARGS = {
"label_smoothing": 0.2,
"eval_beams": 1,
"val_metric": None,
"adafactor": True,
"early_stopping_patience": 2,
"logger_name": "default",
Expand Down

0 comments on commit e5abdf8

Please sign in to comment.