diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ccacc466794a..5db431bc615f 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -300,9 +300,10 @@ def __getitem__(self, index): example["instance_images"] = self.image_transforms(instance_image) example["instance_prompt_ids"] = self.tokenizer( self.instance_prompt, - padding="do_not_pad", truncation=True, + padding="max_length", max_length=self.tokenizer.model_max_length, + return_tensors="pt", ).input_ids if self.class_data_root: @@ -312,14 +313,37 @@ def __getitem__(self, index): example["class_images"] = self.image_transforms(class_image) example["class_prompt_ids"] = self.tokenizer( self.class_prompt, - padding="do_not_pad", truncation=True, + padding="max_length", max_length=self.tokenizer.model_max_length, + return_tensors="pt", ).input_ids return example +def collate_fn(examples, with_prior_preservation=False): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + class PromptDataset(Dataset): "A simple dataset to prepare the prompts to generate class images on multiple GPUs." @@ -510,34 +534,12 @@ def main(args): center_crop=args.center_crop, ) - def collate_fn(examples): - input_ids = [example["instance_prompt_ids"] for example in examples] - pixel_values = [example["instance_images"] for example in examples] - - # Concat class and instance examples for prior preservation. - # We do this to avoid doing two forward passes. - if args.with_prior_preservation: - input_ids += [example["class_prompt_ids"] for example in examples] - pixel_values += [example["class_images"] for example in examples] - - pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - - input_ids = tokenizer.pad( - {"input_ids": input_ids}, - padding="max_length", - max_length=tokenizer.model_max_length, - return_tensors="pt", - ).input_ids - - batch = { - "input_ids": input_ids, - "pixel_values": pixel_values, - } - return batch - train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1 + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=1, ) # Scheduler and math around the number of training steps.