Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/ar/llm_tutorial_optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ for _ in range(5):
next_token_id = torch.argmax(next_logits, dim=-1)

print("shape of input_ids", next_token_id.shape)
print("length of key-value cache", len(past_key_values[0][0])) # past_key_values are of shape [num_layers, 0 for k, 1 for v, batch_size, length, hidden_dim]
print("length of key-value cache", past_key_values.get_seq_length()) # past_key_values are of shape [num_layers, 0 for k, 1 for v, batch_size, length, hidden_dim]
generated_tokens.append(next_token_id.item())

generated_text = tokenizer.batch_decode(generated_tokens)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/llm_tutorial_optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ for _ in range(5):
next_token_id = torch.argmax(next_logits, dim=-1)

print("shape of input_ids", next_token_id.shape)
print("length of key-value cache", len(past_key_values[0][0])) # past_key_values are of shape [num_layers, 0 for k, 1 for v, batch_size, length, hidden_dim]
print("length of key-value cache", past_key_values.get_seq_length()) # past_key_values are of shape [num_layers, 0 for k, 1 for v, batch_size, length, hidden_dim]
generated_tokens.append(next_token_id.item())

generated_text = tokenizer.batch_decode(generated_tokens)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/ko/llm_tutorial_optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ for _ in range(5):
next_token_id = torch.argmax(next_logits, dim=-1)

print("shape of input_ids", next_token_id.shape)
print("length of key-value cache", len(past_key_values[0][0])) # past_key_values 형태: [num_layers, 0 for k, 1 for v, batch_size, length, hidden_dim]
print("length of key-value cache", past_key_values.get_seq_length()) # past_key_values 형태: [num_layers, 0 for k, 1 for v, batch_size, length, hidden_dim]
generated_tokens.append(next_token_id.item())

generated_text = tokenizer.batch_decode(generated_tokens)
Expand Down
8 changes: 1 addition & 7 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,13 +1689,7 @@ def forward(
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device

past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = (
past_key_values[0][0].shape[-2]
if not isinstance(past_key_values, Cache)
else past_key_values.get_seq_length()
)
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()

if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
Expand Down
8 changes: 1 addition & 7 deletions src/transformers/models/blip/modeling_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,13 +674,7 @@ def forward(
else:
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")

past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = (
past_key_values[0][0].shape[-2]
if not isinstance(past_key_values, Cache)
else past_key_values.get_seq_length()
)
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()

if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length)).to(device)
Expand Down
34 changes: 0 additions & 34 deletions src/transformers/models/ctrl/modeling_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,18 +250,6 @@ def forward(
**kwargs, # NOOP kwargs, for now
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPast]:
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]`
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.

If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as
`input_ids`.

Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
[`PreTrainedTokenizer.encode`] for details.

[What are input IDs?](../glossary#input-ids)

Example:

```python
Expand Down Expand Up @@ -424,17 +412,6 @@ def forward(
**kwargs,
) -> Union[tuple[torch.Tensor], CausalLMOutputWithPast]:
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]`
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.

If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as
`input_ids`.

Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
[`PreTrainedTokenizer.encode`] for details.

[What are input IDs?](../glossary#input-ids)
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
Expand Down Expand Up @@ -572,17 +549,6 @@ def forward(
return_dict: Optional[bool] = None,
) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0].shape[-2]`
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.

If `past_key_values` is used, only input IDs that do not have their past calculated should be passed as
`input_ids`.

Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
[`PreTrainedTokenizer.encode`] for details.

[What are input IDs?](../glossary#input-ids)
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,13 +644,7 @@ def forward(
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device

past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = (
past_key_values[0][0].shape[-2]
if not isinstance(past_key_values, Cache)
else past_key_values.get_seq_length()
)
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()

if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/pix2struct/modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def forward(
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_values[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]

def to_projection_shape(states):
Expand Down
8 changes: 1 addition & 7 deletions src/transformers/models/rembert/modeling_rembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,7 @@ def forward(
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device

past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = (
past_key_values[0][0].shape[-2]
if not isinstance(past_key_values, Cache)
else past_key_values.get_seq_length()
)
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()

if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
Expand Down
8 changes: 1 addition & 7 deletions src/transformers/models/roformer/modeling_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,13 +736,7 @@ def forward(
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device

past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = (
past_key_values[0][0].shape[-2]
if not isinstance(past_key_values, Cache)
else past_key_values.get_seq_length()
)
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()

if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
Expand Down
9 changes: 1 addition & 8 deletions src/transformers/models/speecht5/modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,14 +805,7 @@ def forward(
else:
raise ValueError("You have to specify `decoder_input_ids`")

past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = (
past_key_values[0][0].shape[-2]
if not isinstance(past_key_values, Cache)
else past_key_values.get_seq_length()
)

past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
positions = self.embed_positions(input_ids, past_key_values_length)

inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
Expand Down
21 changes: 13 additions & 8 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4665,7 +4665,7 @@ def test_generate_custom_cache_position(self):
value=1,
)
inputs_2b["past_key_values"] = outputs_1b.past_key_values
cache_length_1b = outputs_1b.past_key_values[0][0].shape[-2]
cache_length_1b = outputs_1b.past_key_values.get_seq_length()
inputs_2b["cache_position"] = torch.arange(
cache_length_1b,
cache_length_1b + inputs_2b["input_ids"].shape[1],
Expand All @@ -4677,14 +4677,19 @@ def test_generate_custom_cache_position(self):

# The two sets of generated text and past kv should be equal to each other
self.assertTrue(has_similar_generate_outputs(traditional_outputs, incremental_outputs))
for layer_idx in range(len(traditional_outputs.past_key_values)):
for kv_idx in range(len(traditional_outputs.past_key_values[layer_idx])):
self.assertTrue(
torch.allclose(
traditional_outputs.past_key_values[layer_idx][kv_idx],
incremental_outputs.past_key_values[layer_idx][kv_idx],
cache1, cache2 = traditional_outputs.past_key_values, incremental_outputs.past_key_values
for idx in range(len(cache1)):
if isinstance(cache1, EncoderDecoderCache):
for subcache in ["self_attention_cache", "cross_attention_cache"]:
torch.testing.assert_close(
getattr(cache1, subcache).layers[idx].keys, getattr(cache2, subcache).layers[idx].keys
)
)
torch.testing.assert_close(
getattr(cache1, subcache).layers[idx].values, getattr(cache2, subcache).layers[idx].values
)
else:
torch.testing.assert_close(cache1.layers[idx].keys, cache2.layers[idx].keys)
torch.testing.assert_close(cache1.layers[idx].values, cache2.layers[idx].values)

@pytest.mark.generate
@parameterized.expand(
Expand Down