-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Description
With Pivotal Tuning for SDXL it's common to train 2 tokens per text-encoder. The current API for loading it is as follows
pipe.load_textual_inversion(state_dict["text_encoders_0"][0], token=["<s0>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipe.load_textual_inversion(state_dict["text_encoders_0"][1], token=["<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipe.load_textual_inversion(state_dict["text_encoders_1"][0], token=["<s0>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
pipe.load_textual_inversion(state_dict["text_encoders_1"][1], token=["<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)I think there could be better ways, maybe
To me, the ideal scenario is:
We standardize the name of the state dicts to match the name of the text encoders, so that gets automatically mapped out. If the state dict contains multiple inputs, we allow/require those multiple tokens
Suggested API
pipe.load_textual_inversion("name/repo", weight_name="embeddings.safetensors", token=["<s0>", "<s1>"])where in the embeddings.safetensors we have a state dict called text_encoder and another one called text_encoder_2 and that gets mapped out accordingly to the respective text_encoder and tokenizer (if the user doesn't provide text_encoder and tokenizer explicitly)
I don't think it needs to be this particular API idea, but I think any move in the direction of simplifying the API would be great!