@@ -136,7 +136,7 @@ def __init__(self, in_dim, out_dim, dropout=0.0):
136
136
if in_dim != out_dim else nn .Identity ())
137
137
138
138
def forward (self , x , feat_cache = None , feat_idx = [0 ]):
139
- h = self . shortcut ( x )
139
+ old_x = x
140
140
for layer in self .residual :
141
141
if isinstance (layer , CausalConv3d ) and feat_cache is not None :
142
142
idx = feat_idx [0 ]
@@ -156,7 +156,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
156
156
feat_idx [0 ] += 1
157
157
else :
158
158
x = layer (x )
159
- return x + h
159
+ return x + self . shortcut ( old_x )
160
160
161
161
162
162
def patchify (x , patch_size ):
@@ -327,7 +327,7 @@ def __init__(self,
327
327
self .downsamples = nn .Sequential (* downsamples )
328
328
329
329
def forward (self , x , feat_cache = None , feat_idx = [0 ]):
330
- x_copy = x . clone ()
330
+ x_copy = x
331
331
for module in self .downsamples :
332
332
x = module (x , feat_cache , feat_idx )
333
333
@@ -369,7 +369,7 @@ def __init__(self,
369
369
self .upsamples = nn .Sequential (* upsamples )
370
370
371
371
def forward (self , x , feat_cache = None , feat_idx = [0 ], first_chunk = False ):
372
- x_main = x . clone ()
372
+ x_main = x
373
373
for module in self .upsamples :
374
374
x_main = module (x_main , feat_cache , feat_idx )
375
375
if self .avg_shortcut is not None :
0 commit comments