Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion comfy/ldm/ace/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,15 @@ def forward(
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
transformer_options={},
**cross_attention_kwargs,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
transformer_options=transformer_options,
**cross_attention_kwargs,
)

Expand Down Expand Up @@ -366,6 +368,7 @@ def __call__(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
transformer_options={},
*args,
**kwargs,
) -> torch.Tensor:
Expand Down Expand Up @@ -433,7 +436,7 @@ def __call__(

# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = optimized_attention(
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
).to(query.dtype)

# linear proj
Expand Down Expand Up @@ -697,6 +700,7 @@ def forward(
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None,
transformer_options={},
):

N = hidden_states.shape[0]
Expand All @@ -720,6 +724,7 @@ def forward(
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
else:
attn_output, _ = self.attn(
Expand All @@ -729,6 +734,7 @@ def forward(
encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None,
transformer_options=transformer_options,
)

if self.use_adaln_single:
Expand All @@ -743,6 +749,7 @@ def forward(
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
hidden_states = attn_output + hidden_states

Expand Down
4 changes: 4 additions & 0 deletions comfy/ldm/ace/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def decode(
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
transformer_options={},
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
Expand All @@ -339,6 +340,7 @@ def decode(
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
transformer_options=transformer_options,
)

output = self.final_layer(hidden_states, embedded_timestep, output_length)
Expand Down Expand Up @@ -393,6 +395,7 @@ def _forward(

output_length = hidden_states.shape[-1]

transformer_options = kwargs.get("transformer_options", {})
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
Expand All @@ -402,6 +405,7 @@ def _forward(
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
transformer_options=transformer_options,
)

return output
25 changes: 14 additions & 11 deletions comfy/ldm/audio/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ def forward(
mask = None,
context_mask = None,
rotary_pos_emb = None,
causal = None
causal = None,
transformer_options={},
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None

Expand Down Expand Up @@ -363,7 +364,7 @@ def forward(
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))

out = optimized_attention(q, k, v, h, skip_reshape=True)
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out = self.to_out(out)

if mask is not None:
Expand Down Expand Up @@ -488,7 +489,8 @@ def forward(
global_cond=None,
mask = None,
context_mask = None,
rotary_pos_emb = None
rotary_pos_emb = None,
transformer_options={}
):
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:

Expand All @@ -498,12 +500,12 @@ def forward(
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
x = x * torch.sigmoid(1 - gate_self)
x = x + residual

if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)

if self.conformer is not None:
x = x + self.conformer(x)
Expand All @@ -517,10 +519,10 @@ def forward(
x = x + residual

else:
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)

if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)

if self.conformer is not None:
x = x + self.conformer(x)
Expand Down Expand Up @@ -606,7 +608,8 @@ def forward(
return_info = False,
**kwargs
):
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
transformer_options = kwargs.get("transformer_options", {})
patches_replace = transformer_options.get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]

Expand Down Expand Up @@ -645,13 +648,13 @@ def forward(
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
return out

out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)

if return_info:
Expand Down
29 changes: 15 additions & 14 deletions comfy/ldm/aura/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, opera
)

#@torch.compile()
def forward(self, c):
def forward(self, c, transformer_options={}):

bsz, seqlen1, _ = c.shape

Expand All @@ -95,7 +95,7 @@ def forward(self, c):
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k)

output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c = self.w1o(output)
return c

Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, opera


#@torch.compile()
def forward(self, c, x):
def forward(self, c, x, transformer_options={}):

bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
Expand All @@ -168,7 +168,7 @@ def forward(self, c, x):
torch.cat([cv, xv], dim=1),
)

output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)

c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c)
Expand Down Expand Up @@ -207,7 +207,7 @@ def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None,
self.is_last = is_last

#@torch.compile()
def forward(self, c, x, global_cond, **kwargs):
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):

cres, xres = c, x

Expand All @@ -225,7 +225,7 @@ def forward(self, c, x, global_cond, **kwargs):
x = modulate(self.normX1(x), xshift_msa, xscale_msa)

# attention
c, x = self.attn(c, x)
c, x = self.attn(c, x, transformer_options=transformer_options)


c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
Expand Down Expand Up @@ -255,13 +255,13 @@ def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, o
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)

#@torch.compile()
def forward(self, cx, global_cond, **kwargs):
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond
).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
cx = self.attn(cx)
cx = self.attn(cx, transformer_options=transformer_options)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout
Expand Down Expand Up @@ -473,13 +473,14 @@ def block_wrap(args):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
args["vec"])
args["vec"],
transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
c = out["txt"]
x = out["img"]
else:
c, x = layer(c, x, global_cond, **kwargs)
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)

if len(self.single_layers) > 0:
c_len = c.size(1)
Expand All @@ -488,13 +489,13 @@ def block_wrap(args):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], args["vec"])
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
return out

out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
cx = out["img"]
else:
cx = layer(cx, global_cond, **kwargs)
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)

x = cx[:, c_len:]

Expand Down
12 changes: 6 additions & 6 deletions comfy/ldm/cascade/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=No

self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)

def forward(self, q, k, v):
def forward(self, q, k, v, transformer_options={}):
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)

out = optimized_attention(q, k, v, self.heads)
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)

return self.out_proj(out)

Expand All @@ -47,13 +47,13 @@ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=No
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)

def forward(self, x, kv, self_attn=False):
def forward(self, x, kv, self_attn=False, transformer_options={}):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv)
x = self.attn(x, kv, kv, transformer_options=transformer_options)
x = x.permute(0, 2, 1).view(*orig_shape)
return x

Expand Down Expand Up @@ -114,9 +114,9 @@ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, de
operations.Linear(c_cond, c, dtype=dtype, device=device)
)

def forward(self, x, kv):
def forward(self, x, kv, transformer_options={}):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
return x


Expand Down
14 changes: 7 additions & 7 deletions comfy/ldm/cascade/stage_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def gen_c_embeddings(self, clip):
clip = self.clip_norm(clip)
return clip

def _down_encode(self, x, r_embed, clip):
def _down_encode(self, x, r_embed, clip, transformer_options={}):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
Expand All @@ -187,7 +187,7 @@ def _down_encode(self, x, r_embed, clip):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
Expand All @@ -199,7 +199,7 @@ def _down_encode(self, x, r_embed, clip):
level_outputs.insert(0, x)
return level_outputs

def _up_decode(self, level_outputs, r_embed, clip):
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
Expand All @@ -216,7 +216,7 @@ def _up_decode(self, level_outputs, r_embed, clip):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
Expand All @@ -228,7 +228,7 @@ def _up_decode(self, level_outputs, r_embed, clip):
x = upscaler(x)
return x

def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8)

Expand All @@ -245,8 +245,8 @@ def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
align_corners=True)
level_outputs = self._down_encode(x, r_embed, clip)
x = self._up_decode(level_outputs, r_embed, clip)
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
return self.clf(x)

def update_weights_ema(self, src_model, beta=0.999):
Expand Down
Loading
Loading