Skip to content

Commit

Permalink
Add debug options to force on and off attention upcasting.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed May 16, 2024
1 parent 58f8388 commit 46daf0a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
5 changes: 5 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ class LatentPreviewMethod(enum.Enum):

parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")

upcast = parser.add_mutually_exclusive_group()
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")


vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
Expand Down
14 changes: 14 additions & 0 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
import comfy.ops
ops = comfy.ops.disable_weight_init


def get_attn_precision(attn_precision):
if args.dont_upcast_attention:
return None
if attn_precision is None and args.force_upcast_attention:
return torch.float32
return attn_precision

def exists(val):
return val is not None

Expand Down Expand Up @@ -78,6 +86,8 @@ def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)

def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
attn_precision = get_attn_precision(attn_precision)

b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
Expand Down Expand Up @@ -128,6 +138,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None):


def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None):
attn_precision = get_attn_precision(attn_precision)

b, _, dim_head = query.shape
dim_head //= heads

Expand Down Expand Up @@ -188,6 +200,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None)
return hidden_states

def attention_split(q, k, v, heads, mask=None, attn_precision=None):
attn_precision = get_attn_precision(attn_precision)

b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
Expand Down

0 comments on commit 46daf0a

Please sign in to comment.