Skip to content

Commit

Permalink
simplify inf. api
Browse files Browse the repository at this point in the history
  • Loading branch information
arashb committed Aug 4, 2022
1 parent 5c8604a commit 9e910fc
Showing 1 changed file with 23 additions and 40 deletions.
63 changes: 23 additions & 40 deletions deepspeed/ops/transformer/inference/transformer_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,8 @@ class DeepSpeedSelfAttentionFunction(Function):
def forward(ctx,
input,
input_mask,
head_mask,
layer_past,
get_present,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
norm_w,
norm_b,
config,
Expand Down Expand Up @@ -529,24 +525,16 @@ def __init__(self,
def forward(self,
input,
input_mask,
head_mask=None,
layer_past=None,
get_present=False,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
norm_w=None,
norm_b=None,
alibi=None):
output = DeepSpeedSelfAttentionFunction.apply(
input,
input_mask,
head_mask,
layer_past,
get_present,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
norm_w,
norm_b,
self.config,
Expand Down Expand Up @@ -795,24 +783,23 @@ def __init__(self,
device=device))
self.layer_past = None

def forward(self,
input,
input_mask=None,
attention_mask=None,
head_mask=None,
layer_past=None,
get_key_value=False,
get_present=False,
encoder_output=None,
enc_dec_attn_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
alibi=None,
output_attentions=False):
get_present = (get_present or get_key_value or use_cache)
input_mask = input_mask if attention_mask is None else attention_mask
layer_past = layer_past if layer_past is not None else self.layer_past
def forward(self, *inputs, **kwargs):
# import pdb; pdb.set_trace()
return self._forward(*inputs,
input_mask=kwargs['input_mask'] if 'input_mask' in kwargs else kwargs['attention_mask'] if 'attention_mask' in kwargs else None,
layer_past=kwargs['layer_past'] if 'layer_past' in kwargs else None,
get_present=kwargs['get_present'] if 'get_present' in kwargs else kwargs['get_key_value'] if 'get_key_value' in kwargs else kwargs['use_cache'] if 'use_cache' in kwargs else False,
alibi=kwargs['alibi'] if 'alibi' in kwargs else None)

def _forward(self,
input,
input_mask=None,
layer_past=None,
get_present=False,
alibi=None):
# get_present = (get_present or get_key_value or use_cache)
# input_mask = input_mask if attention_mask is None else attention_mask
# layer_past = layer_past if layer_past is not None else self.layer_past

attn_mask = None
if isinstance(input, tuple):
Expand All @@ -827,16 +814,12 @@ def forward(self,
with torch.no_grad():
attention_output, key, value, context_outputtn_ctx, inp_norm = \
self.attention(input,
input_mask,
head_mask,
layer_past,
get_present,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
self.norm_w,
self.norm_b,
alibi)
input_mask,
layer_past,
get_present,
self.norm_w,
self.norm_b,
alibi)
presents = (key, value)
self.layer_past = presents if layer_past is None else None
output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob)
Expand Down

0 comments on commit 9e910fc

Please sign in to comment.