Skip to content

Commit

Permalink
Refactor attention upcasting code part 1.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed May 14, 2024
1 parent 2de3b69 commit b0ab31d
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
# CrossAttn precision handling
if args.dont_upcast_attention:
logging.info("disabling upcasting of attention")
_ATTN_PRECISION = "fp16"
_ATTN_PRECISION = None
else:
_ATTN_PRECISION = "fp32"
_ATTN_PRECISION = torch.float32


def exists(val):
Expand Down Expand Up @@ -85,7 +85,7 @@ def forward(self, x):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)

def attention_basic(q, k, v, heads, mask=None):
def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
Expand All @@ -101,7 +101,7 @@ def attention_basic(q, k, v, heads, mask=None):
)

# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
if attn_precision == torch.float32:
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
Expand Down Expand Up @@ -135,7 +135,7 @@ def attention_basic(q, k, v, heads, mask=None):
return out


def attention_sub_quad(query, key, value, heads, mask=None):
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None):
b, _, dim_head = query.shape
dim_head //= heads

Expand All @@ -146,7 +146,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)

dtype = query.dtype
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8
else:
Expand Down Expand Up @@ -195,7 +195,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states

def attention_split(q, k, v, heads, mask=None):
def attention_split(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
Expand All @@ -214,10 +214,12 @@ def attention_split(q, k, v, heads, mask=None):

mem_free_total = model_management.get_free_memory(q.device)

if _ATTN_PRECISION =="fp32":
if attn_precision == torch.float32:
element_size = 4
upcast = True
else:
element_size = q.element_size()
upcast = False

gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
Expand Down Expand Up @@ -251,7 +253,7 @@ def attention_split(q, k, v, heads, mask=None):
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
if _ATTN_PRECISION =="fp32":
if upcast:
with torch.autocast(enabled=False, device_type = 'cuda'):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else:
Expand Down Expand Up @@ -302,7 +304,7 @@ def attention_split(q, k, v, heads, mask=None):
except:
pass

def attention_xformers(q, k, v, heads, mask=None):
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
if BROKEN_XFORMERS:
Expand Down Expand Up @@ -334,7 +336,7 @@ def attention_xformers(q, k, v, heads, mask=None):
)
return out

def attention_pytorch(q, k, v, heads, mask=None):
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None):
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
Expand Down Expand Up @@ -409,9 +411,9 @@ def forward(self, x, context=None, value=None, mask=None):
v = self.to_v(context)

if mask is None:
out = optimized_attention(q, k, v, self.heads)
out = optimized_attention(q, k, v, self.heads, attn_precision=_ATTN_PRECISION)
else:
out = optimized_attention_masked(q, k, v, self.heads, mask)
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=_ATTN_PRECISION)
return self.to_out(out)


Expand Down

0 comments on commit b0ab31d

Please sign in to comment.