Skip to content

Commit

Permalink
馃悰 add back commits that were overwritten
Browse files Browse the repository at this point in the history
  • Loading branch information
nateraw committed Jul 17, 2020
1 parent 01d31b6 commit d0c47a9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
25 changes: 22 additions & 3 deletions examples/seq2seq/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from torch.utils.data import DataLoader

from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import get_linear_schedule_with_warmup
from transformers import MBartTokenizer, get_linear_schedule_with_warmup


try:
from .utils import (
assert_all_frozen,
use_task_specific_params,
SummarizationDataset,
lmap,
Expand Down Expand Up @@ -47,6 +48,7 @@
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
assert_all_frozen,
)
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback

Expand Down Expand Up @@ -92,9 +94,12 @@ def __init__(self, hparams, **kwargs):
if self.hparams.freeze_embeds:
self.freeze_embeds()
if self.hparams.freeze_encoder:
freeze_params(self.model.model.encoder) # TODO: this will break for t5
freeze_params(self.model.get_encoder())
assert_all_frozen(self.model.get_encoder())

self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
self.decoder_start_token_id = None

def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
Expand Down Expand Up @@ -160,7 +165,12 @@ def _generative_step(self, batch: dict) -> dict:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
t0 = time.time()
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
generated_ids = self.model.generate(
input_ids=source_ids,
attention_mask=source_mask,
use_cache=True,
decoder_start_token_id=self.decoder_start_token_id,
)
gen_time = (time.time() - t0) / source_ids.shape[0]
preds = self.ids_to_clean_text(generated_ids)
target = self.ids_to_clean_text(y)
Expand Down Expand Up @@ -276,6 +286,8 @@ def add_model_specific_args(parser, root_dir):
parser.add_argument(
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
)
parser.add_argument("--src_lang", type=str, default="", required=False)
parser.add_argument("--tgt_lang", type=str, default="", required=False)
return parser


Expand All @@ -285,6 +297,13 @@ class TranslationModule(SummarizationModule):
metric_names = ["bleu"]
val_metric = "bleu"

def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.dataset_kwargs["src_lang"] = hparams.src_lang
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]

def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target)

Expand Down
4 changes: 2 additions & 2 deletions examples/seq2seq/finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ export PYTHONPATH="../":"${PYTHONPATH}"
python finetune.py \
--learning_rate=3e-5 \
--fp16 \
--gpus 2 \
--gpus 1 \
--do_train \
--do_predict \
--n_val 1000 \
--val_check_interval 0.1 \
$@
$@

0 comments on commit d0c47a9

Please sign in to comment.