Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer - token_embed outputs nan values #44

Closed
MarcusLoppe opened this issue Dec 30, 2023 · 11 comments
Closed

Transformer - token_embed outputs nan values #44

MarcusLoppe opened this issue Dec 30, 2023 · 11 comments

Comments

@MarcusLoppe
Copy link
Contributor

MarcusLoppe commented Dec 30, 2023

This issue occurs if you have too high learning rate (1-e2) at a low loss (0.3), through this also occurred when I had 1-e3 as lr and at 0.01 loss.
edit: Using flash attention it goes from 5.0 loss to nan in the 5th epoch using 1e-4 lr.

After the codes are masked the and token_embed is called, it will output nan values.
Not sure if this issue is a pytorch, meshgpt-pytorch or user error :)

codes = codes.masked_fill(codes == self.pad_id, 0)
codes = self.token_embed(codes)
codes  after  masked_fill  torch.Size([2, 912]) tensor([[11965,   608, 11350,  ...,     0,     0,     0],
        [15507, 13398,  5400,  ...,  8247, 13231,  5280]], device='cuda:0') 

codes token_embed after  torch.Size([2, 912, 512]) tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0',
       grad_fn=<EmbeddingBackward0>)
@lucidrains
Copy link
Owner

lucidrains commented Jan 1, 2024

yea, this is just normal transformer instability

there's a bag of tricks for tackling this

@MarcusLoppe
Copy link
Contributor Author

yea, this is just normal transformer instability

there's a bag of tricks for tackling this

@lucidrains

Shoot, I'm using a dataset of 120 mesh models (1200 after augmentation), it worked bit better with a bigger dataset so it might be due to the 'small' dataset.

lr 1e-4:

Epoch 1/10: 100%|██████████| 600/600 [02:34<00:00,  3.89it/s, loss=8.54]
Epoch 1 average loss: 8.743859918912252
Epoch 2/10: 100%|██████████| 600/600 [02:31<00:00,  3.95it/s, loss=8.15]
Epoch 2 average loss: 8.339149476687114
Epoch 3/10: 100%|██████████| 600/600 [02:31<00:00,  3.96it/s, loss=6.67]
Epoch 3 average loss: 7.025277642409007
Epoch 4/10: 100%|██████████| 600/600 [02:32<00:00,  3.94it/s, loss=5.83]
Epoch 4 average loss: 5.839961892763774           avg loss speed: 2.196133786572349
Epoch 5/10: 100%|██████████| 600/600 [02:31<00:00,  3.95it/s, loss=5.23]
Epoch 5 average loss: 5.08304128130277           avg loss speed: 1.9850883893171947
Epoch 6/10: 100%|██████████| 600/600 [02:32<00:00,  3.94it/s, loss=4.39]
Epoch 6 average loss: 4.479391298294067           avg loss speed: 1.5033689738644487
Epoch 7/10: 100%|██████████| 600/600 [02:23<00:00,  4.19it/s, loss=nan] 
Epoch 7 average loss: nan
Epoch 8/10: 100%|██████████| 600/600 [02:17<00:00,  4.35it/s, loss=nan]
Epoch 8 average loss: nan

@Kurokabe
Copy link
Contributor

Kurokabe commented Jan 1, 2024

yea, this is just normal transformer instability

there's a bag of tricks for tackling this

Could you give some examples on how to tackle this? I'm also having NaN after a few epochs (~5 epochs) when training on full ShapeNet (~15k different mesh models) with an 1e-4 lr. I'm still investigating so I'm not sure if it's exactly the same problem as @MarcusLoppe but it could be nice to have some ideas on how to solve this problem :)

@lucidrains
Copy link
Owner

lucidrains commented Jan 1, 2024

there are no solutions. stabilizing transformers is still an active area of research, esp as you increase parameter count. there are various bandaids however. most practitioners have a couple they apply, but none of them are panaceas yet

@lucidrains
Copy link
Owner

lucidrains commented Jan 1, 2024

you can check out my x-transformers repo for more info

@MarcusLoppe
Copy link
Contributor Author

MarcusLoppe commented Jan 1, 2024

you can check out my x-transformers repo for more info

Any particular feature? I'm finding gate_residual ,sandwich_norm, ResiDual and scale_residual.
Btw do you have already or plan on implement sliding window in x-transformers?

Could you give some examples on how to tackle this? I'm also having NaN after a few epochs (~5 epochs) when training on full ShapeNet (~15k different mesh models) with an 1e-4 lr. I'm still investigating so I'm not sure if it's exactly the same problem as @MarcusLoppe but it could be nice to have some ideas on how to solve this problem :)

I think experimenting with the optimizer would be a good start as well, most easiest parameters is probably; max_grad_norm and weight_decay.
I'll do some testing and I'll let you know what I find out.

In the paper they didn't mention of any other details then using Adam and batch size of 64, I believe that increasing the batch size might help as well. Due to VRAM constrains I'm only using 1 or 2 batch size.

@lucidrains
Copy link
Owner

lucidrains commented Jan 2, 2024

@MarcusLoppe you could try qk norm. some researchers at google brain are attached to this, but i suspect it has a slight generalization cost

yea, you are right with optimizer. values to play with are beta1, beta2, and eps. your batch size def needs to be bigger once you scale up, but you can use gradient accumulation for this (which is built-in)

@lucidrains
Copy link
Owner

other things that would help is warmup, gradient clipping of 0.5 and 0.25 if you want to be really aggressive

@lucidrains
Copy link
Owner

@MarcusLoppe scratch everything i said, as Kurokabe noted that a potential source of instability was actually due to the gateloop layers

@MarcusLoppe
Copy link
Contributor Author

@MarcusLoppe scratch everything i said, as Kurokabe noted that a potential source of instability was actually due to the gateloop layers

I still get nan loss at 0.07 using 1e-4 as learning rate. But above that it doesn't give any issues anymore.
I'll try to replicate and use detect_anomaly to see what happens.

@MarcusLoppe
Copy link
Contributor Author

Resolved by using larger dataset, possible explanation: #68 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants