Skip to content

Commit

Permalink
Fix SAG.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed May 14, 2024
1 parent bb4940d commit ec6f16a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
6 changes: 4 additions & 2 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,14 +420,15 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=
inner_dim = dim

self.is_res = inner_dim == dim
self.attn_precision = attn_precision

if self.ff_in:
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)

self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)

if disable_temporal_crossattention:
Expand All @@ -441,7 +442,7 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=
context_dim_attn2 = context_dim

self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)

self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
Expand Down Expand Up @@ -471,6 +472,7 @@ def _forward(self, x, context=None, transformer_options={}):

extra_options["n_heads"] = self.n_heads
extra_options["dim_head"] = self.d_head
extra_options["attn_precision"] = self.attn_precision

if self.ff_in:
x_skip = x
Expand Down
10 changes: 5 additions & 5 deletions comfy_extras/nodes_sag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

from einops import rearrange, repeat
import os
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
from comfy.ldm.modules.attention import optimized_attention
import comfy.samplers

# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
def attention_basic_with_sim(q, k, v, heads, mask=None):
def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
Expand All @@ -26,7 +26,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
)

# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
if attn_precision == torch.float32:
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
Expand Down Expand Up @@ -121,13 +121,13 @@ def attn_and_record(q, k, v, extra_options):
if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
return out
else:
return optimized_attention(q, k, v, heads=heads)
return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])

def post_cfg_function(args):
nonlocal attn_scores
Expand Down

0 comments on commit ec6f16a

Please sign in to comment.