Skip to content

Commit

Permalink
fix the inference API (temp. solution)
Browse files Browse the repository at this point in the history
  • Loading branch information
arashb committed Aug 9, 2022
1 parent f8aef93 commit e7fecac
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion deepspeed/ops/transformer/inference/transformer_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,11 @@ def forward(self,
encoder_attention_mask=None,
use_cache=False,
alibi=None,
output_attentions=False):
output_attentions=False,
# TODO(arashb): 'layer_head_mask' and 'past_key_value' are only added to satisfy the OPT models API.
# This needs to be redesigned later!
layer_head_mask=None,
past_key_value=None):
get_present = (get_present or get_key_value or use_cache)
input_mask = input_mask if attention_mask is None else attention_mask

Expand All @@ -823,6 +827,7 @@ def forward(self,
self.layer_past = None

layer_past = layer_past if layer_past is not None else self.layer_past
head_mask = layer_head_mask if layer_head_mask is not None else head_mask

attn_mask = None
if isinstance(input, tuple):
Expand Down

0 comments on commit e7fecac

Please sign in to comment.