Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 31 additions & 29 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,10 @@ def __getitem__(self, index):
example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt,
padding="do_not_pad",
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids

if self.class_data_root:
Expand All @@ -312,14 +313,37 @@ def __getitem__(self, index):
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
padding="do_not_pad",
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids

return example


def collate_fn(examples, with_prior_preservation=False):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]

# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]

pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

input_ids = torch.cat(input_ids, dim=0)

batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
return batch


class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."

Expand Down Expand Up @@ -510,34 +534,12 @@ def main(args):
center_crop=args.center_crop,
)

def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]

# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if args.with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]

pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

input_ids = tokenizer.pad(
{"input_ids": input_ids},
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).input_ids

batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
return batch

train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=1,
)

# Scheduler and math around the number of training steps.
Expand Down