Skip to content

Commit

Permalink
fix rgbs outputted by unet upsampler to discriminator, and validate c…
Browse files Browse the repository at this point in the history
…orrectly on GigaGAN instantiation, thanks to @XavierXiao
  • Loading branch information
lucidrains committed Jul 25, 2023
1 parent 373267f commit 8dd0514
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
6 changes: 6 additions & 0 deletions gigagan_pytorch/gigagan_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,7 @@ def forward(
x = images

image_size = (self.image_size, self.image_size)

assert x.shape[-2:] == image_size

batch = x.shape[0]
Expand Down Expand Up @@ -1834,6 +1835,11 @@ def __init__(
self.D = discriminator
self.VD = vision_aided_discriminator

# validate multiscale input resolutions

if train_upsampler:
assert is_empty(set(discriminator.multiscale_input_resolutions) - set(generator.allowable_rgb_resolutions)), f'only multiscale input resolutions of {generator.allowable_rgb_resolutions} is allowed based on the unet input and output image size'

# ema

self.has_ema_generator = False
Expand Down
9 changes: 8 additions & 1 deletion gigagan_pytorch/unet_upsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,13 @@ def __init__(
self.style_to_conv_modulations = nn.Linear(style_network.dim, sum(style_embed_split_dims))
self.style_embed_split_dims = style_embed_split_dims

@property
def allowable_rgb_resolutions(self):
input_res_base = int(log2(self.input_image_size))
output_res_base = int(log2(self.image_size))
allowed_rgb_res_base = list(range(input_res_base + 1, output_res_base))
return [*map(lambda p: 2 ** p, allowed_rgb_res_base)]

@property
def device(self):
return next(self.parameters()).device
Expand Down Expand Up @@ -584,7 +591,7 @@ def forward(

# only keep those rgbs whose feature map is greater than the input image to be upsampled

rgbs = list(filter(lambda t: t.shape[-1] <= shape[-1], rgbs))
rgbs = list(filter(lambda t: t.shape[-1] > shape[-1], rgbs))

if not replace_rgb_with_input_lowres_image:
return rgb, rgbs
Expand Down
2 changes: 1 addition & 1 deletion gigagan_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.5'
__version__ = '0.2.6'

0 comments on commit 8dd0514

Please sign in to comment.