[WIP] Add Adversarial Diffusion Distillation (ADD) Script#6303
[WIP] Add Adversarial Diffusion Distillation (ADD) Script#6303dg845 wants to merge 58 commits intohuggingface:mainfrom
Conversation
| def transform(example): | ||
| # resize image | ||
| image = example["image"] | ||
| image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) |
There was a problem hiding this comment.
bilinear causes image artifacting that impacts training quality in major ways. use LANCZOS.
There was a problem hiding this comment.
We could make this configurable as well if users prefer to use different interpolations.
There was a problem hiding this comment.
| if self.prediction_type == "epsilon": | ||
| pred_x_0 = (sample - sigmas * model_output) / alphas | ||
| elif self.prediction_type == "sample": | ||
| pred_x_0 = model_output |
There was a problem hiding this comment.
the u-net returns the residual noise prediction, right? or is it not an intermediary phase with XL Turbo?
There was a problem hiding this comment.
the unet returns residual noise prediction only when prediction_type=epsilon , which is the default for most SD models. This is to support different prediction_types
There was a problem hiding this comment.
well when we train on v_prediction it uses the residual returned from the unet as an input to get_velocity. ergo it is an intermediary stage for v-prediction in Diffusers training. but this code makes it appear as if the sample is returned directly.
There was a problem hiding this comment.
I think this is fine as the denoiser does exactly that, it returns the predicted original sample.
|
|
||
|
|
||
| @torch.no_grad() | ||
| def update_ema(target_params, source_params, rate=0.99): |
There was a problem hiding this comment.
why not use EMAModel class?
There was a problem hiding this comment.
Yeah, +1 to that. Let's try using the EMAModel class.
There was a problem hiding this comment.
Not fully sure, but using EMA like this (fixed ema rate) might make sense for distillation as we might not want to change the model too much. So a fixed high enough value of rate could be better.
There was a problem hiding this comment.
Just to confirm, in EMAModel there's currently no option to set a fixed decay rate? In get_decay it looks like self.decay is not used whether self.use_ema_warmup is True or False:
diffusers/src/diffusers/training_utils.py
Lines 196 to 199 in 1fff527
In fact, it doesn't seem like self.decay is used in any of the EMA logic at all.
There was a problem hiding this comment.
@dg845 please correct me if I am wrong here.
This is where self.decay is used in get_decay():
diffusers/src/diffusers/training_utils.py
Line 201 in 1fff527
Shouldn't that suffice?
There was a problem hiding this comment.
Thanks, I missed that. I think a fixed decay rate can be achieved by setting decay == min_decay.
| action="store_true", | ||
| help=( | ||
| "Whether to center crop the input images to the resolution. If not set, the images will be randomly" | ||
| " cropped. The images will be resized to the resolution first before cropping." |
There was a problem hiding this comment.
there's no need to unconditionally resize before crop, especially on a diverse dataset. not resizing first allows better fine details to be learnt.
| # Enforce zero terminal SNR (see section 3.1 of ADD paper) | ||
| # TODO: is there a better way to implement this? |
There was a problem hiding this comment.
doesn't seem to be much point to this, since even @PeterL1n et al showed that zero-terminal SNR doesn't do anything meaningful for epsilon models, and SDXL Turbo doesn't use v-prediction unfortunately..
There was a problem hiding this comment.
I think,as the losses are computed in pixel-space, it could still have some effect for epsilon prediction.
|
|
||
| if accelerator.unwrap_model(unet).dtype != torch.float32: | ||
| raise ValueError( | ||
| f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" |
|
|
||
| # 9. Handle mixed precision and device placement | ||
| # For mixed precision training we cast all non-trainable weigths to half-precision | ||
| # as these weights are only used for inference, keeping weights in full precision is not required. |
There was a problem hiding this comment.
inference has a big difference with bfloat16 vs float16
| student_timestep_schedule = torch.from_numpy(student_timestep_schedule).to(accelerator.device) | ||
|
|
||
| # 10. Handle saving and loading of checkpoints | ||
| # `accelerate` 0.16.0 will have better support for customized saving |
There was a problem hiding this comment.
we're > 0.20.0 accelerate now
| # Enable TF32 for faster training on Ampere GPUs, | ||
| # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | ||
| if args.allow_tf32: | ||
| torch.backends.cuda.matmul.allow_tf32 = True |
There was a problem hiding this comment.
other than matmul you can also enable:
torch.backends.cudnn.allow_tf32 = True| image, text = batch | ||
|
|
||
| image = image.to(accelerator.device, non_blocking=True) | ||
| encoded_text = compute_embeddings_fn(text) |
There was a problem hiding this comment.
precomputing the embeds allows for lower VRAM use.
caption dropout should be implemented too
There was a problem hiding this comment.
Think caption dropout is not necessary for this, as here we are distilling the CFG score of the teacher model
There was a problem hiding this comment.
I think it's also fine to NOT precompute the text embeddings for now as we're aiming for a bigger training run here. We can revisit this later.
| # encode pixel values with batch size of at most 32 | ||
| latents = [] | ||
| for i in range(0, pixel_values.shape[0], 32): | ||
| latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample()) |
There was a problem hiding this comment.
would be better not to encode latents on the fly, as that substantially increases vram use.
an option to recache the vae latents every epoch would be nice, since then the random crop and random flip are more functional.
There was a problem hiding this comment.
Absolutely okay to not consider this for now:
https://github.com/huggingface/diffusers/pull/6303/files#r1436066871
| image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) | ||
|
|
||
| # get crop coordinates and crop image | ||
| c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) |
There was a problem hiding this comment.
what about args.center_crop ? or a crop that preserves aspect bucketing? legacy SD training greatly benefits from data bucketing.
| image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) | ||
|
|
||
| # get crop coordinates and crop image | ||
| c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) |
There was a problem hiding this comment.
so this unconditionally crops the example even though the crop coords are used conditionally by the value of use_fix_crop_and_size.
Additionally, it uses RandomCrop always, without args.center_crop being taken into account.
Further, it crops to a square every time.
this isn't necessary, you can preserve the aspect ratio.
| else: | ||
| generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | ||
|
|
||
| validation_prompts = [ |
There was a problem hiding this comment.
you can use a prompt dataset here
There was a problem hiding this comment.
For a reference example script, that ain't necessary.
There was a problem hiding this comment.
my suggestion was based on the other training example scripts :-)
patil-suraj
left a comment
There was a problem hiding this comment.
Great start! The script already covers most of the details for ADD. Left some comments.
Will try to give it a run in next few days.
| def transform(example): | ||
| # resize image | ||
| image = example["image"] | ||
| image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) |
There was a problem hiding this comment.
We could make this configurable as well if users prefer to use different interpolations.
| if self.prediction_type == "epsilon": | ||
| pred_x_0 = (sample - sigmas * model_output) / alphas | ||
| elif self.prediction_type == "sample": | ||
| pred_x_0 = model_output |
There was a problem hiding this comment.
the unet returns residual noise prediction only when prediction_type=epsilon , which is the default for most SD models. This is to support different prediction_types
| # Enforce zero terminal SNR (see section 3.1 of ADD paper) | ||
| # TODO: is there a better way to implement this? |
There was a problem hiding this comment.
I think,as the losses are computed in pixel-space, it could still have some effect for epsilon prediction.
| # 1. Decode real and fake (generated) latents back to pixel space. | ||
| # NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the | ||
| # pretrained feature network for the discriminator operates in pixel space rather than latent space. |
| student_gen_image = vae.decode(student_x_0).sample | ||
|
|
||
| # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. | ||
| disc_output_real = discriminator(real_image, prompt_embeds) |
There was a problem hiding this comment.
What kind of image input does the dino model expect ? Since it's normalized with imagenet mean and std, we should convert the decoded images between 0-1 range. Like
real_image = (real_image / 2 + 0.5).clamp(0, 1)
There was a problem hiding this comment.
And I think we should resize the images here as dino expects 224x224 images iirc.
It's done in the FeatureNetwork class.
There was a problem hiding this comment.
did we check the range of expected inputs ?
| lr_scheduler.step() | ||
|
|
||
| # Checks if the accelerator has performed an optimization step behind the scenes | ||
| if accelerator.sync_gradients: |
There was a problem hiding this comment.
we should do ema update here.
| optimizer.zero_grad(set_to_none=True) | ||
|
|
||
| # 1. Rerun the disc on generated image, but this time allow gradients to flow through the generator | ||
| disc_output_fake = discriminator(student_gen_image, prompt_embeds) |
There was a problem hiding this comment.
This term is already computed above. Do we need to recompute it here ? Not sure because here we are not doing vanilla GAN training so we might as well be able to utilise that.
| validation_prompts = [ | ||
| "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography", | ||
| "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", | ||
| "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | ||
| "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", | ||
| ] |
There was a problem hiding this comment.
Prompts from the ADD paper could also be used.
There was a problem hiding this comment.
Changed prompts to those used in the ADD paper (note that all examples images from the paper are generated by ADD-XL).
|
|
||
| # 1. Create the noise scheduler and the desired noise schedule. | ||
| # Enforce zero terminal SNR (see section 3.1 of ADD paper) | ||
| # TODO: is there a better way to implement this? |
There was a problem hiding this comment.
Your current implementation looks good to me!
There was a problem hiding this comment.
Note that I'm currently using DDIMScheduler, which currently supports rescale_betas_zero_snr, but ideally I'd like to use DDPMScheduler here, since my understanding is that DDPMScheduler can typically load DDIMScheduler configs but not vice versa. Since DDPMScheduler currently does not support rescale_betas_zero_snr, I've opened a PR to add it: #6305.
There was a problem hiding this comment.
you can simply use Euler now, instead of DDIM, because it also supports zero-terminal SNR. this would match the behaviour of the ControlNet trainer, which uses Euler.
| pixel_values = image.to(dtype=weight_dtype) | ||
| if vae.dtype != weight_dtype: | ||
| vae.to(dtype=weight_dtype) |
There was a problem hiding this comment.
As hinted by @patil-suraj above, we can safely always have the VAE in a reduced precision in case of SD.
There was a problem hiding this comment.
not for SD 2.1, which ends up with NaN inside the VAE with half precision on finetuned models.. never figured that one out
There was a problem hiding this comment.
Oh, for SD2.1, that is the case?
There was a problem hiding this comment.
yes, it's noticeably an issue when finetuning a 2.1-v model in Diffusers and then trying to do inference later on an AMD system, which lacks xformers etc, the float32 VAE has OOM but float16 has NaN.
sayakpaul
left a comment
There was a problem hiding this comment.
Amazing work. Excited to see how the results come up!
I'd be keen on experimenting with simpler discriminators though. But that is obviously not a blocker.
|
Excuse my ignorance but how do I run it? I seem to get this error trying to run either the lora or SD1.5 version
Traceback (most recent call last): |
|
I think there are a few additional arguments that need to be explicitly supplied for the scripts to not raise an error. Something close to the minimal set of arguments needed is accelerate launch examples/add/train_add_distill_lora_sd_wds.py \
--pretrained_teacher_model="<teacher_model>" \
--train_shards_path_or_url="<dataset>" \
--output_dir="<output_dir>" \
--max_train_steps=1 \
--max_train_samples=20 \
--dataloader_num_workers=8 \assuming the other default values work (for example, Note that the scripts are a work in progress and there's no guarantee that they work currently. |
|
Got it running, ran into bug saving though. Validation images also looked like random noise also. |
| d_r1_regularizer = 0 | ||
| for k, head in discriminator.heads.items(): | ||
| head_grad_params = torch.autograd.grad( | ||
| outputs=d_adv_loss_real, inputs=head.parameters(), create_graph=True |
There was a problem hiding this comment.
according to the paper, the r1 penalty seems to be computed w.r.t head's input instead of head's parameters?
There was a problem hiding this comment.
Thanks for the catch, I think you're right. In fact there seem to be several problems in the current implementation:
- As you pointed out, the gradient penalty should be calculated with respect to the discriminator head inputs. That's not available in the current code, but if the input to discriminator head
kis available asfeatures[k], I think the fix would be to set theinputsargument totorch.autograd.gradtofeatures[k](if I understand autograd correctly). - It looks like I misunderstood the definition of the R1 gradient penalty. The ADD paper cites this paper when discussing the R1 gradient penalty, and the latter paper defines the R1 gradient penalty as
So it seems like we should be using the L2 norm rather than the L1 norm when calculating the gradient penalty. It's also possible that the implementation is off by a factor of
@patil-suraj @sayakpaul does this sound correct to you guys?
There was a problem hiding this comment.
yeah, regarding formula it is l2 norm, something like this
d_r1_regularizer = sum((torch.linalg.vector_norm(grad.view(grad.size(0), -1), dim=1) ** 2).mean() for grad in feature_grads)
There was a problem hiding this comment.
I have added a tentative fix for the discriminator R1 gradient penalty for the SD ADD script in commit ab46142. In particular this part
diffusers/examples/add/train_add_distill_sd_wds.py
Lines 1860 to 1862 in ab46142
feels weird to me but seems necessary because the features in features_real don't usually have gradients because the feature_network is frozen. @erliding @patil-suraj @sayakpaul would be great if you could look this over
There was a problem hiding this comment.
yep, you need to explicitly call feature.requires_grad_() before passing them to heads
There was a problem hiding this comment.
There was a problem hiding this comment.
@sayakpaul as far as my understanding, the gradient penalty could help enforce Lipschitz continuity (thus gradient of output w.r.t input) on discriminator which is a requirement of Wasserstein Gan
There was a problem hiding this comment.
Right, but we aren't using the earth mover distance here in the loss no?
There was a problem hiding this comment.
yeah, it's hinge loss here :) gradient penalty should bring similar benefits though
There was a problem hiding this comment.
@dg845 It does look good to me. We need to enable grads for feature inputs to be able to compare the grad penalty, if the input does not have grad enabled I think create_graph=True will complain.
for reference, DDGAN training has this
…'t fixed autograd call yet).
|
|
||
| if accelerator.sync_gradients: | ||
| accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm) | ||
| discriminator_optimizer.step() |
There was a problem hiding this comment.
if i understand correct, when the gradient_accumulation_steps > 1current implementation seems result in a simultaneous gradient descent up to the last accumulation step for each batch, when gradient_accumulation_steps == 1 it is alternating gradient descent, while in the stylegan-t it is always alternating gradient descent
There was a problem hiding this comment.
How important is that to follow? Is it absolutely a must for training stability?
There was a problem hiding this comment.
convergence behavior of simultaneous gradient descent and alternating gradient descent are different when achieving min-max equilibrium, from those gan implementations lately, seems alternating gradient descent is usually adopted, not sure how important they could help for the case of ADD though
…ator R1 gradient penalty.
| if args.use_image_conditioning: | ||
| image_embedding = encoded_image.pop("image_embeds").float() | ||
| # Only supply image conditioning when student timestep is not last training timestep T. | ||
| image_embedding = torch.where( |
There was a problem hiding this comment.
currently the same masked image_embedding is fed to discriminator for both real and fake images, but it seems for which real image its image_embedding being masked out could be random with a rate say 1 / num_inference_steps instead of depending on student_timesteps, not sure if this could make big difference though
There was a problem hiding this comment.
Yeah same. I am sure either how impactful this would be.
| student_index = torch.randint(0, student_distillation_steps, (bsz,), device=latents.device).long() | ||
| student_timesteps = student_timestep_schedule[student_index] | ||
| teacher_timesteps = torch.randint( | ||
| 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device |
There was a problem hiding this comment.
it might better not to sample teacher_timesteps from the full range of [0, noise_scheduler.config.num_train_timesteps) but instead ignoring timesteps that are too small or too big, e.g. the default configurable range from dream fusion is [0.02, 0.98] * num_train_timesteps
There was a problem hiding this comment.
I guess Figure 2 of the ADD paper implies that they sample from the full range of teacher timesteps:
But this is definitely something we can try out :).
|
How far away is this pr from being merged? |
|
Hi @SteamedGit the ADD implementation is nominally complete but I have not been able to test whether the script can distill good models (e.g. for SD v1.5) yet. |
…itive value instead of zero following EulerDiscreteScheduler.
…y whether we use a CLIPTextModel or CLIPTextModelWithProjection (e.g. with --use_pretrained_projection).
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Not stale. |
|
@sayakpaul @dg845 Great job! Can someone please confirm if the effectiveness of this PR has been verified?@ |
|
regarding computing sds loss i suggest taking a look at https://arxiv.org/abs/2306.04619 which tends to produce a better target |
|
@cjt222 sorry, I haven't been able to finish testing it yet. Will hopefully find more time to work on it soon 😅. |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |



What does this PR do?
This PR adds an example script for adversarial diffusion distillation (ADD) (paper, code), a distillation + adversarial training method used to distill SD/SD-XL Turbo.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@patrickvonplaten
@sayakpaul
@patil-suraj