From 61558639628c1af8918b29e99da27e7570d960a7 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 9 Oct 2023 17:37:42 +0000 Subject: [PATCH 1/3] zero out dropped captions --- diffusion/datasets/image_caption.py | 12 ++++++++++++ diffusion/models/stable_diffusion.py | 5 +++++ 2 files changed, 17 insertions(+) diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 24cec6dd..38155c90 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -43,6 +43,8 @@ class StreamingImageCaptionDataset(StreamingDataset): image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. sdxl (bool): Whether or not we're training SDXL. Default: `False`. + zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``True``. + **streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader """ @@ -60,6 +62,7 @@ def __init__( image_key: str = 'image', caption_key: str = 'caption', sdxl: bool = False, + zero_dropped_captions: bool = True, **streaming_kwargs, ) -> None: @@ -87,6 +90,7 @@ def __init__( self.image_size = image_size self.image_key = image_key self.caption_key = caption_key + self.zero_dropped_captions = zero_dropped_captions def __getitem__(self, index): sample = super().__getitem__(index) @@ -122,12 +126,17 @@ def __getitem__(self, index): # Caption if torch.rand(1) < self.caption_drop_prob: caption = '' + if self.zero_dropped_captions: + out['drop_caption_mask'] = 0.0 + else: + out['drop_caption_mask'] = 1.0 else: caption = sample[self.caption_key] if isinstance(caption, List) and self.caption_selection == 'first': caption = caption[0] if isinstance(caption, List) and self.caption_selection == 'random': caption = random.sample(caption, k=1)[0] + out['drop_caption_mask'] = 1.0 max_length = None if self.sdxl else self.tokenizer.model_max_length # type: ignore tokenized_caption = self.tokenizer(caption, @@ -158,6 +167,7 @@ def build_streaming_image_caption_dataloader( image_key: str = 'image', caption_key: str = 'caption', rand_crop: bool = False, + zero_dropped_captions: bool = True, streaming_kwargs: Optional[Dict] = None, dataloader_kwargs: Optional[Dict] = None, ): @@ -178,6 +188,7 @@ def build_streaming_image_caption_dataloader( image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. rand_crop (bool): If True, randomly crop images. Otherwise, center crop. Default: ``False``. + zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``True``. streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``. dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``. """ @@ -240,6 +251,7 @@ def build_streaming_image_caption_dataloader( caption_key=caption_key, batch_size=batch_size, sdxl=sdxl, + zero_dropped_captions=zero_dropped_captions, **streaming_kwargs, ) diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 4435d0bc..69ae47a7 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -196,6 +196,11 @@ def forward(self, batch): # Magical scaling number (See https://github.com/huggingface/diffusers/issues/437#issuecomment-1241827515) latents *= self.latent_scale + # Zero dropped captions if needed + conditioning *= batch['drop_caption_mask'].view(-1, 1, 1) + if pooled_conditioning is not None: + pooled_conditioning *= batch['drop_caption_mask'].view(-1, 1) + # Sample the diffusion timesteps timesteps = torch.randint(0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device) # Add noise to the inputs (forward diffusion) From 4d5a6bec6d14defb10f84280ae4f55d96dd97968 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 9 Oct 2023 13:35:27 -0700 Subject: [PATCH 2/3] add if statement for compatibility w other dataloders --- diffusion/models/stable_diffusion.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 69ae47a7..9a4a8c7f 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -197,9 +197,10 @@ def forward(self, batch): latents *= self.latent_scale # Zero dropped captions if needed - conditioning *= batch['drop_caption_mask'].view(-1, 1, 1) - if pooled_conditioning is not None: - pooled_conditioning *= batch['drop_caption_mask'].view(-1, 1) + if 'drop_caption_mask' in batch.keys(): + conditioning *= batch['drop_caption_mask'].view(-1, 1, 1) + if pooled_conditioning is not None: + pooled_conditioning *= batch['drop_caption_mask'].view(-1, 1) # Sample the diffusion timesteps timesteps = torch.randint(0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device) From ddb42fa84b13253e1df8aea4ff29f1266f4e0a59 Mon Sep 17 00:00:00 2001 From: Jasmine Collins Date: Fri, 13 Oct 2023 09:28:24 -0700 Subject: [PATCH 3/3] set zero_dropped_captions default to False --- diffusion/datasets/image_caption.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 38155c90..81583e6f 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -43,7 +43,7 @@ class StreamingImageCaptionDataset(StreamingDataset): image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. sdxl (bool): Whether or not we're training SDXL. Default: `False`. - zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``True``. + zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``False``. **streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader """ @@ -62,7 +62,7 @@ def __init__( image_key: str = 'image', caption_key: str = 'caption', sdxl: bool = False, - zero_dropped_captions: bool = True, + zero_dropped_captions: bool = False, **streaming_kwargs, ) -> None: