Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,10 @@ def forward(
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
if self.config._attn_implementation == "flash_attention_2":
self.is_causal = causal_attention_mask is not None
if "flash" in self.config._attn_implementation:
is_causal = causal_attention_mask is not None
else:
is_causal = self.is_causal
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
Expand All @@ -337,7 +338,7 @@ def forward(
keys,
values,
attention_mask,
is_causal=self.is_causal,
is_causal=is_causal,
Comment on lines -340 to +341
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do is_causal=causal_attention_mask is not None or self.is_text_attention for all cases of attentions. Rn prob it is not differentiating correctly between vision or text attention because the presence of causal mask is not a marker with new masking utility

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be careful with inline conditionals, they have caused troubles with torch compile in the past 😓

scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
output_attentions=output_attentions,
Expand Down Expand Up @@ -611,7 +612,7 @@ def forward(
)

# expand attention_mask
if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
if attention_mask is not None and "flash" not in self.config._attn_implementation:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)

Expand Down
7 changes: 4 additions & 3 deletions src/transformers/models/metaclip_2/modeling_metaclip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,10 @@ def forward(
values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# METACLIP_2 text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
if self.config._attn_implementation == "flash_attention_2":
self.is_causal = causal_attention_mask is not None
if "flash" in self.config._attn_implementation:
is_causal = causal_attention_mask is not None
else:
is_causal = self.is_causal
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
Expand All @@ -237,7 +238,7 @@ def forward(
keys,
values,
attention_mask,
is_causal=self.is_causal,
is_causal=is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
output_attentions=output_attentions,
Expand Down