Explosive variance observed in latents
and noise_pred
when using torch.autocast()
#119
-
Hello, recently I tried training LoRA using your code https://github.com/PixArt-alpha/PixArt-alpha/blob/93f6bfe8942664052937b78171d3ed5ed56d8dba/train_scripts/train_pixart_lora_hf.py but ran into an issue that the DiT (PixArt model) does not work with torch.autocast(). Briefly, PixArt-alpha returns drastically differnet output if cast to fp16. Details below. Before everything, my bash script is:
And my accelerator configuration is
My dataset contains 13 pixel-art-styled images. To begin with, running the code from the repo directly returns error
The next issue is
Line 893 +
Now the training code can work without errors. But, the validation images generated look as if the fine-tuning not working at all and every validation image looks invariant throughout the training process. Below are the validation images at step 2, 300, 700, 1000 and the image generated by the testing (final inference) script around L970: Notice that though validation images are almost invariant through time, the test image shows the actual effectiveness of the trained LoRA. In the codes, the only difference is test inference uses Tricky is, if we apply Especially, the first image is pure black in which case PixArt is turned to fp16 by autocast() without LoRA attached. But if without using autocast, the variance stays stable: Here is my inference code with
Does anyone have an idea on why |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 18 replies
-
Hi @AlezHibali , Really nice observation. Do u mean that if we remove the |
Beta Was this translation helpful? Give feedback.
-
HI @AlezHibali , i noticed something similar mentioned here on inference it generates as if no peft model is loaded i did the change below, once i noticed this problem - using the base 512 model: this works:
while this does not:
so i removed the dtype params from the "to" calls in the training script. Here is the patch that i created after i observed the point above. https://github.com/raulc0399/PixArt-alpha-finetuning/blob/main/train_pixart_lora_hf_2.patch basically i load the models in float16 and move them to cuda without convertion. using the train_hf.sh from the my repo i managed after 100 epochs to fine-tune on simpsons. attached a sample generated after fine-tuning. is not quite there yet, but you can see that is learning. |
Beta Was this translation helpful? Give feedback.
-
Hi guys. The Lora training script fixing PR is merged: cab13f2 |
Beta Was this translation helpful? Give feedback.
@lawrence-cj 2 small changes in the PR, i added params for dora and rslora
https://huggingface.co/docs/peft/package_reference/lora#peft.LoraConfig.use_rslora
https://huggingface.co/docs/peft/package_reference/lora#peft.LoraConfig.use_dora
both of them show better results on my dataset.
@AlezHibali might give you better results as well.