Skip to content

Commit

Permalink
Cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed May 18, 2024
1 parent 98f828f commit 0bdc2b1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
10 changes: 3 additions & 7 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional, Any
import logging

from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention

from comfy import model_management
Expand Down Expand Up @@ -454,15 +454,11 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=

self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa

def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)

def _forward(self, x, context=None, transformer_options={}):
extra_options = {}
block = transformer_options.get("block", None)
block_index = transformer_options.get("block_index", 0)
Expand Down Expand Up @@ -629,15 +625,15 @@ def forward(self, x, context=None, transformer_options={}):
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
x = x.movedim(1, -1).flatten(1, 2).contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
x = block(x, context=context[i], transformer_options=transformer_options)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(-1, 1).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
Expand Down
2 changes: 1 addition & 1 deletion comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _forward(self, x, emb):
else:
if emb_out is not None:
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
emb_out = emb_out.movedim(1, 2)
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
Expand Down

0 comments on commit 0bdc2b1

Please sign in to comment.