Skip to content

Commit

Permalink
fix tiled vae blend extent range (huggingface#3384)
Browse files Browse the repository at this point in the history
fix tiled vae bleand extent range
  • Loading branch information
superlabs-dev authored and dg845 committed May 21, 2023
1 parent 63abfce commit 32162aa
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/diffusers/models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,14 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
return DecoderOutput(sample=decoded)

def blend_v(self, a, b, blend_extent):
for y in range(min(a.shape[2], b.shape[2], blend_extent)):
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b

def blend_h(self, a, b, blend_extent):
for x in range(min(a.shape[3], b.shape[3], blend_extent)):
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b

Expand Down

0 comments on commit 32162aa

Please sign in to comment.