Skip to content

Commit

Permalink
Full Dreambooth IF stage II upscaling (#3561)
Browse files Browse the repository at this point in the history
* update dreambooth lora to work with IF stage II

* Update dreambooth script for IF stage II upscaler
  • Loading branch information
williamberman committed May 31, 2023
1 parent f751b88 commit 4f14b36
Showing 1 changed file with 46 additions and 9 deletions.
55 changes: 46 additions & 9 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor


if is_wandb_available():
Expand Down Expand Up @@ -114,16 +115,17 @@ def log_validation(

pipeline_args = {}

if text_encoder is not None:
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)

if vae is not None:
pipeline_args["vae"] = vae

if text_encoder is not None:
text_encoder = accelerator.unwrap_model(text_encoder)

# create pipeline (note: unet and vae are loaded again in float32)
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
tokenizer=tokenizer,
text_encoder=text_encoder,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
torch_dtype=weight_dtype,
Expand Down Expand Up @@ -156,10 +158,16 @@ def log_validation(
# run inference
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
images.append(image)
if args.validation_images is None:
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
images.append(image)
else:
for image in args.validation_images:
image = Image.open(image)
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)

for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
Expand Down Expand Up @@ -525,6 +533,19 @@ def parse_args(input_args=None):
parser.add_argument(
"--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder"
)
parser.add_argument(
"--validation_images",
required=False,
default=None,
nargs="+",
help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
)
parser.add_argument(
"--class_labels_conditioning",
required=False,
default=None,
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1169,7 +1190,7 @@ def compute_text_embeddings(prompt):
)
else:
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
bsz, channels, height, width = model_input.shape
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
Expand All @@ -1191,8 +1212,24 @@ def compute_text_embeddings(prompt):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)

if unet.config.in_channels > channels:
needed_additional_channels = unet.config.in_channels - channels
additional_latents = randn_tensor(
(bsz, needed_additional_channels, height, width),
device=noisy_model_input.device,
dtype=noisy_model_input.dtype,
)
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)

if args.class_labels_conditioning == "timesteps":
class_labels = timesteps
else:
class_labels = None

# Predict the noise residual
model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample
model_pred = unet(
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
).sample

if model_pred.shape[1] == 6:
model_pred, _ = torch.chunk(model_pred, 2, dim=1)
Expand Down

0 comments on commit 4f14b36

Please sign in to comment.