Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of Concept: Weighted captions #336

Merged
merged 5 commits into from
Apr 8, 2023

Conversation

AI-Casanova
Copy link
Contributor

This PR is for testing the effect of weighted captions, not for imminent merging.

Implementation of weighted captions requested here #312 and elsewhere on Discord etc.

With --weighted_captions one can now use captions with the standard (token:1.2) syntax, though one must be careful not to put a comma inside the parentheses if using --caption_tag_dropout --shuffle_caption etc as that would break the closing parens.

The entire embeddings are scaled to their previous mean after individual embeddings are scaled up, so increasing the weight of your foreground (subject) tokens effectively downscales the rest of them.

I ran 2 quick runs with --weighted_captions one with standard unweighted captions to test the tokenization, and another with (<token> woman:1.2)

image

Losses track exactly with deterministic training, which is unsurprising to me.

xyz_grid-0190-RPGv4
Images prompted only with <token> woman to proof the effect. Its fairly small, but my dataset has fairly short captions, so there's not a lot of comparative scaling perhaps.

This code was adapted from lbw_stable_diffusion.py

This is likely not the best implementation, and only implemented for train_network.py for tonight, (it could be handled natively in train_util.py for instance) but this disturbs the rest of the repo as little as possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I think I can ditch a few of those imports, will confirm tomorrow.


prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
Copy link

@swfsql swfsql Apr 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AI-Casanova Please correct me if I'm wrong, but this appears to overwrite the previous value of prompt_weights to a fixed value? Is it also overwriting the value of prompt_tokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me consult the source material.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof I think I screwed up an if/then deletion, will push a fix, and test as soon as I get home.

@swfsql
Copy link

swfsql commented Apr 2, 2023

I've made some tests, and I don't have too many conclusions so far.

I've used a single image for training with a lr of 5e-4 (quite high I guess) for 1k steps, generating previews every 25 steps, and trained different lora models by changing the captions.

training image
  • lora-A image caption: a photo of a (shs:1.0) (shoe:1.0)
  • lora-B image caption: a photo of a (shs:1.8) (shoe:0.3)
  • lora-C image caption: a photo of a (shs:0.3) (shoe:1.8)
  • lora-D image caption was two copies of the same image, one with a caption like lora-B and the other image with a caption like lora-C.

Then I tried generating horses (not shoes), such as A shs horse walking on the street, to see what I would get.

loss graph

The conclusions I got:

  • All loras: They got very similar loss (except lora-D, which was a little bit different but ended up close to the others).
  • lora-B: training a token (shs) with a high attention let you use that token also with a high attention without "exploding" the image.
  • lora-D: the only one to completely overft. The last 20% of previews (last 200 training steps) generated basically the exact input image (the shoe), not a horse.

The models generated different images, but I couldn't see any clear difference when swapping the attentions like those tests that I did.


The only "fun" image I got was a shoe-ish horse on step 750 of lora-C, but overall all of the lora-models ended up not really interesting. I thought I would get weird glowing horses or something, but got none of those.

shoe-ish horse?

@AI-Casanova
Copy link
Contributor Author

Sorry about the dirty git log, I had to force the repo to merge.

Thanks to @swfsql sharp eye, I caught the issue that was preventing weighted embeddings from calculating effectively.

All loras: They got very similar loss

This is actually to be expected. Loss is driven almost entirely (per step at least) by the number of timesteps in that particular batch.

In fact, with a set seed, changing caption weights leaves loss within a rounding error, yet significantly changes the genned images
image
xyz_grid-0089-RPGv4
xyz_grid-0088-RPGv4

Captions were created with https://github.com/pharmapsychotic/clip-interrogator-ext and the "Best" setting which creates a fairly long caption. Captions were minimally cleaned, then the token and class were wrapped with (token class:1) and (token class:1.2) respectively. This makes background and putative style tags less important in comparison to the token and class.

@AI-Casanova AI-Casanova reopened this Apr 2, 2023
@TingTingin
Copy link
Contributor

Just for clarification if you set the the value to 1 it effectively turns it off? also does a value like (token :0.5) does this weaken the target token?

@AI-Casanova
Copy link
Contributor Author

@TingTingin correct on both accounts.

I believe you can also technically use negative weights, but I have not tested them, and they would not operate in the sense of a negative prompt during generation.

@TingTingin
Copy link
Contributor

Could this be used for example to ignore watermarks in images?

@kohya-ss
Copy link
Owner

kohya-ss commented Apr 8, 2023

@AI-Casanova Thank you for this great PR! I think this PR is ready to merge.

I've merged other big PRs, so I will merge this today! If you still have any concerns, please let me know.

@AI-Casanova
Copy link
Contributor Author

@kohya-ss Awesome!

Removed commented out lines from earlier bugfix.
@kohya-ss kohya-ss merged commit a75f589 into kohya-ss:dev Apr 8, 2023
@kohya-ss
Copy link
Owner

kohya-ss commented Apr 8, 2023

I've merged to main. I disabled the option in TI and XTI training.

Thank you for this great work!

@AI-Casanova
Copy link
Contributor Author

Of course this is when I realize I didn't need to duplicate code, and could have just imported lpw_stable_diffusion.py directly. Well the code runs fine as is, so no need to refactor at this point.

@enranime
Copy link

Sorry but is there a method to not include captions weight into the token length? because when I train with weight captions the token length its always exceed the limit (225)

@CognitiveDiffusion
Copy link

CognitiveDiffusion commented Sep 28, 2023

This is amazing.

When will we see this great feature for SDXL?

And would it be possible to ALSO add a Tag that let's the Training know which "tags" (1 or 2 long, natural language tags) are for TextEncoder G?

It could be:
(G: #this is a natural language prompt#)

And all other prompts (without (G: natural language prompt) Syntax) could be directed towards TextEncoder I (which is trained on tags).

Please give me feedback on this. I feel like this could be huge for SDXL training.
@bmaltais @AI-Casanova

@AI-Casanova
Copy link
Contributor Author

@CognitiveDiffusion honestly for SDXL LoRA training the best solution I've found (at least for people/characters) is training a TI, and then training a LoRA using the TI as a prompt (the two need to be used together in inference as well)

I got this idea from cloneofsimo's PTI implementation, and it seems to work wonderfully for the two text encoders of SDXL, far better than text encoder training.

I'd share my implementation, but it's completely hackneyed and hard coded. Unfortunately I'm on OT at work and don't have the time to develop it.

@kohya-ss would you be interested in adding the ability to load TI embeddings into train_network.py? The results are quite impressive.

@kohya-ss
Copy link
Owner

kohya-ss commented Oct 1, 2023

@kohya-ss would you be interested in adding the ability to load TI embeddings into train_network.py? The results are quite impressive.

It is very interesting. PTI training is really hard to implement, because it trains different modes/parameters simultaneously. But loading TI embeddings will be able to do similar thing. I'd like to implement it sooner.

@kohya-ss
Copy link
Owner

kohya-ss commented Oct 1, 2023

When will we see this great feature for SDXL?

Unfortunately, I don't know when that will be. It is relatively difficult to implement and I would prefer to prioritize other tasks. I hope to implement it someday.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants