Skip to content

Commit

Permalink
[ProphetNet] Bart-like Refactor (#10501)
Browse files Browse the repository at this point in the history
* first step to refactor

* make all fast tests pass

* make all slow tests pass

* save intermediate

* correct cache

* finish PR

* make fp16 work
  • Loading branch information
patrickvonplaten committed Mar 4, 2021
1 parent 6290169 commit c503a1c
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 188 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class ProphetNetConfig(PretrainedConfig):
smoothing is performed.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
"""
model_type = "prophetnet"
keys_to_ignore_at_inference = ["past_key_values"]
Expand Down Expand Up @@ -119,6 +121,7 @@ def __init__(
num_buckets=32,
relative_max_distance=128,
disable_ngram_loss=False,
gradient_checkpointing=False,
eps=0.0,
use_cache=True,
pad_token_id=0,
Expand Down Expand Up @@ -161,6 +164,9 @@ def __init__(

self.use_cache = use_cache

# 4 Training Args (should be removed soon)
self.gradient_checkpointing = gradient_checkpointing

@property
def num_attention_heads(self) -> int:
return self.num_encoder_attention_heads
Expand Down

0 comments on commit c503a1c

Please sign in to comment.