Skip to content

Commit 67ea2b7

Browse files
authored
Support tiled encode/decode for AutoencoderTiny (#4627)
* Impl tae slicing and tiling * add tae tiling test * add parameterized test * formatted code * fix failed test * style docs
1 parent a10107f commit 67ea2b7

File tree

2 files changed

+168
-2
lines changed

2 files changed

+168
-2
lines changed

src/diffusers/models/autoencoder_tiny.py

Lines changed: 151 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,15 @@ def __init__(
137137
self.latent_shift = latent_shift
138138
self.scaling_factor = scaling_factor
139139

140+
self.use_slicing = False
141+
self.use_tiling = False
142+
143+
# only relevant if vae tiling is enabled
144+
self.spatial_scale_factor = 2**out_channels
145+
self.tile_overlap_factor = 0.125
146+
self.tile_sample_min_size = 512
147+
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
148+
140149
def _set_gradient_checkpointing(self, module, value=False):
141150
if isinstance(module, (EncoderTiny, DecoderTiny)):
142151
module.gradient_checkpointing = value
@@ -149,11 +158,147 @@ def unscale_latents(self, x):
149158
"""[0, 1] -> raw latents"""
150159
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
151160

161+
def enable_slicing(self):
162+
r"""
163+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
164+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
165+
"""
166+
self.use_slicing = True
167+
168+
def disable_slicing(self):
169+
r"""
170+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
171+
decoding in one step.
172+
"""
173+
self.use_slicing = False
174+
175+
def enable_tiling(self, use_tiling: bool = True):
176+
r"""
177+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
178+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
179+
processing larger images.
180+
"""
181+
self.use_tiling = use_tiling
182+
183+
def disable_tiling(self):
184+
r"""
185+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
186+
decoding in one step.
187+
"""
188+
self.enable_tiling(False)
189+
190+
def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
191+
r"""Encode a batch of images using a tiled encoder.
192+
193+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
194+
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
195+
tiles overlap and are blended together to form a smooth output.
196+
197+
Args:
198+
x (`torch.FloatTensor`): Input batch of images.
199+
return_dict (`bool`, *optional*, defaults to `True`):
200+
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
201+
202+
Returns:
203+
[`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
204+
If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
205+
plain `tuple` is returned.
206+
"""
207+
# scale of encoder output relative to input
208+
sf = self.spatial_scale_factor
209+
tile_size = self.tile_sample_min_size
210+
211+
# number of pixels to blend and to traverse between tile
212+
blend_size = int(tile_size * self.tile_overlap_factor)
213+
traverse_size = tile_size - blend_size
214+
215+
# tiles index (up/left)
216+
ti = range(0, x.shape[-2], traverse_size)
217+
tj = range(0, x.shape[-1], traverse_size)
218+
219+
# mask for blending
220+
blend_masks = torch.stack(
221+
torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
222+
)
223+
blend_masks = blend_masks.clamp(0, 1).to(x.device)
224+
225+
# output array
226+
out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
227+
for i in ti:
228+
for j in tj:
229+
tile_in = x[..., i : i + tile_size, j : j + tile_size]
230+
# tile result
231+
tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
232+
tile = self.encoder(tile_in)
233+
h, w = tile.shape[-2], tile.shape[-1]
234+
# blend tile result into output
235+
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
236+
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
237+
blend_mask = blend_mask_i * blend_mask_j
238+
tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
239+
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
240+
return out
241+
242+
def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
243+
r"""Encode a batch of images using a tiled encoder.
244+
245+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
246+
steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
247+
tiles overlap and are blended together to form a smooth output.
248+
249+
Args:
250+
x (`torch.FloatTensor`): Input batch of images.
251+
return_dict (`bool`, *optional*, defaults to `True`):
252+
Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
253+
254+
Returns:
255+
[`~models.vae.DecoderOutput`] or `tuple`:
256+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
257+
returned.
258+
"""
259+
# scale of decoder output relative to input
260+
sf = self.spatial_scale_factor
261+
tile_size = self.tile_latent_min_size
262+
263+
# number of pixels to blend and to traverse between tiles
264+
blend_size = int(tile_size * self.tile_overlap_factor)
265+
traverse_size = tile_size - blend_size
266+
267+
# tiles index (up/left)
268+
ti = range(0, x.shape[-2], traverse_size)
269+
tj = range(0, x.shape[-1], traverse_size)
270+
271+
# mask for blending
272+
blend_masks = torch.stack(
273+
torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
274+
)
275+
blend_masks = blend_masks.clamp(0, 1).to(x.device)
276+
277+
# output array
278+
out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
279+
for i in ti:
280+
for j in tj:
281+
tile_in = x[..., i : i + tile_size, j : j + tile_size]
282+
# tile result
283+
tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
284+
tile = self.decoder(tile_in)
285+
h, w = tile.shape[-2], tile.shape[-1]
286+
# blend tile result into output
287+
blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
288+
blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
289+
blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
290+
tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
291+
return out
292+
152293
@apply_forward_hook
153294
def encode(
154295
self, x: torch.FloatTensor, return_dict: bool = True
155296
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
156-
output = self.encoder(x)
297+
if self.use_slicing and x.shape[0] > 1:
298+
output = [self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x) for x_slice in x.split(1)]
299+
output = torch.cat(output)
300+
else:
301+
output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
157302

158303
if not return_dict:
159304
return (output,)
@@ -162,7 +307,11 @@ def encode(
162307

163308
@apply_forward_hook
164309
def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
165-
output = self.decoder(x)
310+
if self.use_slicing and x.shape[0] > 1:
311+
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
312+
output = torch.cat(output)
313+
else:
314+
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
166315
# Refer to the following discussion to know why this is needed.
167316
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
168317
output = output.mul_(2).sub_(1)

tests/models/test_models_vae.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,23 @@ def get_sd_vae_model(self, model_id="hf-internal-testing/taesd-diffusers", fp16=
285285
model.to(torch_device).eval()
286286
return model
287287

288+
@parameterized.expand(
289+
[
290+
[(1, 4, 73, 97), (1, 3, 584, 776)],
291+
[(1, 4, 97, 73), (1, 3, 776, 584)],
292+
[(1, 4, 49, 65), (1, 3, 392, 520)],
293+
[(1, 4, 65, 49), (1, 3, 520, 392)],
294+
[(1, 4, 49, 49), (1, 3, 392, 392)],
295+
]
296+
)
297+
def test_tae_tiling(self, in_shape, out_shape):
298+
model = self.get_sd_vae_model()
299+
model.enable_tiling()
300+
with torch.no_grad():
301+
zeros = torch.zeros(in_shape).to(torch_device)
302+
dec = model.decode(zeros).sample
303+
assert dec.shape == out_shape
304+
288305
def test_stable_diffusion(self):
289306
model = self.get_sd_vae_model()
290307
image = self.get_sd_image(seed=33)

0 commit comments

Comments
 (0)