Skip to content

Commit c60dc41

Browse files
Remove unecessary clones in the wan2.2 VAE. (comfyanonymous#9083)
1 parent 5d4cc3b commit c60dc41

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

comfy/ldm/wan/vae2_2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(self, in_dim, out_dim, dropout=0.0):
136136
if in_dim != out_dim else nn.Identity())
137137

138138
def forward(self, x, feat_cache=None, feat_idx=[0]):
139-
h = self.shortcut(x)
139+
old_x = x
140140
for layer in self.residual:
141141
if isinstance(layer, CausalConv3d) and feat_cache is not None:
142142
idx = feat_idx[0]
@@ -156,7 +156,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
156156
feat_idx[0] += 1
157157
else:
158158
x = layer(x)
159-
return x + h
159+
return x + self.shortcut(old_x)
160160

161161

162162
def patchify(x, patch_size):
@@ -327,7 +327,7 @@ def __init__(self,
327327
self.downsamples = nn.Sequential(*downsamples)
328328

329329
def forward(self, x, feat_cache=None, feat_idx=[0]):
330-
x_copy = x.clone()
330+
x_copy = x
331331
for module in self.downsamples:
332332
x = module(x, feat_cache, feat_idx)
333333

@@ -369,7 +369,7 @@ def __init__(self,
369369
self.upsamples = nn.Sequential(*upsamples)
370370

371371
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
372-
x_main = x.clone()
372+
x_main = x
373373
for module in self.upsamples:
374374
x_main = module(x_main, feat_cache, feat_idx)
375375
if self.avg_shortcut is not None:

0 commit comments

Comments
 (0)