From 8dc023010f2d0e5aac7e56b06c4378df07d0c3aa Mon Sep 17 00:00:00 2001 From: rbrq03 Date: Wed, 13 Mar 2024 06:33:19 +0000 Subject: [PATCH] fix index in set textencoder grad --- examples/custom_diffusion/train_custom_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 790454be7a2e..aa00f7ecf61a 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -1178,7 +1178,7 @@ def main(args): grads_text_encoder = text_encoder.get_input_embeddings().weight.grad # Get the index for tokens that we want to zero the grads for index_grads_to_zero = torch.arange(len(tokenizer)) != modifier_token_id[0] - for i in range(len(modifier_token_id[1:])): + for i in range(1, len(modifier_token_id)): index_grads_to_zero = index_grads_to_zero & ( torch.arange(len(tokenizer)) != modifier_token_id[i] )