Skip to content

Commit

Permalink
Only enable attention upcasting on models that actually need it.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed May 14, 2024
1 parent b0ab31d commit bb4940d
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 24 deletions.
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,6 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
```embedding:embedding_filename.pt```


## How to increase generation speed?

On non Nvidia hardware you can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything.

```--dont-upcast-attention```

## How to show high-quality previews?

Use ```--preview-method auto``` to enable previews.
Expand Down
1 change: 0 additions & 1 deletion comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __call__(self, parser, namespace, values, option_string=None):
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")

parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")

fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
Expand Down
28 changes: 12 additions & 16 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,6 @@
import comfy.ops
ops = comfy.ops.disable_weight_init

# CrossAttn precision handling
if args.dont_upcast_attention:
logging.info("disabling upcasting of attention")
_ATTN_PRECISION = None
else:
_ATTN_PRECISION = torch.float32


def exists(val):
return val is not None

Expand Down Expand Up @@ -386,10 +378,11 @@ def optimized_attention_for_device(device, mask=False, small_input=False):


class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.attn_precision = attn_precision

self.heads = heads
self.dim_head = dim_head
Expand All @@ -411,15 +404,15 @@ def forward(self, x, context=None, value=None, mask=None):
v = self.to_v(context)

if mask is None:
out = optimized_attention(q, k, v, self.heads, attn_precision=_ATTN_PRECISION)
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else:
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=_ATTN_PRECISION)
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out)


class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()

self.ff_in = ff_in or inner_dim is not None
Expand All @@ -434,7 +427,7 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=

self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
context_dim=context_dim if self.disable_self_attn else None, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)

if disable_temporal_crossattention:
Expand All @@ -448,7 +441,7 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=
context_dim_attn2 = context_dim

self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)

self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
Expand Down Expand Up @@ -588,7 +581,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None, operations=ops):
use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth
Expand All @@ -606,7 +599,7 @@ def __init__(self, in_channels, n_heads, d_head,

self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
for d in range(depth)]
)
if not use_linear:
Expand Down Expand Up @@ -662,6 +655,7 @@ def __init__(
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
attn_precision=None,
dtype=None, device=None, operations=ops
):
super().__init__(
Expand All @@ -674,6 +668,7 @@ def __init__(
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
self.time_depth = time_depth
Expand Down Expand Up @@ -703,6 +698,7 @@ def __init__(
inner_dim=time_mix_inner_dim,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
for _ in range(self.depth)
Expand Down
4 changes: 3 additions & 1 deletion comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def __init__(
video_kernel_size=None,
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
attn_precision=None,
device=None,
operations=ops,
):
Expand Down Expand Up @@ -550,13 +551,14 @@ def get_attention_layer(
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
max_time_embed_period=max_ddpm_temb_period,
attn_precision=attn_precision,
dtype=self.dtype, device=device, operations=operations
)
else:
return SpatialTransformer(
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
)

def get_resblock(
Expand Down
12 changes: 12 additions & 0 deletions comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class SD20(supported_models_base.BASE):
"use_temporal_attention": False,
}

unet_extra_config = {
"num_heads": -1,
"num_head_channels": 64,
"attn_precision": torch.float32,
}

latent_format = latent_formats.SD15

def model_type(self, state_dict, prefix=""):
Expand Down Expand Up @@ -276,6 +282,12 @@ class SVD_img2vid(supported_models_base.BASE):
"use_temporal_resblock": True
}

unet_extra_config = {
"num_heads": -1,
"num_head_channels": 64,
"attn_precision": torch.float32,
}

clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."

latent_format = latent_formats.SD15
Expand Down

0 comments on commit bb4940d

Please sign in to comment.