Skip to content

Commit

Permalink
fix all issues with training on transparent pngs
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 10, 2022
1 parent bebc280 commit a6776c8
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 6 deletions.
7 changes: 4 additions & 3 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
temperature = 0.9,
straight_through = False,
kl_div_loss_weight = 0.,
normalization = ((0.5,) * 3, (0.5,) * 3)
normalization = ((*((0.5,) * 3), 0), (*((0.5,) * 3), 1))
):
super().__init__()
assert log2(image_size).is_integer(), 'image size must be a power of 2'
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(
self.kl_div_loss_weight = kl_div_loss_weight

# take care of normalization within class
self.normalization = normalization
self.normalization = tuple(map(lambda t: t[:channels], normalization))

self._register_external_parameters()

Expand Down Expand Up @@ -594,7 +594,8 @@ def forward(

if is_raw_image:
image_size = self.vae.image_size
assert tuple(image.shape[1:]) == (3, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training'
channels = self.vae.channels
assert tuple(image.shape[1:]) == (channels, image_size, image_size), f'invalid image of dimensions {image.shape} passed in during training'

image = self.vae.get_codebook_indices(image)

Expand Down
8 changes: 6 additions & 2 deletions dalle_pytorch/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self,
image_size=128,
truncate_captions=False,
resize_ratio=0.75,
transparent=False,
tokenizer=None,
shuffle=False
):
Expand Down Expand Up @@ -43,9 +44,12 @@ def __init__(self,
self.truncate_captions = truncate_captions
self.resize_ratio = resize_ratio
self.tokenizer = tokenizer

image_mode = 'RGBA' if transparent else 'RGB'

self.image_transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB')
if img.mode != 'RGB' else img),
T.Lambda(lambda img: img.convert(image_mode)
if img.mode != image_mode else img),
T.RandomResizedCrop(image_size,
scale=(self.resize_ratio, 1.),
ratio=(1., 1.)),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '1.6.0',
version = '1.6.1',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down
1 change: 1 addition & 0 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available
args.image_text_folder,
text_len=TEXT_SEQ_LEN,
image_size=IMAGE_SIZE,
transparent=TRANSPARENT,
resize_ratio=args.resize_ratio,
truncate_captions=args.truncate_captions,
tokenizer=tokenizer,
Expand Down

0 comments on commit a6776c8

Please sign in to comment.