Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v-prediction training support #1455

Merged
merged 10 commits into from Nov 28, 2022
Merged

v-prediction training support #1455

merged 10 commits into from Nov 28, 2022

Conversation

patil-suraj
Copy link
Contributor

This PR adds support for v-prediction training in

  • textual-inversion
  • dreambooth
  • text-to-image fine-tuning

This allows fine-tuning the SD2 768x768 model with these scripts.

To enable this, it adds get_velocity method to DDPM and DDIM scheduler to get the target during training. The type of training is automatically detected inside script using the noise_scheduler.config.prediction_type argument.

Users will just have set the right resolution in the command, 512 for all models except the 768 one. And 768 for the stable-diffusion-2 model.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 28, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing!

It looks like the DDIM scheduler is never used for training, is that correct? Do we need a get_velocity function for it in that case?


# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we rename noise_pred too? (Same comment for the other scripts)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, we could call it model_output instead of noise_pred, wdyt ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_output or pred both sound fine to me, whatever you think is clearer. I think we use model_output in more places though, so that'd be better then.

Comment on lines +38 to +44
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Comment on lines +358 to +376
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very similar to add_noise doesn't it? Would it make sense to make both implementations rely on a common function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, could refactor it in a follow-up PR , wdyt @patrickvonplaten

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we use both add_noise and get_velocity in the same training step for v-prediction, I'm ok with keeping them separate. But longer-term we might benefit from factoring out or condensing the alpha_prod code in both functions (everything above velocity =).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree ! I think it's clear to add it directly to add_noise and then maye use of self.config.prediction_type - could we do this in this PR maybe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scratch that it doesn't work as expected

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In favour of keeping get_velocity because for v-prediction we need both the velocity and noised image, so if we modify add_noise, we'll need to return a tuple as output, which will complicate it bit.

get_velocity is clearer to understand as it makes it clear that it's different from add_noise.

Comment on lines +348 to +367
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as in DDIM.

@patil-suraj
Copy link
Contributor Author

It looks like the DDIM scheduler is never used for training, is that correct? Do we need a get_velocity function for it in that case?

It's not really used for training, but think it can be. Adding the method just for consistency.

@pcuenca
Copy link
Member

pcuenca commented Nov 28, 2022

It's not really used for training, but think it can be. Adding the method just for consistency.

Cool! We can add support for dpm solver too in a future PR :)

Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment on lines +358 to +376
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)

sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we use both add_noise and get_velocity in the same training step for v-prediction, I'm ok with keeping them separate. But longer-term we might benefit from factoring out or condensing the alpha_prod code in both functions (everything above velocity =).

@nlml
Copy link

nlml commented Dec 5, 2022

Hmmm, did this break the script for older stable diffusion models? When I try to run with MODEL_NAME="runwayml/stable-diffusion-v1-5" now, I get the following error:

│ ❱ 557 │   │   │   │   if noise_scheduler.config.prediction_type == "epsilon":                    │
│   558 │   │   │   │   │   target = noise                                                         │  
│   559 │   │   │   │   elif noise_scheduler.config.prediction_type == "v_prediction":             │
│   560 │   │   │   │   │   target = noise_scheduler.get_velocity(latents, noise, timesteps)       │                                         
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯                                         
AttributeError: 'FrozenDict' object has no attribute 'prediction_type'

because the noise_scheduler.config FrozenDict does not contain this key

@patil-suraj
Copy link
Contributor Author

Hey @nlml , prediction_type is only recently added, so you'll need to install diffusers from main to run the examples scripts.

sliard pushed a commit to sliard/diffusers that referenced this pull request Dec 21, 2022
* add get_velocity

* add v prediction for training

* fix saving

* add revision arg

* fix saving

* save checkpoints dreambooth

* fix saving embeds

* add instruction in readme

* quality

* noise_pred -> model_pred
@pkurz3nd
Copy link

pkurz3nd commented Jan 3, 2023

Hello, i tried training using stable-diffusion-2-1 with resolution 768, but i get nan loss in the very first iteration
before optimizer step is performed.
i use this command for training:

!accelerate launch train_dreambooth.py
--pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1"
--revision="fp16"
--instance_data_dir="./data/jonny_depp"
--class_data_dir="./data/popular_person"
--output_dir=$OUTPUT_DIR
--with_prior_preservation --prior_loss_weight=1.0
--instance_prompt="photo of xyz jonny depp"
--class_prompt="professional photographic of popular person, high quality image"
--resolution=768
--train_batch_size=1
--gradient_accumulation_steps=4
--learning_rate=1e-6
--lr_scheduler="polynomial"
--lr_warmup_steps=0
--num_class_images=10
--max_train_steps=800
--train_text_encoder
--mixed_precision="fp16"
--use_8bit_adam
--gradient_checkpointing
--sample_batch_size=1

i am using
diffusers==0.11.1
torch==1.13.1+cu117
accelerate==0.12.0

i checked all the model parts, the problem is with unet, it outputs a all nan tensor. the inputs to unet dont contain any nans or infs
any ideas?

@patil-suraj
Copy link
Contributor Author

Hey @pkurz3nd , this is a known issue with the table-diffusion-2-1 model. It overflows when using fp16 during training. To train this model, either

  • use fp32
  • or use xformers when using fp16

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* add get_velocity

* add v prediction for training

* fix saving

* add revision arg

* fix saving

* save checkpoints dreambooth

* fix saving embeds

* add instruction in readme

* quality

* noise_pred -> model_pred
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants