Skip to content

Simplify the load_textual_inversion API #5733

@apolinario

Description

@apolinario

With Pivotal Tuning for SDXL it's common to train 2 tokens per text-encoder. The current API for loading it is as follows

pipe.load_textual_inversion(state_dict["text_encoders_0"][0], token=["<s0>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipe.load_textual_inversion(state_dict["text_encoders_0"][1], token=["<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipe.load_textual_inversion(state_dict["text_encoders_1"][0], token=["<s0>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
pipe.load_textual_inversion(state_dict["text_encoders_1"][1], token=["<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)

I think there could be better ways, maybe

To me, the ideal scenario is:
We standardize the name of the state dicts to match the name of the text encoders, so that gets automatically mapped out. If the state dict contains multiple inputs, we allow/require those multiple tokens

Suggested API

pipe.load_textual_inversion("name/repo", weight_name="embeddings.safetensors", token=["<s0>", "<s1>"])

where in the embeddings.safetensors we have a state dict called text_encoder and another one called text_encoder_2 and that gets mapped out accordingly to the respective text_encoder and tokenizer (if the user doesn't provide text_encoder and tokenizer explicitly)

I don't think it needs to be this particular API idea, but I think any move in the direction of simplifying the API would be great!

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions