Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Critical] Very high loss rate at first few tokens (classifier free guidance not working) #80

Closed
MarcusLoppe opened this issue May 4, 2024 · 66 comments

Comments

@MarcusLoppe
Copy link
Contributor

MarcusLoppe commented May 4, 2024

@lucidrains
This is a issue I'm having a while, the cross-attention is very weak at the start of the sequence.
When the transformer starts with no tokens it will relay on the cross-attention but unfortunately the cross-attention doesn't work for the first token(s).

Proof

To prove this I trained a dataset of 500 models that have unique text embeddings and no augmentations, then I only took the first 6 tokens of the mesh and train on that.
After training for 8hrs, it's still stuck at 1.03 loss.

Without fixing this issue, the auto-regression without a prompt of tokens will never work.

This problem has been ongoing for a while but I thought it was a issue of training and using a model that has been trained on the first few tokens would resolve this. However that isn't the case.
Real-life example
To highlight the issue, I trained a model on the 13k dataset then removed all the augmentation copies and removed models with duplicate labels.
If I provide it with the first 2 tokens as a prompt it will autocomplete without no problem and no visual issues, however if i provide it with 1 or 0 tokens it fails completely.

Checked the logits

I investigated this further and checked the logits when it generated the first token, the probability for correct token was at the 9th most probable token.
I tried to implement a beam search with beam width of 5 but since the first token has such a low probability, it would require a lot of beams which probably will work but this seems like a brute force solution isn't very good.
It may work to do a beam search of 20 and then kill of the solutions which seems to have a low probability/entropy, but this seems like a bandage solution that might not work with scaling up meshgpt.

Why is this a problem?

The first tokens are very important for the generation since it's a domino effect, if it gets the incorrect token at the start, the generation will fail since it relays to much on the sequence to auto-correct.
It's like if the sentence is "Dog can be a happy animal" and it predicts "Human" as the first token, it won't be able to auto-correct since sentence is already messed up and the chances it will auto-correct to "Human got a dog which can be a happy animal" is extremely hard.

Possible solution

Since the cross-attention is used only on the "big" decoder, can it also be implemented for the fine decoder?

Attempts to fix:

  • I've tried removing the fine decoder and fine gateloop
  • I also tried increasing cross_attn_num_mem_kv but found no signifiant changes.
  • I replaced theTextEmbeddingReturner with AttentionTextConditioner but still no changes.
  • Tried using different text encoder such as BGE and CLIP.

This has been a problem for a long time and I've mentioned in the issues threads as a note so I'm creating a issue for it since it really prevents me from releasing fine-tuned models.

I got a model ready to go that can predict 13k models but since the first tokens make the autoregressive generation makes it impossible, I've not released it yet.

Here is some images over the loss:
bild

@MarcusLoppe MarcusLoppe changed the title Very high loss rate at first few tokens [Critical] Very high loss rate at first few tokens May 4, 2024
@MarcusLoppe MarcusLoppe changed the title [Critical] Very high loss rate at first few tokens [Critical] Very high loss rate at first few tokens (classifier free guidance not working) May 5, 2024
@pathquester
Copy link

This sounds critical indeed. Hopefully it's an easy fix.

@MarcusLoppe
Copy link
Contributor Author

MarcusLoppe commented May 6, 2024

@lucidrains

I think I've resolved this issue by tokenizing the text and insert it at the start of the codes and add a special token to indicate the start of the mesh tokens.
However the downside with this is that the transformer needs to use a larger vocab, any idea how if it's possible to reduce the vocab size it's predicting for?

I tested it on a smaller dataset but it seems to be working!
I think this will also guide the transformer much better.

bild

@pathquester
Copy link

@MarcusLoppe That is fantastic! Have you posted the fix somewhere?

@MarcusLoppe
Copy link
Contributor Author

@MarcusLoppe That is fantastic! Have you posted the fix somewhere?

Not yet, my current way is bit hacky and requires bit of a rewrite to properly implement.

I'm currently verifying the solution on bit bigger dataset and will hammer out all the possible bugs.

@lucidrains
Copy link
Owner

lucidrains commented May 10, 2024

@MarcusLoppe hey Marcus, thanks for identifying this issue

have you tried turning off CFG? if you haven't, one thing to try is simply turning off CFG for the first few tokens. i think i've come across papers where they studied at which steps CFG is even effective

also try turning off CFG and do greedy sampling and see what happens. if that doesn't work, there is some big issue

@MarcusLoppe
Copy link
Contributor Author

@MarcusLoppe hey Marcus, thanks for identifying this issue

have you tried turning off CFG? if you haven't, one thing to try is simply turning off CFG for the first few tokens. i think i've come across papers where they studied at which steps CFG is even effective

also try turning off CFG and do greedy sampling and see what happens. if that doesn't work, there is some big issue

With CFG you mean classifier-free guidance?

Not sure how I would go about that, do you mean setting cond_drop_prob to 0.0?
I've tried that and as far as I can tell the CFG just returns the embedding without any modifications (if cond_drop_prob is set to 0 since then it won't mask the text embedding).

The issue lies with when the transformer has a empty sequence and only the text embedding to go from. The text embedding doesn't seem to help very much so it doesn't know what token to pick, hence the huge loss at the start.

@lucidrains
Copy link
Owner

@MarcusLoppe oh, maybe it is already turned off

so CFG is turned on by setting cond_scale > 1. when invoking .generate

if you haven't been using cond_scale, then perhaps it was never turned on

@lucidrains
Copy link
Owner

lucidrains commented May 10, 2024

@MarcusLoppe oh crap, do i only append the start token for the fine transformer?? 🤦 yes you are correct, it is never conditioned then for the first set of fine tokens

@lucidrains
Copy link
Owner

lucidrains commented May 10, 2024

thank you, this is a real issue then. i'll add cross attention to the fine transformer later today

edit: another cheap solution would be to project text embeddings, pool it, and just use that as the initial sos token

@MarcusLoppe
Copy link
Contributor Author

MarcusLoppe commented May 10, 2024

@MarcusLoppe oh crap, do i only append the start token for the fine transformer?? 🤦 yes you are correct, it is never conditioned then for the first set of fine tokens

Awesome, however my 'fix' seems to be working however.
By provide the text in the form of tokens in the sequence the fine-decoder will get the text context and it also helps creating a stronger relationship with the tokens and speed up the training.
So the tokens it trains on is like: "chair XXXXXXXX" (where X is the mesh tokens).

The downside is that it needs a bigger vocab which slows the training bit but the stronger relationship between the mesh tokens and the text seems to be working :)

thank you, this is a real issue then. i'll add cross attention to the fine transformer later today

I had some issues with proving the context to the fine-decoder since the vector changes shapes but you might be able to solve it.

However I tried removing the gateloop and fine-decoder so the main decoder is the last layer, but unfortunately it had the same issue.

@lucidrains
Copy link
Owner

@MarcusLoppe yup, your way is also legit 😄

you have a bright future Marcus! finding this issue, the analysis, coming up with your own solution; none of that is trivial

@MarcusLoppe
Copy link
Contributor Author

@MarcusLoppe yup, your way is also legit 😄

you have a bright future Marcus! finding this issue, the analysis, coming up with your own solution; none of that is trivial

Thank you very much 😄 Although it took a while I think I've learned one or two things on the way 😄

thank you, this is a real issue then. i'll add cross attention to the fine transformer later today

edit: another cheap solution would be to project text embeddings, pool it, and just use that as the initial sos token

I don't think the cross-attention will be enough, as per my last reply i removed the fine-decoder and gateloop and had the same issue.

If you think about the multimodal generative models they never start from token 0. For a example the vision models has a prompt with a specific request from the user.
So it has the first few tokens and some sort of goal or idea what to generate, then the cross-attention will do it's job and provide the addition context.
So the generative has a more 'probabilistic path' start to get to the correct answer.

I think projecting the text embeddings might be the better way in this case.

@lucidrains
Copy link
Owner

@MarcusLoppe yup i went with the pooled text embedding summed to the sos token for now

let me know if that fixes things (or not) 🤞

@lucidrains
Copy link
Owner

this was really my fault for designing the initial architecture incorrectly

the sos token should be on the coarse transformer

@MarcusLoppe
Copy link
Contributor Author

@MarcusLoppe yup i went with the pooled text embedding summed to the sos token for now

let me know if that fixes things (or not) 🤞

Awesome! I'll check it out 🚀

However with the last x-transformers update I'm getting the error below.
The num_mem_kv doesn't seem to be picked up or trimmed by:
"attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)"

And the dim_head in meshgpt isn't being passed correctly as it should be: "attn_dim_head "

-> 1057 assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
1059 dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1061 self.dim = dim

AssertionError: unrecognized kwargs passed in dict_keys(['dim_head', 'num_mem_kv'])

@lucidrains
Copy link
Owner

@MarcusLoppe ah yes, those should have attn_ prepended, should be fixed in the latest version

@MarcusLoppe
Copy link
Contributor Author

MarcusLoppe commented May 10, 2024

@lucidrains

Alright here is some results.
Using the CLIP embedding model (higher distances in the embedding space) with a GPT-small size transformer:

I first trained using a small set of 350 models, which have a total of x5 augments each. It only contains 39 unique labels so there are some overlap with the texts.
Previous test just produced a blob of triangles, this time it outputted all tents and a blob.
bild

I then took the same model and removed all augmentations so it's x1 of each model and unique texts for each model.
This outputted somewhat better results but it's still not following the text guidance.
I checked the logits and the first token generate was for a bench model and the correct was at the 19th placement and had the value 0.013.
bild

And as you can see, the loss at the start didn't show any improvements :/
bild

For sanity check I trained a fresh model on the x1 to 0.004 loss but as you can see it didn't help. Might made it worse.
bild

I did the same test previously using my method with tokenized text I was able to get all perfect results using the x1 (did not test x5), so that would indicate that the issue that the cross attention relationship when there is no tokens isn't strong enough.

Btw I tested just adding fake tokens by increasing the codebook and used e.g codebook_size +1 (eos at +2) at the start but that didn't change anything.

@lucidrains
Copy link
Owner

@MarcusLoppe ok, i threw in cross attention conditioning for fine transformer in 1.2.3

if that doesn't work, i'll just spend some time refactoring so the start token is in the coarse transformer. that would work for sure, as it is equivalent to your solution, except the text tokens do not undergo self attention

@lucidrains
Copy link
Owner

@MarcusLoppe thanks for running the experiments!

@MarcusLoppe
Copy link
Contributor Author

MarcusLoppe commented May 11, 2024

@MarcusLoppe ok, i threw in cross attention conditioning for fine transformer in 1.2.3

if that doesn't work, i'll just spend some time refactoring so the start token is in the coarse transformer. that would work for sure, as it is equivalent to your solution, except the text tokens do not undergo self attention

@lucidrains

The loss rate improved much better over the epochs, however it had some downside.
Before it generate 100 tokens/s, now it went down to 80 t/s, but I prefer this version much more I think since this will cut down the training speed.
Since inference time increased so did the per epoch, using a 2k dataset it went from 02:28 to 02:42, however I saw better loss improvements.

Unfortunately it did not work :(
However something to note is that it worked before using the demo mesh dataset that consist of 9 meshes.

Cond_scale 1:
bild
Cond_scale 3:
bild

bild

@lucidrains
Copy link
Owner

@MarcusLoppe ah, thank you

ok, final try

will have to save this for late next week if it doesn't work

@MarcusLoppe
Copy link
Contributor Author

@MarcusLoppe ah, thank you

ok, final try

will have to save this for late next week if it doesn't work

It worked better, here is the result of training it on 39 models with unique labels, however you can still see a spike in the start of the sequence meaning that it might not be resolved.

bild
bild

Using my method I managed to get these results below, it manages to generate quite complex objects.
However the start is still bit weak, it would help if you manage to make it so the sos token is in the coarse transformer, this will help the training time a lot since it can reduce the vocab size from 32k to 2k :)

I've also experiment with using 3 tokens per triangle and the autoencoder seems to be working, however it makes the training progression for the transformer slower. But considering that VRAM requirement for training on 800 triangle meshes would go from 22GB to 9GB and half the generation time, I think that is something worth exploring.

However I think that the autoencoder could also benefit from getting the text embeddings, I tried to pass it as the context in the linear attention layer but since it requires the the same shape as the quantized input it won't accept it nor I think it would be very VRAM friendly to duplicate the text embedding to the number of faces.
Do you think there is a easy fix for this? I think it would reduce the codebook size a lot and help create codes with closer relationships to the text which would benefit the transformer a lot.

bild

@lucidrains
Copy link
Owner

lucidrains commented May 13, 2024

@MarcusLoppe that is much better (an order of magnitude)! thank you for the experiments Marcus!

i know how to improve it (can add multiple sos to give the attention more surface area)

@lucidrains
Copy link
Owner

@MarcusLoppe i'll get back to this later this week 🙏

@lucidrains
Copy link
Owner

@MarcusLoppe oh, the sos token has already been moved to the coarse transformer in the latest commit. that's where the improvement you are seeing is coming from

@MarcusLoppe
Copy link
Contributor Author

MarcusLoppe commented May 13, 2024

@MarcusLoppe that is much better (an order of magnitude)! thank you for the experiments Marcus!

i know how to improve it (can add multiple sos to give the attention more surface area)

Oh awesome, however the loss got very low (0.006) for these results, for the bigger datasets the loss gets to about 0.01 until it needs like 1000 epochs to reach similar loss.

So some further improvements would be nice! 😄
Any thoughts about the text embedding aware auto-encoder?

@lucidrains
Copy link
Owner

lucidrains commented May 13, 2024

@MarcusLoppe yup, we can try multiple sos tokens, then if that doesn't work, i'll build in the option to use prepended text embeddings (so like the solution you came up with, additional sos excised or pooled before fine transformer)

and yes, text embedding aware autoencoder is achievable! in fact, the original soundstream paper did this

@MarcusLoppe
Copy link
Contributor Author

MarcusLoppe commented May 16, 2024

@MarcusLoppe ah, you aren't referring to the number of sos tokens, but to the token number in the main sequence, my bad

try with a much larger number of sos tokens, say 16 or 32

I don't have the figures for them but I tried 16 and got bad results.
As you can see using 8 tokens had the worst results.
I'll shoot up the test script and get you some hard numbers.

I know that setting up the sos tokens before the decoder and then inserting after the cross attention will create some sort of learnable relationship and I assume that the tokens change with loss.
However I don't have any data to back this up but isn't it better to have the tokens be a representation of the text embeddings?
If the sequence is 48 tokens the majority of the loss comes after the frist few tokens and will 'shape'/optimize to minimize that loss, meaning that the tokens will adapt to fit itself to work for 98% of the sequence.
Sort of like sacrifice the frist wave of soldiers in war to be on a better situation so no other soldiers need to die.

So is it possible to reshape (with any nn) the text embeddings to the dim size and then inserting them at the start of the sequence and then a special token?

@lucidrains
Copy link
Owner

ok, I'll try one more thing, but if that doesn't work, let's go for your text prefix solution

@MarcusLoppe
Copy link
Contributor Author

ok, I'll try one more thing, but if that doesn't work, let's go for your text prefix solution

Little bit off topic but I trained a 1 quantize auto-encoder and transformer and good results. It was a little slower progression but I got about 0.03 loss with the transformer.
I didn't succeed in generating mesh with 0 tokens but providing 10 tokens it managed to generate mesh :)

So that is a big win, halfing the sequence length and reducing vram requirement from 22 GB to 8 GB in training (800 faces)

@MarcusLoppe
Copy link
Contributor Author

ok, I'll try one more thing, but if that doesn't work, let's go for your text prefix solution

Hi again.

Here is some failed results:

  • I've tried training with a higher loss at the first token or even only applying loss for the first token.
  • I did some testing with text embeddings which I used the text_embeds to append at start or use a linear layer(dim, dim * num_tokens) and then rearranged to (b,num_tokens, dim)
  • Replaced the sos_tokens with text_embeds
  • Tested using 32 tokens

I was wonder if even the decoder cross-attention layer could handle it alone but with just the decoder layer couldn't handle any part of the sequence.
So what thinking with the cross-attention? Do you think the sos token or cross-attention can handle the cold start?
Since the issue is with the first 0-3 tokens, would it beneficial to create some kind of embedding space that contains the first 3 tokens and is indexed by text embedding, this way the text embedding provided by the user can be used to find the nearest neighbour.
It's not very novel but a good way to at-least kickstart the generation, although the issue might be resolved with scale later on.

The best result I got was with the commit below, however It may just be luck and not a consistent behaviour. The linear attention method had similar results but without the slowness of adding cross-attention to the fine-decoder.

Training many many epochs using add cross attention based text conditioning for fine transformer too
Token 1: Correct = 260, Incorrect = 95
Token 2: Correct = 319, Incorrect = 36
Token 3: Correct = 351, Incorrect = 4
Token 4: Correct = 355, Incorrect = 0

Linear layer with 4 sos tokens

if not exists(cache):
            sos_tokens, attended_face_codes = unpack(attended_face_codes, packed_sos_shape, 'b * d')
            attention_scores = F.softmax(self.attention_weights(sos_tokens), dim=1)
            pooled_sos_token = torch.sum(attention_scores * sos_tokens, dim=1, keepdim=True)
            attended_face_codes = torch.cat((pooled_sos_token, attended_face_codes), dim = 1)

Token 1: Correct = 237, Incorrect = 118
Token 2: Correct = 205, Incorrect = 150
Token 3: Correct = 351, Incorrect = 4
Token 4: Correct = 354, Incorrect = 1
Token 5: Correct = 355, Incorrect = 0

@lucidrains
Copy link
Owner

lucidrains commented May 18, 2024

@MarcusLoppe thank you Marcus! 🙏 will get a few solutions in soon

i also realized i wasn't caching the cross attention key / values correctly 🤦 will also get that fixed this morning

@lucidrains lucidrains reopened this May 18, 2024
@MarcusLoppe
Copy link
Contributor Author

@MarcusLoppe thank you Marcus! 🙏 will get a few solutions in soon

i also realized i wasn't caching the cross attention key / values correctly 🤦 will also get that fixed this morning

Awesome! 😄
Outside of meshgpt have you had success training the decoder and let it generate from cold start with just a embedding before? E.g. train on sequences with 6 tokens and the only input is a embedding that is used in the cross attention for the decoder.

It works kinda good when the dataset is small (<500) , I don't think it's the model size since it can remember 10k models if its prompted with a few tokens.

Btw let me know if I'm doing something wrong but during my testing I just call forward_on_codes and get the logits and get the token by argmax.
I'm not sure if this would disable the classifier guidance or not.

@MarcusLoppe
Copy link
Contributor Author

MarcusLoppe commented May 20, 2024

@MarcusLoppe thank you Marcus! 🙏 will get a few solutions in soon

i also realized i wasn't caching the cross attention key / values correctly 🤦 will also get that fixed this morning

Hey again,

So I've noticed some strange behaviour with the cross attention num_mem_kv that might help you resolve the issue.
I've previously changed the value before without any noticeable changes.

However using the commit with the fine-decoder cross-attention I found the results below.
Setting the num_mem_kv cross attention to 16 seems to be hitting some kind of sweet spot (maybe related to the dataset size).

This made it possible to generate mesh from token 0 since it seems to be hitting the correct tokens, however as you can see the mesh is hardly smooth but at least it's selecting the correct first token! I'm currently training to see if using x5 augmentation of the same dataset will yield any better results since it might be more robust.
bild

I also tested fine depth either to 4 or 8 but the effect worsen the performance, same goes with increasing the attn_num_mem_kv to 16.

I also tested using 16 cross_attn_num_mem_kv on all the other solutions you've posted but there was no noticeable changes.

Commit: 5ef6cbf

8 cross_attn_num_mem_kv
Token 1: Correct = 6, Incorrect = 349
Token 2: Correct = 165, Incorrect = 190
Token 3: Correct = 320, Incorrect = 35
Token 4: Correct = 322, Incorrect = 33
Token 5: Correct = 341, Incorrect = 14 
 
16 cross_attn_num_mem_kv

Token 1: Correct = 293, Incorrect = 62
Token 2: Correct = 331, Incorrect = 24
Token 3: Correct = 354, Incorrect = 1
Token 4: Correct = 354, Incorrect = 1
Token 5: Correct = 355, Incorrect = 0 

16 cross_attn_num_mem_kv
8 fine_attn_depth 
Token 1: Correct = 233, Incorrect = 122
Token 2: Correct = 189, Incorrect = 166
Token 3: Correct = 321, Incorrect = 34
Token 4: Correct = 313, Incorrect = 42
Token 5: Correct = 342, Incorrect = 13 

32 cross_attn_num_mem_kv
 
Token 1: Correct = 4, Incorrect = 351
Token 2: Correct = 207, Incorrect = 148
Token 3: Correct = 345, Incorrect = 10
Token 4: Correct = 338, Incorrect = 17
Token 5: Correct = 349, Incorrect = 6 

16 attn_num_mem_kv 
16 cross_attn_num_mem_kv
Token 1: Correct = 5, Incorrect = 350
Token 2: Correct = 205, Incorrect = 150
Token 3: Correct = 353, Incorrect = 2
Token 4: Correct = 355, Incorrect = 0
Token 5: Correct = 355, Incorrect = 0 

@MarcusLoppe
Copy link
Contributor Author

MarcusLoppe commented Jun 1, 2024

@MarcusLoppe thank you Marcus! 🙏 will get a few solutions in soon

i also realized i wasn't caching the cross attention key / values correctly 🤦 will also get that fixed this morning

Hey, @lucidrains
I think I've figured something out, I quite a lot changes but I had success by applying the following:

  • Text embedding pooling with to_sos_text_cond
  • sos_token (single parameter).
  • Fine decoder cross-attention

Plus a few other tricks.
The training is also quite specific in regards to masking the text and other factors, if it becomes overtrained then the results are just blobs again.
When the conditions are pretty good the model will always generate a complete shape, not always for what you want but at least it's not a blob.
Btw I also manage to train a model using 1 quantizer which reduced the inference time by half (duh :) ).

I wouldn't say this issue is resolved since using a dataset with 1k unique labels, during the generation it will steer towards the most average mesh model according the the text embeddings, you can see this average effect in the second image (cond scale helps sometimes, setting it too high will turn the mesh into a blob).
Hopefully this information helps you steer towards a final solution that can be used for a large of amount text labels.

Possible issue / accidental feature

I'm not sure if it's a problem but since I add the sos_token before the main decoder and then adding the text embedding pooling afterwards, it will results in 2 tokens with 'value' is added and with the padding it will be 12 tokens.
The first 6 extra tokens are due for the autoregressive and the other 6 is due to the text embedding pool since it's added just before pad_to_length is called.

The results is that 1 token will be replaced/lost due to the right shift since the 2 tokens are added and only the sos_token is removed.
So the data between the decoder and fine decoder will be shifted right and the becomes in another order, this might not be a issue for the fine decoder since it's already out of order due to the rearranging and adding the grouped_codes so the shape goes from (b, faces, dim) to (b * (faces+1), (quantizers * vertices_per_face), dim)
But if you think of in a linear fashion and ignoring the ML transforming the data, the output would be:
<pooled_text_embed> <mesh> <cut> <EOS> <extra tokens>
Instead of:
<mesh> <EOS> <cut> <extra tokens>

This is just a guess but maybe since the output is over a longer sequence window during (12 tokens in the future instead of 6), it might help with the inference since during training it outputs what it thinks might be after the EOS token. However this output is cut off and doesn't affect the loss so I'm not sure if it matters, I also increased the padding so it's 18 tokens but the performance degraded).
I also tested replacing the pooled_text_embed with a Parameter dim but it got worse results so the text embedding does affect the output.

Multi-token prediction

I've been trying to understand how the transformer train and at the end there is always 1 extra face (6 tokens) and then the sequence is cut of so it's 5 tokens remaining. I'm guessing this is done for the autoregression and the EOS token.
But I think it can provide a additional effect by extending 'hidden' future tokens and can be used multi-token prediction.
I'm not sure about where the masking is applied while training but as a test I increase the amount of codes that was cut off and set 'append_eos' to false to see if it can predict multiple tokens ahead.
Nothing fancy as the meta paper and just a weak proof of concept.

Here is some samples after training 15 epochs on the first 12 tokens on 2.8k meshes with 1000 labels:
1 tokens: 0.3990 loss (0.5574 loss without the text embedding pooling)
2 tokens: 0.112 loss
3 tokens: 0.24 loss
4 tokens: 0.1375 loss (woah!)
6 tokens: 0.1826 loss (18th epoch 0.104 loss)

if return_loss:
            assert seq_len > 0
            codes, labels = codes[:, :-number_of_tokens], codes 
.......
embed = embed[:, :(code_len + number_of_tokens)] 

500 labels with 10 models for each label- 2k codebook, number of quantizers: 2
bild

1000 labels with 5 models for each label- 2k codebook, number of quantizers: 2
bild

100 labels with 25 models for each label- 16k codebook, number of quantizers: 1
bild

@lucidrains
Copy link
Owner

@MarcusLoppe thanks Marcus for the detailed report and glad to hear you found a solution!

i'll keep chipping away at it to strengthen conditioning

next up is to probably add adaptive layer/rms normalization to x-transformers

@lucidrains
Copy link
Owner

@MarcusLoppe you should see a slight speed bump on inference now, as i now cache the key / values correctly for cross attention

hope i didn't break anything!

@MarcusLoppe
Copy link
Contributor Author

MarcusLoppe commented Jun 1, 2024

@MarcusLoppe thanks Marcus for the detailed report and glad to hear you found a solution!

i'll keep chipping away at it to strengthen conditioning

next up is to probably add adaptive layer/rms normalization to x-transformers

Lovely :) I'll test the FILM normalization method and let you know.
I tried replacing the the PixelNorm using the film batch normalization on the ResNet but I had mild success, the std for the first token and it's relationship to the label decreased from 12 to 11.

However the sos_token isn't quite how I implemented it, I've had more success in just leaving the sos_token in without unpacking it.
I had best success with using the sos_token as per your commit move the concatenation of the sos token so it is always conditioned b.
That commit have had far better results rather then:

  • Unpacking multiple tokens + packing pooling
  • Repacking single
  • Unpacking multiple tokens and packing last token.

I tried explaining it before with my tests but I might have not been clear enough.

Here is the implementation I've used
https://github.com/MarcusLoppe/meshgpt-pytorch/blob/sos_token_test2/meshgpt_pytorch/meshgpt_pytorch.py

@MarcusLoppe you should see a slight speed bump on inference now, as i now cache the key / values correctly for cross attention

hope i didn't break anything!

I'll give it a go :)
Btw I can create a pull request for this but when using 1 quantizer, the rounding down method doesn't work in generation. Currently it's doing:
10 codes / 1 = 10 * 1 = 10 codes.
Instead of doing:
10 codes / 3 = 3 * 3 = 9 codes.

So changing the below will made the 1 quantizer generation work.
From:
round_down_code_len = code_len // self.num_quantizers * self.num_quantizers
To:
round_down_code_len = code_len // self.num_vertices_per_face * self.num_vertices_per_face

@lucidrains
Copy link
Owner

@MarcusLoppe thanks Marcus for the detailed report and glad to hear you found a solution!
i'll keep chipping away at it to strengthen conditioning
next up is to probably add adaptive layer/rms normalization to x-transformers

Lovely :) I'll test the FILM normalization method and let you know. I tried replacing the the PixelNorm using the film batch normalization on the ResNet but I had mild success, the std for the first token and it's relationship to the label decreased from 12 to 11.

However the sos_token isn't quite how I implemented it, I've had more success in just leaving the sos_token in without unpacking it. I had best success with using the sos_token as per your commit move the concatenation of the sos token so it is always conditioned b. That commit have had far better results rather then:

  • Unpacking multiple tokens + packing pooling
  • Repacking single
  • Unpacking multiple tokens and packing last token.

I tried explaining it before with my tests but I might have not been clear enough.

Here is the implementation I've used https://github.com/MarcusLoppe/meshgpt-pytorch/blob/sos_token_test2/meshgpt_pytorch/meshgpt_pytorch.py

@MarcusLoppe you should see a slight speed bump on inference now, as i now cache the key / values correctly for cross attention
hope i didn't break anything!

I'll give it a go :) Btw I can create a pull request for this but when using 1 quantizer, the rounding down method doesn't work in generation. Currently it's doing: 10 codes / 1 = 10 * 1 = 10 codes. Instead of doing: 10 codes / 3 = 3 * 3 = 9 codes.

So changing the below will made the 1 quantizer generation work. From: round_down_code_len = code_len // self.num_quantizers * self.num_quantizers To: round_down_code_len = code_len // self.num_vertices_per_face * self.num_vertices_per_face

thanks for reporting the rounding down issue!

and yes, i can cleanup the multiple sos tokens code if not needed. however, by setting just 1 sos token, it should be equivalent to what you deem the best working commit

@MarcusLoppe
Copy link
Contributor Author

MarcusLoppe commented Jun 1, 2024

thanks for reporting the rounding down issue!

and yes, i can cleanup the multiple sos tokens code if not needed. however, by setting just 1 sos token, it should be equivalent to what you deem the best working commit

Just tested it on the first 12 tokens and using the FiLM + mean have worse performance plus it's giving me nan loss.
However it might work better when applied to the full sequence, I'll get back to you.
Edit: Seems like the parameter usage in the FilM layer improved it, it no longer gives nan loss at the start.

Although I'm not sure if it would help since the issue might be pooling the mean. Lets say 1000's of text embeddings which all are unique, the cross-attention will receive them in their original state but then the fine decoder will get the average of each embedding as a additional token.
After averaging the embeddings they are now closer to each other then before and some of them no longer unique.
I think this is the reason why the same model is the 'default' for several mesh models.

@lucidrains
Copy link
Owner

@MarcusLoppe thanks for running it, made a few changes

yea, can definitely try attention pooling, which is a step up from mean pool

@MarcusLoppe
Copy link
Contributor Author

@MarcusLoppe thanks for running it, made a few changes

yea, can definitely try attention pooling, which is a step up from mean pool

Okay, some updates.

  1. I noticed that the masked_mean was broken since classifier free guidance uses the pad id "0" instead of "-1", so the embeddings never stayed the same for each batch since it included the padded values when doing the mean.
    This would mean that the 'chair' embedding would have different values and effectively be random.
    I created a pull request to add the padding id: Added padding id option

  2. After the padding issue was resolved I still had meshes with text that was more popular then other (e.g. it matched with "pallet" for 183 times of 775). So I printed out the cosine similarities and noticed a pattern, 'chair' and 'pallet' had 0.99999 cosine similarity and 'pallet' had 0.99999 similarity with many others. I was using CLIP so I switched to T5/BGE and the similarity was around 0.69 as it should be.
    I tried to combined CLIP & T5 using model_types = ('t5', 'clip'), however the same issue remained.

CLIP seems to work better then T5 & BGE on longer sentences and contain more nuanced information.
Here is the cosine similarities (checkout the last 2):

['pallet', 'chair']
BGE 58.8942
T5 67.7801
CLIP 99.9937
['a pallet', 'a chair']
BGE 71.565
T5 60.0675
CLIP 81.1785
['a pallet on floor', 'a chair on floor']
BGE 81.9622
T5 40.6449
CLIP 76.7676
['pallet', 'pallet on floor']
BGE 89.8585
T5 50.7333
CLIP 78.1393
['chair', 'pallet on floor']
BGE 60.5253
T5 37.083
CLIP 78.2081

So after fixing these changes I tried again and had much better success :)
I tested FiLM and got mixed results, I'm not confident to say what is best since the 1.5k tokens tests says it's best with FiLM on but the training without FiLM is better.

The results are very good, however there is some issues such as the test using x5 models per label, I have a very had time to generate 3 sets of distinct rows of the same furniture labels.
It does follow the label but it seems to be ignoring the other models (even with cond_scale and high temperature 0.8).

I trained on a dataset using 775 labels with 5x examples each (3.8k meshes), first tests was only trained on 60 tokens total, then the latest one I trained on the full 1500 token sequences. I tested using x10 examples but that training run requires more time to get a accurate picture of it's performance.
For calculating the accuracy I took all the unique labels and created list of all the models that have the same label. Using this list I checked if any of these models contains the same 3 first tokens as the generated sequence, if so. I count it as "correct".

Using FiLM

T5, trained on 60 tokens:

Generation accuracy: 
Accuracy: 0.9987096774193548 all_correct: 774 len(test_dataset): 775
Token 1: Correct = 774, Incorrect = 1
Token 2: Correct = 774, Incorrect = 1
Token 3: Correct = 774, Incorrect = 1 

Forward accuracy:
Accuracy: 0.27225806451612905 all_correct: 211 len(test_dataset): 775
Token 1: Correct = 774, Incorrect = 1
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0 

BGE, trained on 60 tokens:

Generation accuracy: 
Accuracy: 1.0 all_correct: 775 len(test_dataset): 775
Token 1: Correct = 775, Incorrect = 0
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0 

Forward accuracy:
Accuracy: 0.36 all_correct: 279 len(test_dataset): 775
Token 1: Correct = 775, Incorrect = 0
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0 

BGE trained on 1500 tokens:

Generation accuracy: 
Accuracy: 0.984516129032258 all_correct: 763 len(test_dataset): 775
Token 1: Correct = 763, Incorrect = 12
Token 2: Correct = 763, Incorrect = 12
Token 3: Correct = 763, Incorrect = 12  
 
Generation - Super loose (cosine similarity with text 90%)
Accuracy: 0.9922580645161291 all_correct: 769 len(test_dataset): 775
Token 1: Correct = 770, Incorrect = 5
Token 2: Correct = 769, Incorrect = 6
Token 3: Correct = 769, Incorrect = 6 

Forward accuracy:
Accuracy: 0.28774193548387095 all_correct: 223 len(test_dataset): 775
Token 1: Correct = 760, Incorrect = 15
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0

Without FiLM

T5, trained on 60 tokens:

Generation accuracy: 
Accuracy: 0.9987096774193548 all_correct: 774 len(test_dataset): 775
Token 1: Correct = 774, Incorrect = 1
Token 2: Correct = 774, Incorrect = 1
Token 3: Correct = 774, Incorrect = 1 

Forward accuracy:
Accuracy: 0.34580645161290324 all_correct: 268 len(test_dataset): 775
Token 1: Correct = 774, Incorrect = 1
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0 

BGE, trained on 60 tokens:

Generation accuracy: 
Accuracy: 1.0 all_correct: 775 len(test_dataset): 775
Token 1: Correct = 775, Incorrect = 0
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0 

Forward accuracy:
Accuracy: 0.33419354838709675 all_correct: 259 len(test_dataset): 775
Token 1: Correct = 775, Incorrect = 0
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0 

BGE trained on 1500 tokens:

Generation accuracy: 
Accuracy: 0.9780645161290322 all_correct: 758 len(test_dataset): 775
Token 1: Correct = 758, Incorrect = 17
Token 2: Correct = 758, Incorrect = 17
Token 3: Correct = 758, Incorrect = 17
 
Generation - Super loose (cosine similarity with text 90%)
Accuracy: 0.9883870967741936 all_correct: 766 len(test_dataset): 775
Token 1: Correct = 766, Incorrect = 9
Token 2: Correct = 766, Incorrect = 9
Token 3: Correct = 766, Incorrect = 9
 
Forward accuracy:
Accuracy: 0.28774193548387095 all_correct: 223 len(test_dataset): 775
Token 1: Correct = 746, Incorrect = 29
Token 2: Correct = 775, Incorrect = 0
Token 3: Correct = 775, Incorrect = 0

775 labels with 10x examples each (7.7k meshes),
BGE trained on 1500 tokens:

Generation accuracy: 
Accuracy: 0.9858064516129033 all_correct: 764 len(test_dataset): 775
Token 1: Correct = 764, Incorrect = 11
Token 2: Correct = 764, Incorrect = 11
Token 3: Correct = 764, Incorrect = 11 

Generation - Super loose (cosine similarity with text 90%)
Accuracy: 0.9896774193548387 all_correct: 767 len(test_dataset): 775
Token 1: Correct = 767, Incorrect = 8
Token 2: Correct = 767, Incorrect = 8
Token 3: Correct = 767, Incorrect = 8
Token 4: Correct = 767, Incorrect = 8

Renders:
Mesh files: https://file.io/OxWCcTYpUbxN
Without FILM
['bed', 'sofa', 'monitor', 'bench', 'chair', 'table', 'console table console', 'object on a stick', 'knife', 'billboard', 'concrete structure', 'rod', 'stick with a handle', 'shark fin', 'wooden railing', 'zigzag chair straight chair side chair', 'building', 'building with a few windows', 'wooden bench', 'screen with nothing on it', 'bird with a beak', 'apple', 'trash can with a lid', 'blocky robot', 'crystal on a base', 'sign with arrows on it', 'three different colocubes', 'crowbar', 'sheet of paper', 'metal rod', 'computer screen computer display screen crt screen', 'octagon', 'chair on a floor', 'platform bed', 'dark object', 'u shaped object', 'bed with a headboard', 'three blocks', 'brush', 'whale tail', 'staircase', 'lamp', 'broken cylinder', 'night stand', 'traffic cone', 'drill', 'four geometric shapes', 'trash can', 'rocket ship', 'traffic cone on a base', 'pyramid and a square', 'robot that is standing up', 'computer monitor', 'pixelated object', 'coffee cup', 'chaise longue chaise daybed chair', 'three rocks', 'desk table-tennis table ping-pong table pingpong table', 'letter k', 'box with a ribbon', 'dice with dots on it']
bild

With FILM
['bed', 'sofa', 'monitor', 'bench', 'chair', 'table', 'block animal', 'car', 'book', 'minecraft character in a shirt', 'pot', 'hanging sign', 'minecraft character wearing a shirt', 'secretary writing table escritoire secretaire', 'apple with a stem', 'two colorful blocks', 'line', 'security camera', 'pan', 'armchair', 'picket fence', 'palm tree', 'the letter t', 'bottle', 'park bench bench', 'sign', 'wooden baseball bat', 'wooden table', 'container with a lid', 'picnic table with two benches', 'geometric tree', 'operating table', 'traffic light', 'checkechair', 'dumpster', 'group of cubes', 'chair on a floor', 'two rectangular objects', 'box sitting', 'blocky chicken with an beak', 'octagon', 'couch with arm rests', 'metal object', 'top hat', 'diamond shaped object', 'two shelves', 'two rocks', 'apple', 'cartoon snail carrying a box', 'bench', 'cylinder and a cube', 'minecraft dolphin', 'question mark', 'tree made out of blocks', 'display video display', 'wooden stool', 'box with stripes on it', 'metal rod', 'tree stump', 'three dimensional object', 'side table table']
bild

@lucidrains
Copy link
Owner

@MarcusLoppe awesome, i think after adding adaptive layernorms to x-transformers, we can close this issue

that will surely be enough

@MarcusLoppe
Copy link
Contributor Author

MarcusLoppe commented Jun 15, 2024

@MarcusLoppe awesome, i think after adding adaptive layernorms to x-transformers, we can close this issue

that will surely be enough

@lucidrains
A little question, I'm about to train larger model that would require many text embeddings and I'm bit worried that cross-attention with the text embedding might 'take up' to much of the model.
I was thinking of using a dataset of 150k models so the text-guidance will need to be extremely good to represent those meshes.

I have this idea that due to the cross-attention to a text embedding which have a many to many relationship with tokens, if it instead just used the cross-attention to the sequence itself it will have more or else one to many relationship.
So if we used the learnable tokens and compress the text embedding through it so i would represent the text within the sequence, then it could cross-attend the sequence itself and make it more efficient and simpler.

I was wondering if using something like below would work?

The multiple sos tokens are just kept for the main decoder but isn't for the rest of the network and there is no Q-Former architecture that takes the text embeddings and encodes the information to it.

https://arxiv.org/pdf/2405.20853
bild

@lucidrains
Copy link
Owner

@MarcusLoppe oh, i don't even know what the q-former architecture is haha

i'll have to read it later this week, but it sounds like just a cross attention based recompression, similar to perceiver resampler

just got the adaptive layernorm conditioning into the repo! i think we can safely close this issue

@lucidrains
Copy link
Owner

@MarcusLoppe we can chat about q-former in the discussions tab

@lucidrains
Copy link
Owner

@MarcusLoppe oh yes, the qformer architecture is in vogue. bytedance recently used it for their vqvae to compress images even further.

will explore this in the coming month for sure!

@MarcusLoppe
Copy link
Contributor Author

@MarcusLoppe oh yes, the qformer architecture is in vogue. bytedance recently used it for their vqvae to compress images even further.

will explore this in the coming month for sure!

Would it be possible to explore this sooner? :) Or maybe provide a hint on how to do this?
A student of a university in the states have offered to help using 16 H100's for 2 weeks!
He'll be granted the compute shortly (today) and we'll start the autoencoder training.

However I'm still bit unsure if the cross-attention is the best way since the I had some trouble with using it for 10k labels.
I don't think meshgpt will get this attention again so any advice would be helpful! :)

@lucidrains
Copy link
Owner

lucidrains commented Jun 17, 2024

@MarcusLoppe is he/she a phd or ms student? if so, you and him/her should be able to work together and implement it, could even make for a short paper

or i can take a look at it, but probably not for another month or two

@MarcusLoppe
Copy link
Contributor Author

@MarcusLoppe is he/she a phd or ms student? if so, you and him/her should be able to work together and implement it, could even make for a short paper

or i can take a look at it, but probably not for another month or two

I think he's a PHD student, he applied for the compute a while ago and was granted it. It's not for a thesis or graded paper but perhaps a technical report.

I'm happily and interested it in implementing it myself but with many of SOTA things it might be above my head.
As far as I understand it, the image patches (image / 32) in processed through a encoder which then are encoded & quantized into a codebook. Then these codes can be used by the transformer as token indices or maybe just the encoded dim output from the quantizer.
I'm on the right track?

@lucidrains
Copy link
Owner

@MarcusLoppe ok, if he's a phd student, you two should definitely be able to work it out from the code already available

@lucidrains
Copy link
Owner

lucidrains commented Jun 17, 2024

@MarcusLoppe i'm not sure without spending a day reading the paper, but it looks to me they are simply using appended "query" tokens, which is similar to memory / register tokens in the literature. they simply concat it to the sequence and then attend to everything, and slice it out. it is similar to the sos tokens we've been playing around with, except it isn't autoregressive

@lucidrains
Copy link
Owner

@MarcusLoppe ask your collaborator! he should know if he is in the field

@MarcusLoppe
Copy link
Contributor Author

@MarcusLoppe i'm not sure without spending a day reading the paper, but it looks to me they are simply using appended "query" tokens, which is similar to memory / register tokens in the literature. they simply concat it to the sequence and then attend to everything, and slice it out. it is similar to the sos tokens we've been playing around with, except it isn't autoregressive

I've read bit further, you might be right and I'm not understanding your terminology.

My understanding it that they train a autoencoder (tokenizer) and only use 32 tokens to represent the image.
They make extract patches for the image and represent them in token(s) and the during generation they can mask certain tokens and generate new novel images (?).

bild

I'm not quite sure if it's applicable to this project, I played bit around with using the sos tokens,however I got worse results.
I used a encoder layer that uses the text embeddings as cross-attention to output the dim for the sos tokens.
I also tried using [sos_tokens , text_embedding] and encoded it and returned the output from the encoder [:32]

I was thinking that maybe the issue isn't that the text embeddings are too weak but maybe the cross-attention will messes it up a bit.
Since the cross-attention to the text embeddings isn't one to one, it might be unstable since it will learn that one text embedding have 100's of different 'correct solutions'.
I think it will be more stable if the text is within the token sequence, the sos tokens hasn't quite worked out.

@lucidrains
Copy link
Owner

lucidrains commented Jun 17, 2024

@MarcusLoppe you should def chat with your collaborator (instead of me) since you'll be training the next model together

he will probably be more up-to-date with mesh research too, as he is following it full time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants