Skip to content
4 changes: 2 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ def _validate_assistant(self, assistant_model):
"Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
)

if not self.config.vocab_size == assistant_model.config.vocab_size:
if not self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size:
raise ValueError("Make sure the main and assistant model use the same tokenizer")

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
Expand Down Expand Up @@ -1473,7 +1473,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None):
layer_device_map = get_layer_device_map(execution_device_map)

cache_kwargs = {
"config": self.config if hasattr(self.config, "text_config") else self.config,
"config": self.config.get_text_config(),
"max_batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
Expand Down
126 changes: 111 additions & 15 deletions src/transformers/models/paligemma/modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,74 @@
_CONFIG_FOR_DOC = "PaliGemmaConfig"


# Adapted from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
# But Paligemma has no causal mask on prefix
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
is_training: bool,
token_type_ids: torch.Tensor,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
is_training (`bool`):
Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels`
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
if is_training:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
causal_mask = torch.zeros_like(causal_mask)

causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
if is_training:
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
)
return causal_mask


@dataclass
class PaliGemmaCausalLMOutputWithPast(ModelOutput):
"""
Expand Down Expand Up @@ -285,7 +353,7 @@ def _update_causal_mask(
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
):
using_static_cache = isinstance(past_key_values, StaticCache)
dtype, device = inputs_embeds.dtype, inputs_embeds.device
dtype = inputs_embeds.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = inputs_embeds.shape[1]
if using_static_cache:
Expand All @@ -299,19 +367,19 @@ def _update_causal_mask(

if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
if is_training:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
causal_mask = torch.zeros_like(causal_mask)

causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
return attention_mask

causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
if is_training:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
causal_mask = torch.zeros_like(causal_mask)

causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
Expand Down Expand Up @@ -420,7 +488,8 @@ def forward(
image_features = self.multi_modal_projector(selected_image_feature)
image_features = image_features / (self.config.hidden_size**0.5)

special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
raise ValueError(
Expand Down Expand Up @@ -508,11 +577,38 @@ def prepare_inputs_for_generation(
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
use_cache=use_cache,
num_logits_to_keep=num_logits_to_keep,
**kwargs,
)

if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device

dtype = self.get_output_embeddings().weight.dtype
min_dtype = torch.finfo(dtype).min
is_training = token_type_ids is not None and kwargs.get("labels", None) is not None

model_inputs["attention_mask"] = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
is_training=is_training,
token_type_ids=token_type_ids,
)

model_inputs["token_type_ids"] = token_type_ids

# position_ids in Paligemma are 1-indexed
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,9 @@ def __init__(self, config) -> None:
self.blocks = nn.ModuleList(
[Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
)
self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
self.merger = PatchMerger(
dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
)

def get_dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype
Expand Down
Loading
Loading