Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ dist/
*.egg-info/
.DS_Store/
.pytest_cache/
.ruff_cache/
.ruff_cache/
CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
"vision_start_token_id": 151652,
"vision_end_token_id": 151653,
"image_token_id": 151655,
"video_token_id": 151656
"video_token_id": 151656,
"attn_impl": "sdpa"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"do_convert_rgb": true,
"do_normalize": true,
"do_rescale": true,
"do_resize": true,
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_processor_type": "Qwen2VLImageProcessor",
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"max_pixels": 12845056,
"merge_size": 2,
"min_pixels": 3136,
"patch_size": 14,
"processor_class": "Qwen2_5_VLProcessor",
"resample": 3,
"rescale_factor": 0.00392156862745098,
"size": {
"longest_edge": 12845056,
"shortest_edge": 3136
},
"temporal_patch_size": 2
}
6 changes: 3 additions & 3 deletions diffsynth_engine/models/basic/attention.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from typing import Optional

import torch.nn.functional as F
from diffsynth_engine.utils import logging
from diffsynth_engine.utils.flag import (
FLASH_ATTN_3_AVAILABLE,
Expand Down Expand Up @@ -42,11 +42,11 @@ def xformers_attn(q, k, v, attn_mask=None, scale=None):

if SDPA_AVAILABLE:

def sdpa_attn(q, k, v, attn_mask=None, scale=None):
def sdpa_attn(q, k, v, attn_mask=None, is_causal=False, scale=None):
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=scale)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=is_causal, scale=scale)
return out.transpose(1, 2)


Expand Down
98 changes: 41 additions & 57 deletions diffsynth_engine/models/qwen_image/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from diffsynth_engine.models.base import PreTrainedModel
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
from diffsynth_engine.models.basic.attention import attention
from diffsynth_engine.models.basic import attention as attention_ops
from diffsynth_engine.models.utils import no_init_weights
from diffsynth_engine.utils.cache import Cache, DynamicCache
from diffsynth_engine.utils import logging
Expand Down Expand Up @@ -152,17 +152,15 @@ def __init__(
self,
dim: int = 80,
theta: float = 10000.0,
device: str = "cuda:0",
dtype: torch.dtype = torch.bfloat16,
):
super().__init__()
with torch.device(device):
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
with torch.device("cpu"):
self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))

def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
def forward(self, seqlen: int, device: str) -> torch.Tensor:
inv_freq = self.inv_freq.to(device=device)
seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype)
freqs = torch.outer(seq, inv_freq)
return freqs


Expand Down Expand Up @@ -222,7 +220,7 @@ def forward(
q = rearrange(q, "s n d -> 1 s n d")
k = rearrange(k, "s n d -> 1 s n d")
v = rearrange(v, "s n d -> 1 s n d")
out = attention(q, k, v, attn_impl=self.attn_impl, attn_mask=attention_mask)
out = attention_ops.attention(q, k, v, attn_impl=self.attn_impl, attn_mask=attention_mask)
out = rearrange(out, "1 s n d -> s (n d)")
out = self.proj(out)
return out
Expand Down Expand Up @@ -301,7 +299,7 @@ def __init__(self, config: Qwen2_5_VLVisionConfig, device: str = "cuda:0", dtype
dtype=dtype,
)
head_dim = config.hidden_size // config.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2, device=device, dtype=dtype)
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList(
[
Qwen2_5_VisionBlock(
Expand Down Expand Up @@ -348,7 +346,7 @@ def rot_pos_emb(self, grid_thw):
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device=grid_thw.device)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb

Expand Down Expand Up @@ -488,7 +486,6 @@ def __init__(
hidden_size: int = 3584,
num_attention_heads: int = 28,
num_key_value_heads: int = 4,
# dropout: float = 0.0,
mrope_section: List[int] = [16, 24, 24],
attn_impl: Optional[str] = None,
device: str = "cuda:0",
Expand All @@ -501,7 +498,6 @@ def __init__(
self.head_dim = hidden_size // num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = num_attention_heads // num_key_value_heads
# self.dropout = dropout
self.mrope_section = mrope_section
self.attn_impl = attn_impl

Expand All @@ -521,8 +517,6 @@ def __init__(
self.num_attention_heads * self.head_dim, self.hidden_size, bias=False, device=device, dtype=dtype
)

self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=self.head_dim, device=device, dtype=dtype)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -556,14 +550,18 @@ def forward(
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[1]]

# TODO: attention_mask for flash attention 2
out = attention(
query_states,
key_states,
value_states,
attn_impl=self.attn_impl,
attn_mask=causal_mask,
)
# TODO: use is_causal when attention mask is causal
if self.attn_impl == "sdpa":
out = attention_ops.sdpa_attn(query_states, key_states, value_states, is_causal=True)
else:
# TODO: attention_mask for flash attention 2
out = attention_ops.attention(
query_states,
key_states,
value_states,
attn_impl=self.attn_impl,
attn_mask=causal_mask,
)
out = rearrange(out, "b s n d -> b s (n d)")
out = self.o_proj(out)
return out, past_key_values
Expand Down Expand Up @@ -647,29 +645,29 @@ def forward(


class Qwen2_5_VLRotaryEmbedding(nn.Module):
def __init__(self, dim: int = 128, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
def __init__(self, dim: int = 128):
super().__init__()
with torch.device(device):
inv_freq = self.compute_rope(dim) # default rope without dynamic frequency
self.register_buffer("inv_freq", inv_freq, persistent=False)
with torch.device("cpu"):
self.inv_freq = self.compute_rope(dim) # default rope without dynamic frequency

def compute_rope(self, dim: int, theta: float = 1000000.0):
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
return inv_freq

@torch.no_grad()
def forward(self, x, position_ids):
def forward(self, position_ids: torch.LongTensor, device: str, dtype: torch.dtype):
# In contrast to other models, Qwen2_5_VL has different position ids for the grids
# So we expand the inv_freq to shape (3, ...)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
inv_freq = self.inv_freq.to(device=device)
inv_freq_expanded = inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)

freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(2, 3)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()

return cos.to(device=x.device, dtype=x.dtype), sin.to(device=x.device, dtype=x.dtype)
return cos.to(device=device, dtype=dtype), sin.to(device=device, dtype=dtype)


class Qwen2_5_VLModel(nn.Module):
Expand Down Expand Up @@ -702,7 +700,7 @@ def __init__(self, config: Qwen2_5_VLConfig, device: str = "cuda:0", dtype: torc
)
self.norm = Qwen2_5_RMSNorm(config.hidden_size, config.rms_norm_eps, device=device, dtype=dtype)
head_dim = config.hidden_size // config.num_attention_heads
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=head_dim, device=device, dtype=dtype)
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=head_dim)

def get_input_embeddings(self):
return self.embed_tokens
Expand Down Expand Up @@ -749,7 +747,7 @@ def forward(
hidden_states = inputs_embeds

# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
position_embeddings = self.rotary_emb(position_ids, device=hidden_states.device, dtype=hidden_states.dtype)

# decoder layers
for decoder_layer in self.layers:
Expand Down Expand Up @@ -940,8 +938,7 @@ def from_state_dict(
with torch.device("meta"), no_init_weights():
model = cls(vision_config=vision_config, config=config, device=device, dtype=dtype)
model.load_state_dict(state_dict, assign=True)
for param in model.parameters(): # skip buffers
param.data = param.data.to(device=device, dtype=dtype, non_blocking=True)
model.to(device=device, dtype=dtype, non_blocking=True)
return model

def get_input_embeddings(self):
Expand Down Expand Up @@ -1202,27 +1199,14 @@ def forward(
if position_ids is None:
assert attention_mask is None or attention_mask.ndim == 2, "attention mask must be 2D"
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
attention_mask,
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
attention_mask,
)
self.rope_deltas = rope_deltas
Comment on lines 1199 to +1209
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for handling subsequent generation steps (incremental decoding) has been removed. The previous implementation had an else block to calculate position_ids when cache_position was not 0. This change limits the model to only perform pre-fill (single forward pass), which might break any autoregressive text generation capabilities. If this model is intended to be used for multi-step generation, this is a significant regression. Was this intentional?


hidden_states, present_key_values = self.model(
input_ids=None,
Expand Down
Loading