diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index d80e3388a69a..454faff0a176 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -1178,7 +1178,7 @@ def main(args): grads_text_encoder = text_encoder.get_input_embeddings().weight.grad # Get the index for tokens that we want to zero the grads for index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0] - for i in range(len(modifier_token_id[1:])): + for i in range(1, len(modifier_token_id)): index_grads_to_zero = index_grads_to_zero & ( torch.arange(len(tokenizer)) != modifier_token_id[i] )