Skip to content

Commit

Permalink
add autocast fix to other training examples
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Mar 31, 2024
1 parent 20a7c8f commit ad3eb80
Show file tree
Hide file tree
Showing 12 changed files with 122 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available

from contextlib import nullcontext

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.28.0.dev0")
Expand Down Expand Up @@ -1844,7 +1844,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}

with torch.cuda.amp.autocast():
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import warnings
from pathlib import Path
from typing import List, Optional
from contextlib import nullcontext

import numpy as np
import torch
Expand Down Expand Up @@ -2192,13 +2193,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
inference_ctx = (
contextlib.nullcontext()
if "playground" in args.pretrained_model_name_or_path
else torch.cuda.amp.autocast()
)
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with inference_ctx:
with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
Expand Down
16 changes: 13 additions & 3 deletions examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import shutil
from pathlib import Path
from typing import List, Union
from contextlib import nullcontext

import accelerate
import numpy as np
Expand Down Expand Up @@ -238,6 +239,10 @@ def train_dataloader(self):

def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionPipeline.from_pretrained(
Expand Down Expand Up @@ -274,7 +279,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1172,6 +1177,11 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]

if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

# 16. Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

Expand Down Expand Up @@ -1300,7 +1310,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1359,7 +1369,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
target_noise_pred = unet(
x_prev.float(),
timesteps,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import random
import shutil
from pathlib import Path
from contextlib import nullcontext

import accelerate
import numpy as np
Expand Down Expand Up @@ -146,7 +147,12 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import shutil
from pathlib import Path
from typing import List, Union
from contextlib import nullcontext

import accelerate
import numpy as np
Expand Down Expand Up @@ -256,6 +257,10 @@ def train_dataloader(self):

def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionXLPipeline.from_pretrained(
Expand Down Expand Up @@ -291,7 +296,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda", dtype=weight_dtype):
with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1353,7 +1358,12 @@ def compute_embeddings(
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1416,7 +1426,12 @@ def compute_embeddings(
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
target_noise_pred = unet(
x_prev.float(),
timesteps,
Expand Down
22 changes: 19 additions & 3 deletions examples/consistency_distillation/train_lcm_distill_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import shutil
from pathlib import Path
from typing import List, Union
from contextlib import nullcontext

import accelerate
import numpy as np
Expand Down Expand Up @@ -252,7 +253,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1257,7 +1263,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1315,7 +1326,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok

# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
Expand Down
22 changes: 19 additions & 3 deletions examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import shutil
from pathlib import Path
from typing import List, Union
from contextlib import nullcontext

import accelerate
import numpy as np
Expand Down Expand Up @@ -270,7 +271,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe

for _, prompt in enumerate(validation_prompts):
images = []
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
Expand Down Expand Up @@ -1355,7 +1361,12 @@ def compute_embeddings(
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
Expand Down Expand Up @@ -1417,7 +1428,12 @@ def compute_embeddings(

# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)

with autocast_ctx:
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
Expand Down
13 changes: 6 additions & 7 deletions examples/controlnet/train_controlnet_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and

import argparse
import contextlib
from contextlib import nullcontext
import functools
import gc
import logging
Expand Down Expand Up @@ -125,11 +125,10 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
)

image_logs = []
inference_ctx = (
contextlib.nullcontext()
if (is_final_validation or torch.backends.mps.is_available())
else torch.autocast("cuda")
)
if is_final_validation or torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
Expand All @@ -138,7 +137,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
images = []

for _ in range(args.num_validation_images):
with inference_ctx:
with autocast_ctx:
image = pipeline(
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
).images[0]
Expand Down
11 changes: 5 additions & 6 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
from contextlib import nullcontext

import diffusers
from diffusers import (
Expand Down Expand Up @@ -207,14 +208,12 @@ def log_validation(
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
enable_autocast = True
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
enable_autocast = False
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with torch.autocast(
accelerator.device.type,
enabled=enable_autocast,
):
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

for tracker in accelerator.trackers:
Expand Down
Loading

0 comments on commit ad3eb80

Please sign in to comment.