@@ -137,6 +137,15 @@ def __init__(
137
137
self .latent_shift = latent_shift
138
138
self .scaling_factor = scaling_factor
139
139
140
+ self .use_slicing = False
141
+ self .use_tiling = False
142
+
143
+ # only relevant if vae tiling is enabled
144
+ self .spatial_scale_factor = 2 ** out_channels
145
+ self .tile_overlap_factor = 0.125
146
+ self .tile_sample_min_size = 512
147
+ self .tile_latent_min_size = self .tile_sample_min_size // self .spatial_scale_factor
148
+
140
149
def _set_gradient_checkpointing (self , module , value = False ):
141
150
if isinstance (module , (EncoderTiny , DecoderTiny )):
142
151
module .gradient_checkpointing = value
@@ -149,11 +158,147 @@ def unscale_latents(self, x):
149
158
"""[0, 1] -> raw latents"""
150
159
return x .sub (self .latent_shift ).mul (2 * self .latent_magnitude )
151
160
161
+ def enable_slicing (self ):
162
+ r"""
163
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
164
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
165
+ """
166
+ self .use_slicing = True
167
+
168
+ def disable_slicing (self ):
169
+ r"""
170
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
171
+ decoding in one step.
172
+ """
173
+ self .use_slicing = False
174
+
175
+ def enable_tiling (self , use_tiling : bool = True ):
176
+ r"""
177
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
178
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
179
+ processing larger images.
180
+ """
181
+ self .use_tiling = use_tiling
182
+
183
+ def disable_tiling (self ):
184
+ r"""
185
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
186
+ decoding in one step.
187
+ """
188
+ self .enable_tiling (False )
189
+
190
+ def _tiled_encode (self , x : torch .FloatTensor ) -> torch .FloatTensor :
191
+ r"""Encode a batch of images using a tiled encoder.
192
+
193
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
194
+ steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
195
+ tiles overlap and are blended together to form a smooth output.
196
+
197
+ Args:
198
+ x (`torch.FloatTensor`): Input batch of images.
199
+ return_dict (`bool`, *optional*, defaults to `True`):
200
+ Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
201
+
202
+ Returns:
203
+ [`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
204
+ If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
205
+ plain `tuple` is returned.
206
+ """
207
+ # scale of encoder output relative to input
208
+ sf = self .spatial_scale_factor
209
+ tile_size = self .tile_sample_min_size
210
+
211
+ # number of pixels to blend and to traverse between tile
212
+ blend_size = int (tile_size * self .tile_overlap_factor )
213
+ traverse_size = tile_size - blend_size
214
+
215
+ # tiles index (up/left)
216
+ ti = range (0 , x .shape [- 2 ], traverse_size )
217
+ tj = range (0 , x .shape [- 1 ], traverse_size )
218
+
219
+ # mask for blending
220
+ blend_masks = torch .stack (
221
+ torch .meshgrid ([torch .arange (tile_size / sf ) / (blend_size / sf - 1 )] * 2 , indexing = "ij" )
222
+ )
223
+ blend_masks = blend_masks .clamp (0 , 1 ).to (x .device )
224
+
225
+ # output array
226
+ out = torch .zeros (x .shape [0 ], 4 , x .shape [- 2 ] // sf , x .shape [- 1 ] // sf , device = x .device )
227
+ for i in ti :
228
+ for j in tj :
229
+ tile_in = x [..., i : i + tile_size , j : j + tile_size ]
230
+ # tile result
231
+ tile_out = out [..., i // sf : (i + tile_size ) // sf , j // sf : (j + tile_size ) // sf ]
232
+ tile = self .encoder (tile_in )
233
+ h , w = tile .shape [- 2 ], tile .shape [- 1 ]
234
+ # blend tile result into output
235
+ blend_mask_i = torch .ones_like (blend_masks [0 ]) if i == 0 else blend_masks [0 ]
236
+ blend_mask_j = torch .ones_like (blend_masks [1 ]) if j == 0 else blend_masks [1 ]
237
+ blend_mask = blend_mask_i * blend_mask_j
238
+ tile , blend_mask = tile [..., :h , :w ], blend_mask [..., :h , :w ]
239
+ tile_out .copy_ (blend_mask * tile + (1 - blend_mask ) * tile_out )
240
+ return out
241
+
242
+ def _tiled_decode (self , x : torch .FloatTensor ) -> torch .FloatTensor :
243
+ r"""Encode a batch of images using a tiled encoder.
244
+
245
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
246
+ steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
247
+ tiles overlap and are blended together to form a smooth output.
248
+
249
+ Args:
250
+ x (`torch.FloatTensor`): Input batch of images.
251
+ return_dict (`bool`, *optional*, defaults to `True`):
252
+ Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
253
+
254
+ Returns:
255
+ [`~models.vae.DecoderOutput`] or `tuple`:
256
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
257
+ returned.
258
+ """
259
+ # scale of decoder output relative to input
260
+ sf = self .spatial_scale_factor
261
+ tile_size = self .tile_latent_min_size
262
+
263
+ # number of pixels to blend and to traverse between tiles
264
+ blend_size = int (tile_size * self .tile_overlap_factor )
265
+ traverse_size = tile_size - blend_size
266
+
267
+ # tiles index (up/left)
268
+ ti = range (0 , x .shape [- 2 ], traverse_size )
269
+ tj = range (0 , x .shape [- 1 ], traverse_size )
270
+
271
+ # mask for blending
272
+ blend_masks = torch .stack (
273
+ torch .meshgrid ([torch .arange (tile_size * sf ) / (blend_size * sf - 1 )] * 2 , indexing = "ij" )
274
+ )
275
+ blend_masks = blend_masks .clamp (0 , 1 ).to (x .device )
276
+
277
+ # output array
278
+ out = torch .zeros (x .shape [0 ], 3 , x .shape [- 2 ] * sf , x .shape [- 1 ] * sf , device = x .device )
279
+ for i in ti :
280
+ for j in tj :
281
+ tile_in = x [..., i : i + tile_size , j : j + tile_size ]
282
+ # tile result
283
+ tile_out = out [..., i * sf : (i + tile_size ) * sf , j * sf : (j + tile_size ) * sf ]
284
+ tile = self .decoder (tile_in )
285
+ h , w = tile .shape [- 2 ], tile .shape [- 1 ]
286
+ # blend tile result into output
287
+ blend_mask_i = torch .ones_like (blend_masks [0 ]) if i == 0 else blend_masks [0 ]
288
+ blend_mask_j = torch .ones_like (blend_masks [1 ]) if j == 0 else blend_masks [1 ]
289
+ blend_mask = (blend_mask_i * blend_mask_j )[..., :h , :w ]
290
+ tile_out .copy_ (blend_mask * tile + (1 - blend_mask ) * tile_out )
291
+ return out
292
+
152
293
@apply_forward_hook
153
294
def encode (
154
295
self , x : torch .FloatTensor , return_dict : bool = True
155
296
) -> Union [AutoencoderTinyOutput , Tuple [torch .FloatTensor ]]:
156
- output = self .encoder (x )
297
+ if self .use_slicing and x .shape [0 ] > 1 :
298
+ output = [self ._tiled_encode (x_slice ) if self .use_tiling else self .encoder (x ) for x_slice in x .split (1 )]
299
+ output = torch .cat (output )
300
+ else :
301
+ output = self ._tiled_encode (x ) if self .use_tiling else self .encoder (x )
157
302
158
303
if not return_dict :
159
304
return (output ,)
@@ -162,7 +307,11 @@ def encode(
162
307
163
308
@apply_forward_hook
164
309
def decode (self , x : torch .FloatTensor , return_dict : bool = True ) -> Union [DecoderOutput , Tuple [torch .FloatTensor ]]:
165
- output = self .decoder (x )
310
+ if self .use_slicing and x .shape [0 ] > 1 :
311
+ output = [self ._tiled_decode (x_slice ) if self .use_tiling else self .decoder (x ) for x_slice in x .split (1 )]
312
+ output = torch .cat (output )
313
+ else :
314
+ output = self ._tiled_decode (x ) if self .use_tiling else self .decoder (x )
166
315
# Refer to the following discussion to know why this is needed.
167
316
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
168
317
output = output .mul_ (2 ).sub_ (1 )
0 commit comments