Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions scripts/run_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,14 @@ def timing_cuda(
inputs: Dict,
max_new_tokens: int,
device: torch.device,
cache_length: int,
preallocate: bool,
do_profile: bool,
):
warmup_start_event = torch.cuda.Event(enable_timing=True)
warmup_end_event = torch.cuda.Event(enable_timing=True)

if preallocate:
inputs["cache_length"] = cache_length
if do_profile:
num_runs = PROFILE_NUM_RUNS
max_new_tokens = PROFILE_NEW_TOKENS

with torch.no_grad():
print(f"Warming up ({WARMUP_RUNS} runs)...")
Expand Down Expand Up @@ -255,9 +254,7 @@ def timing_cuda(
inputs=inp,
device=device,
max_new_tokens=max_new_tokens,
cache_length=cache_length,
generate_method=generate_method,
preallocate=args.preallocate,
do_profile=args.profile,
)
except:
Expand Down
95 changes: 61 additions & 34 deletions src/trfs_fast/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def generate_minimal(
max_new_tokens: int,
inputs: Optional[torch.Tensor] = None,
streamer: Optional["BaseStreamer"] = None,
cache_length: Optional[int] = None,
**model_kwargs
) -> torch.LongTensor:
r"""
Expand Down Expand Up @@ -113,13 +112,6 @@ def generate_minimal(
if streamer is not None:
streamer.put(input_ids.cpu())

batch_size, context_length = input_ids.shape
cache_length = cache_length if cache_length is not None else max_new_tokens

model_kwargs["valid_past_index"] = torch.tensor(0, dtype=torch.int64)
model_kwargs["past_key_values"] = self.get_empty_kv_cache(batch_size=batch_size, cache_length=cache_length, device=input_ids.device, dtype=self.dtype)
model_kwargs["attention_mask"] = self.get_preallocated_attention_mask(attention_mask=model_kwargs["attention_mask"], batch_size=batch_size, cache_length=cache_length, device=input_ids.device, context_length=context_length)

# 11. run greedy search
return self.greedy_search_minimal(
input_ids,
Expand Down Expand Up @@ -180,8 +172,9 @@ def greedy_search_minimal(

# keep track of which sequences are already finished
unfinished_sequences = torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
max_length = input_ids.shape[-1] + max_new_tokens
min_length = input_ids.shape[-1] + min_new_tokens

counter = 0
result = input_ids
while True:
# prepare model inputs
Expand All @@ -194,7 +187,6 @@ def greedy_search_minimal(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
counter += 1

# argmax
next_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1, keepdim=True)
Expand All @@ -206,14 +198,11 @@ def greedy_search_minimal(
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

input_ids = next_tokens

# update generated ids, model inputs, and length for next step
result = torch.cat([result, next_tokens], dim=-1)
cur_len = result.shape[-1]

if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self.__update_model_kwargs_for_generation(
outputs, model_kwargs, model_inputs
)

# TODO: not sure this is correct anymore with the keepdim=True
# if eos_token was found in one sentence, set sentence to finished
Expand All @@ -223,13 +212,18 @@ def greedy_search_minimal(
)

# stop when each sentence is finished
if unfinished_sequences.max() == 0 and counter >= min_new_tokens:
if unfinished_sequences.max() == 0 and cur_len >= min_length:
break

# stop if we exceed the maximum length
if counter >= max_new_tokens:
if cur_len >= max_length:
break

# Update tensors for the next iteration
model_kwargs = self.__update_model_kwargs_for_generation(
outputs, model_kwargs, model_inputs, max_length, cur_len
)

if streamer is not None:
streamer.end()

Expand All @@ -239,26 +233,43 @@ def __update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
model_inputs: Dict[str, Any]
model_inputs: Dict[str, Any],
max_length: int,
cur_len: int
) -> Dict[str, Any]:
model_kwargs["valid_past_index"] += outputs.logits.shape[1]

if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state

# update attention mask
"""
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
"""
position_ids = model_inputs["position_ids"]
if position_ids.shape[1] > 1:
model_kwargs["position_ids"] = position_ids[:, -1:] + 1
batch_size, _, _ = outputs.logits.shape
device = outputs.logits.device

# Create and fill fixed-sized tensors. This only occurs once, after the prefilling step. Note: at each
# generation step, in the attention layer, the KV values will be placed in the last position of
# `past_key_values` -- for that reason, the attention mask must always hold a 1 in the last position.
if "past_key_values" in model_inputs and model_inputs["past_key_values"] is None:
# create tensors for the maximum size
padded_attention = torch.zeros(batch_size, max_length, dtype=torch.int64, device=device)
padded_past_key_values = get_empty_kv_cache(config=self.config, batch_size=batch_size, max_length=max_length, device=device, dtype=self.dtype)
# fill with the existing values
padded_attention[:, :cur_len - 1] = model_inputs["attention_mask"]
padded_attention[:, -1] = 1 # the token being generated is appened in the last postion
for i in range(len(padded_past_key_values)):
padded_past_key_values[i][..., :cur_len - 1, :] = outputs.past_key_values[i]
# set them to the variable expected by `generate`
model_kwargs["attention_mask"] = padded_attention
model_kwargs["past_key_values"] = padded_past_key_values

# also update the positions ids, from the previous position ids
model_kwargs["position_ids"] = model_inputs["position_ids"][:, -1] + 1
else:
model_kwargs["position_ids"] = position_ids + 1
# Position ids update: simply add one
model_kwargs["position_ids"] += 1
# Attention mask update: add a one in the position to backfill (corresponding to the token that was just
# selected)
backfill_pos = cur_len - 2
model_kwargs["attention_mask"][:, backfill_pos] = 1
# past_key_values update: Move the cache appended on the last position to its permanent position
for i in range(len(model_kwargs["past_key_values"])):
model_kwargs["past_key_values"][i][..., backfill_pos, :] = outputs.past_key_values[i][..., -1, :]
model_kwargs["past_key_values"][i][..., -1, :] *= 0 # see inductor bug mentioned in the attn layer

# NOTE: token_type_ids is not used by llama so we don't care about this one for now
# update token_type_ids with last value
Expand All @@ -267,3 +278,19 @@ def __update_model_kwargs_for_generation(
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

return model_kwargs


def get_empty_kv_cache(config, batch_size: int, max_length: int, dtype: torch.dtype, device: torch.device):
past_key_values = [
torch.zeros(
2,
batch_size,
config.num_attention_heads,
max_length,
config.hidden_size // config.num_attention_heads, # head dimension
dtype=dtype,
device=device
)
for _ in range(config.num_hidden_layers)
]
return past_key_values
86 changes: 24 additions & 62 deletions src/trfs_fast/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def forward(
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
valid_past_index: torch.Tensor = torch.tensor(0, dtype=torch.int64),
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
Expand All @@ -208,41 +207,29 @@ def forward(
key_value_states = query_key_value_states[1:]

# key_value_states used only for dtype here
cos, sin = self.rotary_emb(key_value_states, seq_len=valid_past_index + q_len)
cos, sin = self.rotary_emb(key_value_states)
query_states = apply_rotary_pos_emb_opt(query_states, key_value_states[0], cos, sin, position_ids)

# slice end is equivalent to "if valid_past_index > 0: = valid_past_index + 1; else: = q_len"
past_kv_slice_start = valid_past_index
past_kv_slice_end = torch.eq(valid_past_index, 0).int() * q_len + torch.ne(valid_past_index, 0).int() * (valid_past_index + 1)
past_state_slice_end = torch.eq(valid_past_index, 0).int() * key_value_states.shape[-2] + torch.ne(valid_past_index, 0).int() * (past_kv_slice_end)
past_key_value[..., past_kv_slice_start:past_kv_slice_end, :] = key_value_states
key_states, value_states = past_key_value[..., :past_state_slice_end, :]

if bsz == 1 or self.training:
# BEWARE: at this stage, attention_mask is not the same as in transformers llama
if query_states.shape[2] > 1:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=True
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=None, dropout_p=0.0, is_causal=False
)
# the key/value states are placed on the last position of the past_key_value tensor (if it is not prefilling)
if past_key_value is None:
past_key_value = key_value_states
else:
# This line is necessary for numerical equivalence, although I'm not sure it is useful in any way.
attention_mask = torch.max(attention_mask, self.min_allowed)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# past_key_value[..., -1:, :] = key_value_states -> causes a bug in the inductor, replaced by += and zero setting outside the generation loop
past_key_value[..., -1:, :] += key_value_states
key_states, value_states = past_key_value

# This line is necessary for numerical equivalence, although I'm not sure it is useful in any way.
attention_mask = torch.max(attention_mask, self.min_allowed)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

# TODO (felix) returning past_key_value with static cache is probably useless?
return attn_output, None, None
return attn_output, None, past_key_value


class LlamaDecoderLayer(nn.Module):
Expand All @@ -265,7 +252,6 @@ def forward(
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
valid_past_index: torch.Tensor = torch.tensor(0, dtype=torch.int64),
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand All @@ -289,7 +275,6 @@ def forward(
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
valid_past_index=valid_past_index,
)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -485,7 +470,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
valid_past_index: torch.Tensor = torch.tensor(0, dtype=torch.int64),
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -504,7 +488,11 @@ def forward(
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

seq_length_with_past = seq_length
past_key_values_length = valid_past_index
past_key_values_length = 0

if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
Expand All @@ -523,11 +511,9 @@ def forward(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)

# As we use SDPA, we simply don't care about the attention mask in the batch size = 1 case
if batch_size > 1:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)

hidden_states = inputs_embeds

Expand Down Expand Up @@ -565,7 +551,6 @@ def custom_forward(*inputs):
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
valid_past_index=valid_past_index,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -622,25 +607,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.model

def get_empty_kv_cache(self, batch_size: int, cache_length: int, dtype: torch.dtype, device: torch.device):
past_key_values = [torch.empty(
2,
batch_size,
self.config.num_attention_heads,
cache_length,
self.config.hidden_size // self.config.num_attention_heads, # head dimension
dtype=dtype,
device=device
)
for _ in range(self.config.num_hidden_layers)]
return past_key_values

def get_preallocated_attention_mask(self, attention_mask: torch.Tensor, batch_size: int, cache_length: int, device: torch.device, context_length: int):
attention_mask_buffer = torch.ones(batch_size, cache_length, dtype=torch.int64, device=device)
attention_mask_buffer[:, :context_length] = attention_mask

return attention_mask_buffer

@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Expand All @@ -654,7 +620,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
valid_past_index: torch.Tensor = torch.tensor(0, dtype=torch.int64),
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -698,7 +663,6 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
valid_past_index=valid_past_index,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -732,13 +696,12 @@ def forward(
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
valid_past_index = kwargs.get("valid_past_index", torch.tensor(0, dtype=torch.int64))

if valid_past_index > 0:
if past_key_values is not None:
input_ids = input_ids[:, -1:]

# 1st iteration: create the positions ids; subsequent: use it directly (to avoid a cumsum op)
position_ids = kwargs.get("position_ids", None)
# create position_ids
if position_ids is None:
attention_mask_slice = attention_mask[:, :input_ids.shape[1]]
position_ids = attention_mask_slice.long().cumsum(-1) - 1
Expand All @@ -755,7 +718,6 @@ def prepare_inputs_for_generation(
"position_ids": position_ids,
"past_key_values": past_key_values,
"attention_mask": attention_mask,
"valid_past_index": valid_past_index,
}
)
return model_inputs