diff --git a/scripts/run_llama.py b/scripts/run_llama.py index 5664c97..14bfc78 100644 --- a/scripts/run_llama.py +++ b/scripts/run_llama.py @@ -82,15 +82,14 @@ def timing_cuda( inputs: Dict, max_new_tokens: int, device: torch.device, - cache_length: int, - preallocate: bool, do_profile: bool, ): warmup_start_event = torch.cuda.Event(enable_timing=True) warmup_end_event = torch.cuda.Event(enable_timing=True) - if preallocate: - inputs["cache_length"] = cache_length + if do_profile: + num_runs = PROFILE_NUM_RUNS + max_new_tokens = PROFILE_NEW_TOKENS with torch.no_grad(): print(f"Warming up ({WARMUP_RUNS} runs)...") @@ -255,9 +254,7 @@ def timing_cuda( inputs=inp, device=device, max_new_tokens=max_new_tokens, - cache_length=cache_length, generate_method=generate_method, - preallocate=args.preallocate, do_profile=args.profile, ) except: diff --git a/src/trfs_fast/generation.py b/src/trfs_fast/generation.py index a02c87a..a1e130e 100644 --- a/src/trfs_fast/generation.py +++ b/src/trfs_fast/generation.py @@ -38,7 +38,6 @@ def generate_minimal( max_new_tokens: int, inputs: Optional[torch.Tensor] = None, streamer: Optional["BaseStreamer"] = None, - cache_length: Optional[int] = None, **model_kwargs ) -> torch.LongTensor: r""" @@ -113,13 +112,6 @@ def generate_minimal( if streamer is not None: streamer.put(input_ids.cpu()) - batch_size, context_length = input_ids.shape - cache_length = cache_length if cache_length is not None else max_new_tokens - - model_kwargs["valid_past_index"] = torch.tensor(0, dtype=torch.int64) - model_kwargs["past_key_values"] = self.get_empty_kv_cache(batch_size=batch_size, cache_length=cache_length, device=input_ids.device, dtype=self.dtype) - model_kwargs["attention_mask"] = self.get_preallocated_attention_mask(attention_mask=model_kwargs["attention_mask"], batch_size=batch_size, cache_length=cache_length, device=input_ids.device, context_length=context_length) - # 11. run greedy search return self.greedy_search_minimal( input_ids, @@ -180,8 +172,9 @@ def greedy_search_minimal( # keep track of which sequences are already finished unfinished_sequences = torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) + max_length = input_ids.shape[-1] + max_new_tokens + min_length = input_ids.shape[-1] + min_new_tokens - counter = 0 result = input_ids while True: # prepare model inputs @@ -194,7 +187,6 @@ def greedy_search_minimal( output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) - counter += 1 # argmax next_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True) @@ -206,14 +198,11 @@ def greedy_search_minimal( next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) input_ids = next_tokens - - # update generated ids, model inputs, and length for next step result = torch.cat([result, next_tokens], dim=-1) + cur_len = result.shape[-1] + if streamer is not None: streamer.put(next_tokens.cpu()) - model_kwargs = self.__update_model_kwargs_for_generation( - outputs, model_kwargs, model_inputs - ) # TODO: not sure this is correct anymore with the keepdim=True # if eos_token was found in one sentence, set sentence to finished @@ -223,13 +212,18 @@ def greedy_search_minimal( ) # stop when each sentence is finished - if unfinished_sequences.max() == 0 and counter >= min_new_tokens: + if unfinished_sequences.max() == 0 and cur_len >= min_length: break # stop if we exceed the maximum length - if counter >= max_new_tokens: + if cur_len >= max_length: break + # Update tensors for the next iteration + model_kwargs = self.__update_model_kwargs_for_generation( + outputs, model_kwargs, model_inputs, max_length, cur_len + ) + if streamer is not None: streamer.end() @@ -239,26 +233,43 @@ def __update_model_kwargs_for_generation( self, outputs: ModelOutput, model_kwargs: Dict[str, Any], - model_inputs: Dict[str, Any] + model_inputs: Dict[str, Any], + max_length: int, + cur_len: int ) -> Dict[str, Any]: - model_kwargs["valid_past_index"] += outputs.logits.shape[1] - - if getattr(outputs, "state", None) is not None: - model_kwargs["state"] = outputs.state - # update attention mask - """ - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - """ - position_ids = model_inputs["position_ids"] - if position_ids.shape[1] > 1: - model_kwargs["position_ids"] = position_ids[:, -1:] + 1 + batch_size, _, _ = outputs.logits.shape + device = outputs.logits.device + + # Create and fill fixed-sized tensors. This only occurs once, after the prefilling step. Note: at each + # generation step, in the attention layer, the KV values will be placed in the last position of + # `past_key_values` -- for that reason, the attention mask must always hold a 1 in the last position. + if "past_key_values" in model_inputs and model_inputs["past_key_values"] is None: + # create tensors for the maximum size + padded_attention = torch.zeros(batch_size, max_length, dtype=torch.int64, device=device) + padded_past_key_values = get_empty_kv_cache(config=self.config, batch_size=batch_size, max_length=max_length, device=device, dtype=self.dtype) + # fill with the existing values + padded_attention[:, :cur_len - 1] = model_inputs["attention_mask"] + padded_attention[:, -1] = 1 # the token being generated is appened in the last postion + for i in range(len(padded_past_key_values)): + padded_past_key_values[i][..., :cur_len - 1, :] = outputs.past_key_values[i] + # set them to the variable expected by `generate` + model_kwargs["attention_mask"] = padded_attention + model_kwargs["past_key_values"] = padded_past_key_values + + # also update the positions ids, from the previous position ids + model_kwargs["position_ids"] = model_inputs["position_ids"][:, -1] + 1 else: - model_kwargs["position_ids"] = position_ids + 1 + # Position ids update: simply add one + model_kwargs["position_ids"] += 1 + # Attention mask update: add a one in the position to backfill (corresponding to the token that was just + # selected) + backfill_pos = cur_len - 2 + model_kwargs["attention_mask"][:, backfill_pos] = 1 + # past_key_values update: Move the cache appended on the last position to its permanent position + for i in range(len(model_kwargs["past_key_values"])): + model_kwargs["past_key_values"][i][..., backfill_pos, :] = outputs.past_key_values[i][..., -1, :] + model_kwargs["past_key_values"][i][..., -1, :] *= 0 # see inductor bug mentioned in the attn layer # NOTE: token_type_ids is not used by llama so we don't care about this one for now # update token_type_ids with last value @@ -267,3 +278,19 @@ def __update_model_kwargs_for_generation( model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) return model_kwargs + + +def get_empty_kv_cache(config, batch_size: int, max_length: int, dtype: torch.dtype, device: torch.device): + past_key_values = [ + torch.zeros( + 2, + batch_size, + config.num_attention_heads, + max_length, + config.hidden_size // config.num_attention_heads, # head dimension + dtype=dtype, + device=device + ) + for _ in range(config.num_hidden_layers) + ] + return past_key_values diff --git a/src/trfs_fast/llama.py b/src/trfs_fast/llama.py index 551d181..3e87433 100644 --- a/src/trfs_fast/llama.py +++ b/src/trfs_fast/llama.py @@ -194,7 +194,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, - valid_past_index: torch.Tensor = torch.tensor(0, dtype=torch.int64), ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions is True: raise ValueError("output_attentions=True can not be supported with BetterTransformer.") @@ -208,41 +207,29 @@ def forward( key_value_states = query_key_value_states[1:] # key_value_states used only for dtype here - cos, sin = self.rotary_emb(key_value_states, seq_len=valid_past_index + q_len) + cos, sin = self.rotary_emb(key_value_states) query_states = apply_rotary_pos_emb_opt(query_states, key_value_states[0], cos, sin, position_ids) - # slice end is equivalent to "if valid_past_index > 0: = valid_past_index + 1; else: = q_len" - past_kv_slice_start = valid_past_index - past_kv_slice_end = torch.eq(valid_past_index, 0).int() * q_len + torch.ne(valid_past_index, 0).int() * (valid_past_index + 1) - past_state_slice_end = torch.eq(valid_past_index, 0).int() * key_value_states.shape[-2] + torch.ne(valid_past_index, 0).int() * (past_kv_slice_end) - past_key_value[..., past_kv_slice_start:past_kv_slice_end, :] = key_value_states - key_states, value_states = past_key_value[..., :past_state_slice_end, :] - - if bsz == 1 or self.training: - # BEWARE: at this stage, attention_mask is not the same as in transformers llama - if query_states.shape[2] > 1: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=True - ) - else: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False - ) + # the key/value states are placed on the last position of the past_key_value tensor (if it is not prefilling) + if past_key_value is None: + past_key_value = key_value_states else: - # This line is necessary for numerical equivalence, although I'm not sure it is useful in any way. - attention_mask = torch.max(attention_mask, self.min_allowed) - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) + # past_key_value[..., -1:, :] = key_value_states -> causes a bug in the inductor, replaced by += and zero setting outside the generation loop + past_key_value[..., -1:, :] += key_value_states + key_states, value_states = past_key_value + + # This line is necessary for numerical equivalence, although I'm not sure it is useful in any way. + attention_mask = torch.max(attention_mask, self.min_allowed) + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - # TODO (felix) returning past_key_value with static cache is probably useless? - return attn_output, None, None + return attn_output, None, past_key_value class LlamaDecoderLayer(nn.Module): @@ -265,7 +252,6 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, - valid_past_index: torch.Tensor = torch.tensor(0, dtype=torch.int64), ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -289,7 +275,6 @@ def forward( position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, - valid_past_index=valid_past_index, ) hidden_states = residual + hidden_states @@ -485,7 +470,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - valid_past_index: torch.Tensor = torch.tensor(0, dtype=torch.int64), ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -504,7 +488,11 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") seq_length_with_past = seq_length - past_key_values_length = valid_past_index + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -523,11 +511,9 @@ def forward( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) - # As we use SDPA, we simply don't care about the attention mask in the batch size = 1 case - if batch_size > 1: - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) hidden_states = inputs_embeds @@ -565,7 +551,6 @@ def custom_forward(*inputs): position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, - valid_past_index=valid_past_index, ) hidden_states = layer_outputs[0] @@ -622,25 +607,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - def get_empty_kv_cache(self, batch_size: int, cache_length: int, dtype: torch.dtype, device: torch.device): - past_key_values = [torch.empty( - 2, - batch_size, - self.config.num_attention_heads, - cache_length, - self.config.hidden_size // self.config.num_attention_heads, # head dimension - dtype=dtype, - device=device - ) - for _ in range(self.config.num_hidden_layers)] - return past_key_values - - def get_preallocated_attention_mask(self, attention_mask: torch.Tensor, batch_size: int, cache_length: int, device: torch.device, context_length: int): - attention_mask_buffer = torch.ones(batch_size, cache_length, dtype=torch.int64, device=device) - attention_mask_buffer[:, :context_length] = attention_mask - - return attention_mask_buffer - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -654,7 +620,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - valid_past_index: torch.Tensor = torch.tensor(0, dtype=torch.int64), ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -698,7 +663,6 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - valid_past_index=valid_past_index, ) hidden_states = outputs[0] @@ -732,13 +696,12 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - valid_past_index = kwargs.get("valid_past_index", torch.tensor(0, dtype=torch.int64)) - if valid_past_index > 0: + if past_key_values is not None: input_ids = input_ids[:, -1:] + # 1st iteration: create the positions ids; subsequent: use it directly (to avoid a cumsum op) position_ids = kwargs.get("position_ids", None) - # create position_ids if position_ids is None: attention_mask_slice = attention_mask[:, :input_ids.shape[1]] position_ids = attention_mask_slice.long().cumsum(-1) - 1 @@ -755,7 +718,6 @@ def prepare_inputs_for_generation( "position_ids": position_ids, "past_key_values": past_key_values, "attention_mask": attention_mask, - "valid_past_index": valid_past_index, } ) return model_inputs