From 0bdc2b15c75426af75a326d5966ad47aab5b76d3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 18 May 2024 10:11:44 -0400 Subject: [PATCH] Cleanup. --- comfy/ldm/modules/attention.py | 10 +++------- comfy/ldm/modules/diffusionmodules/openaimodel.py | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 88ee2f32d5..aa74b63233 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -6,7 +6,7 @@ from typing import Optional, Any import logging -from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding +from .diffusionmodules.util import AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management @@ -454,15 +454,11 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff= self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) - self.checkpoint = checkpoint self.n_heads = n_heads self.d_head = d_head self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa def forward(self, x, context=None, transformer_options={}): - return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) - - def _forward(self, x, context=None, transformer_options={}): extra_options = {} block = transformer_options.get("block", None) block_index = transformer_options.get("block_index", 0) @@ -629,7 +625,7 @@ def forward(self, x, context=None, transformer_options={}): x = self.norm(x) if not self.use_linear: x = self.proj_in(x) - x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + x = x.movedim(1, -1).flatten(1, 2).contiguous() if self.use_linear: x = self.proj_in(x) for i, block in enumerate(self.transformer_blocks): @@ -637,7 +633,7 @@ def forward(self, x, context=None, transformer_options={}): x = block(x, context=context[i], transformer_options=transformer_options) if self.use_linear: x = self.proj_out(x) - x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(-1, 1).contiguous() if not self.use_linear: x = self.proj_out(x) return x + x_in diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 1f5a4ded29..ba8fc2c4a0 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -258,7 +258,7 @@ def _forward(self, x, emb): else: if emb_out is not None: if self.exchange_temb_dims: - emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + emb_out = emb_out.movedim(1, 2) h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h