Skip to content

Conversation

@yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Sep 4, 2023

🚨🚨🚨 Note: The main author of this PR is @cene555 and the Kandinsky team. In this PR we mainly just make some small modifications to the training script provided by the authors from original PR so that it's consistent with all our other training scripts. Thanks a million for the contribution @cene555 🚨🚨🚨

Authors of this PR:
Arseniy Shakhmatov
Anton Razzhigaev
Aleksandr Nikolich
Igor Pavlov
Andrey Kuznetsov
Denis Dimitrov

Note:
will include dreambooth in a separate PR

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 4, 2023

The documentation is not available anymore as the PR was closed or merged.



class PriorTransformer(ModelMixin, ConfigMixin):
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little bit surprised to see UNet2DConditionLoadersMixin works out of box for PriorTransformer:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just diffusers things 😎

Comment on lines 309 to 317
### Training with xFormers:

You can enable memory efficient attention by [installing xFormers](https://huggingface.co/docs/diffusers/main/en/optimization/xformers) and passing the `--enable_xformers_memory_efficient_attention` argument to the script.

xFormers training is not available for Prior model fine-tune.

**Note**:

According to [this issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training in some GPUs. If you observe that problem, please install a development version as indicated in that comment. No newline at end of file
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have LoRA too. Let's include a section on LoRA as well in the README.

block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRAAttnAddedKVProcessor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Seems like we don't have a faster implementation of this processor. Based on the usage of the training scripts, I think we can monitor and added them later if needed. WDYT?

model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]

if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC Kandinsky models weren't trained using the epsilon prediction objective. If so, does this still lead to reasonable results?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I'm not sure what's the objective used in decoder training, I looked back into the dalle-2 paper, and I think it only mentioned that it predicts samples directly for prior

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work!

Questions / comments:

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thanks for writing this training guide up! 😄

from diffusers import AutoPipelineForText2Image
import torch

pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, here we want to create a combined pipeline, so will have to use the decoder checkpoint for that
once we have the pipe as the combined pipeline, we then access the prior with pipe.prior_prior

yiyixuxu and others added 16 commits September 10, 2023 18:48
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
@yiyixuxu
Copy link
Collaborator Author

Is it okay to merge now? Maybe @sayakpaul can take a look again
I don't think we need to add this to doc; I added a link to this folder on the text-to-image training doc page instead

@yiyixuxu
Copy link
Collaborator Author

@sayakpaul about your questions

Do we not LoRA fine-tune the text encoders as typically done in the SD world?

kandinsky does not have a pure text-conditioned diffusion process - it decodes image embedding instead. it does use a text encoder but I'm not sure how important it is to fine-tune it. Maybe the community can try it out

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dope! Thank you so much for seeing this through, Yiyi!

@yiyixuxu yiyixuxu merged commit e70cb12 into main Sep 14, 2023
@yiyixuxu yiyixuxu deleted the kandinsky-finetune branch September 14, 2023 16:58
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Add files via upload

Co-authored-by: Shahmatov Arseniy <62886550+cene555@users.noreply.github.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Add files via upload

Co-authored-by: Shahmatov Arseniy <62886550+cene555@users.noreply.github.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants