Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ def _prepare_masks(
if downscale_factor <= max_downscale_factor:
# We use max pooling because we downscale to a pretty low resolution, so we don't want small mask
# regions to be lost entirely.
#
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
#
# TODO(ryand): In the future, we may want to experiment with other downsampling methods.
mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2)
mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2, ceil_mode=True)

return masks_by_seq_len

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,12 @@ def _prepare_spatial_masks(
if downscale_factor <= max_downscale_factor:
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
# regions to be lost entirely.
#
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
#
# TODO(ryand): In the future, we may want to experiment with other downsampling methods (e.g.
# nearest interpolation), and could potentially use a weighted mask rather than a binary mask.
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2)
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2, ceil_mode=True)

return batch_sample_masks_by_seq_len

Expand Down