Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unstability in training in RNN #10

Closed
bratao opened this issue Oct 17, 2020 · 7 comments
Closed

Unstability in training in RNN #10

bratao opened this issue Oct 17, 2020 · 7 comments

Comments

@bratao
Copy link

bratao commented Oct 17, 2020

Hello,

Congratulations about this awesome paper and for providing the code to test it.
I´m training a small RNN network ( 2 layers of SRU (https://github.com/asappresearch/sru), 256 hidden size, CRF at end) for the NER task.

As following the Readme, I disabled the gradient clipping, and used an epsilon of 1e-12. This task converges great with Ranger, SGD and Adam. But using Adabelief I get some loss explosion randomly.

Am I doing something wrong?

accuracy: 0.8366, accuracy3: 0.8366, precision-overall: 0.0040, recall-overall: 0.0163, f1-measure-overall: 0.0065, batch_loss: 7236.0938, loss: 57461.7845 ||: : 30it [09:29, 18.99s/it]                        
accuracy: 0.9254, accuracy3: 0.9255, precision-overall: 0.1612, recall-overall: 0.2104, f1-measure-overall: 0.1825, batch_loss: 51126.7266, loss: 18637.9896 ||: : 30it [08:47, 17.60s/it]                       
accuracy: 0.9645, accuracy3: 0.9645, precision-overall: 0.3207, recall-overall: 0.4666, f1-measure-overall: 0.3801, batch_loss: 11046.6484, loss: 13583.7611 ||: : 30it [08:59, 17.99s/it]                      
accuracy: 0.9828, accuracy3: 0.9829, precision-overall: 0.6505, recall-overall: 0.7602, f1-measure-overall: 0.7011, batch_loss: 8434.5000, loss: 3932.2246 ||: : 29it [08:37, 17.86s/it]                       
accuracy: 0.9856, accuracy3: 0.9856, precision-overall: 0.7832, recall-overall: 0.8383, f1-measure-overall: 0.8098, batch_loss: 122.3125, loss: 3008.3288 ||: : 29it [09:13, 19.09s/it]                        
accuracy: 0.9930, accuracy3: 0.9930, precision-overall: 0.8261, recall-overall: 0.8861, f1-measure-overall: 0.8551, batch_loss: 2115.6699, loss: 1362.0373 ||: : 30it [08:55, 17.84s/it]                       
accuracy: 0.9948, accuracy3: 0.9948, precision-overall: 0.8893, recall-overall: 0.9243, f1-measure-overall: 0.9065, batch_loss: 1569.0469, loss: 1011.7590 ||: : 30it [08:33, 17.10s/it]                       
accuracy: 0.9972, accuracy3: 0.9972, precision-overall: 0.9367, recall-overall: 0.9571, f1-measure-overall: 0.9468, batch_loss: 591.5840, loss: 426.5681 ||: : 29it [08:58, 18.56s/it]                       
accuracy: 0.9977, accuracy3: 0.9977, precision-overall: 0.9514, recall-overall: 0.9660, f1-measure-overall: 0.9587, batch_loss: 23.7188, loss: 279.9471 ||: : 29it [08:32, 17.69s/it]                        
accuracy: 0.9977, accuracy3: 0.9977, precision-overall: 0.9501, recall-overall: 0.9627, f1-measure-overall: 0.9564, batch_loss: 93.2188, loss: 243.8314 ||: : 30it [09:16, 18.54s/it]                        
accuracy: 0.9984, accuracy3: 0.9984, precision-overall: 0.9641, recall-overall: 0.9732, f1-measure-overall: 0.9686, batch_loss: 53.5000, loss: 199.5779 ||: : 29it [08:44, 18.10s/it]                        
accuracy: 0.9984, accuracy3: 0.9984, precision-overall: 0.9702, recall-overall: 0.9789, f1-measure-overall: 0.9745, batch_loss: 52.5781, loss: 156.1823 ||: : 30it [09:14, 18.47s/it]                       
accuracy: 0.9994, accuracy3: 0.9994, precision-overall: 0.9816, recall-overall: 0.9871, f1-measure-overall: 0.9843, batch_loss: 61.4688, loss: 69.1954 ||: : 29it [09:01, 18.66s/it]                        
accuracy: 0.9990, accuracy3: 0.9990, precision-overall: 0.9813, recall-overall: 0.9858, f1-measure-overall: 0.9836, batch_loss: 29.5312, loss: 90.0869 ||: : 29it [08:51, 18.33s/it]                        
accuracy: 0.9996, accuracy3: 0.9996, precision-overall: 0.9846, recall-overall: 0.9896, f1-measure-overall: 0.9871, batch_loss: 74.0625, loss: 53.9213 ||: : 29it [08:40, 17.94s/it]                       
accuracy: 0.9995, accuracy3: 0.9995, precision-overall: 0.9822, recall-overall: 0.9868, f1-measure-overall: 0.9845, batch_loss: 33.9844, loss: 49.5508 ||: : 30it [08:35, 17.19s/it]                       
accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9854, recall-overall: 0.9869, f1-measure-overall: 0.9862, batch_loss: 19.3906, loss: 34.1199 ||: : 30it [09:03, 18.11s/it]                       
accuracy: 0.9995, accuracy3: 0.9995, precision-overall: 0.9938, recall-overall: 0.9950, f1-measure-overall: 0.9944, batch_loss: 709.4336, loss: 48.0945 ||: : 29it [08:38, 17.88s/it]                      
accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9914, recall-overall: 0.9937, f1-measure-overall: 0.9925, batch_loss: 14.9688, loss: 38.2326 ||: : 29it [08:36, 17.79s/it]                       
accuracy: 0.9996, accuracy3: 0.9996, precision-overall: 0.9852, recall-overall: 0.9894, f1-measure-overall: 0.9873, batch_loss: 79.4688, loss: 51.3397 ||: : 29it [08:55, 18.46s/it]                       
accuracy: 0.9998, accuracy3: 0.9998, precision-overall: 0.9926, recall-overall: 0.9936, f1-measure-overall: 0.9931, batch_loss: 39.0625, loss: 22.0619 ||: : 30it [09:00, 18.03s/it]                      
accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9915, recall-overall: 0.9937, f1-measure-overall: 0.9926, batch_loss: 16.9062, loss: 33.6324 ||: : 30it [09:32, 19.07s/it]                       
accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9939, recall-overall: 0.9947, f1-measure-overall: 0.9943, batch_loss: 0.7812, loss: 27.4840 ||: : 30it [09:13, 18.44s/it]                        
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9951, recall-overall: 0.9959, f1-measure-overall: 0.9955, batch_loss: 27.0786, loss: 15.0342 ||: : 29it [09:08, 18.92s/it]                      
accuracy: 0.9996, accuracy3: 0.9996, precision-overall: 0.9938, recall-overall: 0.9963, f1-measure-overall: 0.9951, batch_loss: 7.7500, loss: 25.8246 ||: : 29it [09:00, 18.63s/it]                       
accuracy: 0.9998, accuracy3: 0.9998, precision-overall: 0.9957, recall-overall: 0.9966, f1-measure-overall: 0.9961, batch_loss: 27.6875, loss: 17.3096 ||: : 30it [08:47, 17.58s/it]                      
accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9949, recall-overall: 0.9968, f1-measure-overall: 0.9958, batch_loss: 35.4727, loss: 26.2837 ||: : 29it [08:24, 17.40s/it]                      
accuracy: 0.9977, accuracy3: 0.9977, precision-overall: 0.9501, recall-overall: 0.9627, f1-measure-overall: 0.9564, batch_loss: 93.2188, loss: 243.8314 ||: : 30it [09:16, 18.54s/it]
accuracy: 0.9984, accuracy3: 0.9984, precision-overall: 0.9641, recall-overall: 0.9732, f1-measure-overall: 0.9686, batch_loss: 53.5000, loss: 199.5779 ||: : 29it [08:44, 18.10s/it]
accuracy: 0.9984, accuracy3: 0.9984, precision-overall: 0.9702, recall-overall: 0.9789, f1-measure-overall: 0.9745, batch_loss: 52.5781, loss: 156.1823 ||: : 30it [09:14, 18.47s/it]
accuracy: 0.9994, accuracy3: 0.9994, precision-overall: 0.9816, recall-overall: 0.9871, f1-measure-overall: 0.9843, batch_loss: 61.4688, loss: 69.1954 ||: : 29it [09:01, 18.66s/it]
accuracy: 0.9990, accuracy3: 0.9990, precision-overall: 0.9813, recall-overall: 0.9858, f1-measure-overall: 0.9836, batch_loss: 29.5312, loss: 90.0869 ||: : 29it [08:51, 18.33s/it]
accuracy: 0.9996, accuracy3: 0.9996, precision-overall: 0.9846, recall-overall: 0.9896, f1-measure-overall: 0.9871, batch_loss: 74.0625, loss: 53.9213 ||: : 29it [08:40, 17.94s/it]
accuracy: 0.9995, accuracy3: 0.9995, precision-overall: 0.9822, recall-overall: 0.9868, f1-measure-overall: 0.9845, batch_loss: 33.9844, loss: 49.5508 ||: : 30it [08:35, 17.19s/it]
accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9854, recall-overall: 0.9869, f1-measure-overall: 0.9862, batch_loss: 19.3906, loss: 34.1199 ||: : 30it [09:03, 18.11s/it]
accuracy: 0.9995, accuracy3: 0.9995, precision-overall: 0.9938, recall-overall: 0.9950, f1-measure-overall: 0.9944, batch_loss: 709.4336, loss: 48.0945 ||: : 29it [08:38, 17.88s/it]
accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9914, recall-overall: 0.9937, f1-measure-overall: 0.9925, batch_loss: 14.9688, loss: 38.2326 ||: : 29it [08:36, 17.79s/it]
accuracy: 0.9996, accuracy3: 0.9996, precision-overall: 0.9852, recall-overall: 0.9894, f1-measure-overall: 0.9873, batch_loss: 79.4688, loss: 51.3397 ||: : 29it [08:55, 18.46s/it]
accuracy: 0.9998, accuracy3: 0.9998, precision-overall: 0.9926, recall-overall: 0.9936, f1-measure-overall: 0.9931, batch_loss: 39.0625, loss: 22.0619 ||: : 30it [09:00, 18.03s/it]
accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9915, recall-overall: 0.9937, f1-measure-overall: 0.9926, batch_loss: 16.9062, loss: 33.6324 ||: : 30it [09:32, 19.07s/it]
accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9939, recall-overall: 0.9947, f1-measure-overall: 0.9943, batch_loss: 0.7812, loss: 27.4840 ||: : 30it [09:13, 18.44s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9951, recall-overall: 0.9959, f1-measure-overall: 0.9955, batch_loss: 27.0786, loss: 15.0342 ||: : 29it [09:08, 18.92s/it]
accuracy: 0.9996, accuracy3: 0.9996, precision-overall: 0.9938, recall-overall: 0.9963, f1-measure-overall: 0.9951, batch_loss: 7.7500, loss: 25.8246 ||: : 29it [09:00, 18.63s/it]
accuracy: 0.9998, accuracy3: 0.9998, precision-overall: 0.9957, recall-overall: 0.9966, f1-measure-overall: 0.9961, batch_loss: 27.6875, loss: 17.3096 ||: : 30it [08:47, 17.58s/it]
accuracy: 0.9997, accuracy3: 0.9997, precision-overall: 0.9949, recall-overall: 0.9968, f1-measure-overall: 0.9958, batch_loss: 35.4727, loss: 26.2837 ||: : 29it [08:24, 17.40s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9968, recall-overall: 0.9975, f1-measure-overall: 0.9972, batch_loss: 40.9062, loss: 13.3182 ||: : 30it [09:12, 18.42s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9965, recall-overall: 0.9979, f1-measure-overall: 0.9972, batch_loss: 0.5000, loss: 8.9580 ||: : 29it [08:27, 17.51s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9973, recall-overall: 0.9978, f1-measure-overall: 0.9976, batch_loss: 0.6250, loss: 10.6955 ||: : 29it [08:08, 16.84s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9983, recall-overall: 0.9990, f1-measure-overall: 0.9986, batch_loss: 5.4375, loss: 9.3031 ||: : 30it [08:18, 16.63s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9978, recall-overall: 0.9982, f1-measure-overall: 0.9980, batch_loss: 6.3047, loss: 6.1776 ||: : 29it [08:19, 17.22s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9977, recall-overall: 0.9980, f1-measure-overall: 0.9979, batch_loss: 0.8438, loss: 5.7469 ||: : 29it [08:14, 17.04s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9975, recall-overall: 0.9976, f1-measure-overall: 0.9976, batch_loss: 9.0176, loss: 7.7605 ||: : 30it [08:18, 16.60s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9964, recall-overall: 0.9966, f1-measure-overall: 0.9965, batch_loss: 1.8438, loss: 11.5324 ||: : 30it [08:11, 16.37s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9962, recall-overall: 0.9969, f1-measure-overall: 0.9966, batch_loss: 9.9844, loss: 12.8704 ||: : 29it [08:27, 17.51s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9980, recall-overall: 0.9988, f1-measure-overall: 0.9984, batch_loss: 3.5742, loss: 4.8728 ||: : 30it [08:36, 17.23s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9993, recall-overall: 0.9993, f1-measure-overall: 0.9993, batch_loss: 0.7031, loss: 2.8980 ||: : 30it [08:26, 16.88s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9986, recall-overall: 0.9987, f1-measure-overall: 0.9986, batch_loss: 7.0625, loss: 4.2808 ||: : 30it [08:50, 17.69s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9988, recall-overall: 0.9990, f1-measure-overall: 0.9989, batch_loss: 2.1562, loss: 4.5667 ||: : 30it [08:08, 16.28s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9987, recall-overall: 0.9990, f1-measure-overall: 0.9988, batch_loss: 15.0625, loss: 3.0480 ||: : 30it [08:36, 17.22s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9986, recall-overall: 0.9989, f1-measure-overall: 0.9987, batch_loss: 21.6094, loss: 2.7449 ||: : 30it [08:18, 16.60s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9995, recall-overall: 0.9997, f1-measure-overall: 0.9996, batch_loss: 0.7812, loss: 2.5399 ||: : 29it [08:06, 16.78s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9995, recall-overall: 0.9995, f1-measure-overall: 0.9995, batch_loss: -0.0625, loss: 2.2463 ||: : 29it [08:13, 17.03s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9992, recall-overall: 0.9993, f1-measure-overall: 0.9992, batch_loss: 2.7969, loss: 3.0429 ||: : 30it [08:21, 16.71s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9997, recall-overall: 0.9998, f1-measure-overall: 0.9997, batch_loss: 2.4316, loss: 2.3025 ||: : 30it [08:30, 17.02s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9996, recall-overall: 0.9998, f1-measure-overall: 0.9997, batch_loss: 1.3281, loss: 4.6582 ||: : 29it [08:09, 16.89s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9973, recall-overall: 0.9980, f1-measure-overall: 0.9977, batch_loss: -0.0000, loss: 4.8893 ||: : 30it [08:36, 17.23s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9956, recall-overall: 0.9976, f1-measure-overall: 0.9966, batch_loss: 0.6875, loss: 4.2254 ||: : 30it [08:21, 16.71s/it]
accuracy: 0.9999, accuracy3: 0.9999, precision-overall: 0.9980, recall-overall: 0.9981, f1-measure-overall: 0.9981, batch_loss: 0.0312, loss: 5.8634 ||: : 30it [08:10, 16.34s/it]
accuracy: 0.9984, accuracy3: 0.9984, precision-overall: 0.9787, recall-overall: 0.9515, f1-measure-overall: 0.9649, batch_loss: 22304.5000, loss: 749.8296 ||: : 30it [08:32, 17.08s/it]
accuracy: 0.9570, accuracy3: 0.9570, precision-overall: 0.2782, recall-overall: 0.4189, f1-measure-overall: 0.3343, batch_loss: 731722.4375, loss: 65948.9812 ||: : 30it [08:25, 16.85s/it]
accuracy: 0.9383, accuracy3: 0.9383, precision-overall: 0.1668, recall-overall: 0.2775, f1-measure-overall: 0.2083, batch_loss: 778091.5625, loss: 337316.9677 ||: : 29it [08:08, 16.83s/it]
Epoch    53: reducing learning rate of group 0 to 3.0000e-03.
accuracy: 0.9668, accuracy3: 0.9669, precision-overall: 0.3510, recall-overall: 0.5322, f1-measure-overall: 0.4230, batch_loss: 77123.0000, loss: 253831.3728 ||: : 29it [08:23, 17.36s/it]
accuracy: 0.9767, accuracy3: 0.9767, precision-overall: 0.4897, recall-overall: 0.6151, f1-measure-overall: 0.5453, batch_loss: -1.0000, loss: 137048.0448 ||: : 30it [08:35, 17.19s/it]
accuracy: 0.9839, accuracy3: 0.9839, precision-overall: 0.6340, recall-overall: 0.7326, f1-measure-overall: 0.6798, batch_loss: 43615.0000, loss: 103847.1062 ||:  19%|#8        | 5/27 [01:36<07:03, 19.27s/it]
@juntang-zhuang
Copy link
Owner

juntang-zhuang commented Oct 17, 2020

Thanks for the feedback. I think you are right, and this might be caused by the fact that the update is roughly m_t / sqrt(. (gt-mt)^2 ), the denominator is sometimes too small. Even if the denominator is small for one element, that element will explode. This is a issue I'm trying to fix in the next release, for example hard thres 1/sqrt( (gt-mt)^2 ) to a rather large value.
Please keep this issue open as a reminder of problems to fix for improvement.

@juntang-zhuang
Copy link
Owner

I noticed that the learning rate is quite large, after reduction it's still 3e-3. Perhaps a large lr also causes the instability.

@bratao
Copy link
Author

bratao commented Oct 17, 2020

Just an update @juntang-zhuang . Your modified Ranger got the best accuracy with half epochs compared to regular Ranger. It is already my favorite optimizer. Thank you!

@juntang-zhuang
Copy link
Owner

Wow, excited to hear that, thanks so much for trying it out.

@henriyl
Copy link

henriyl commented Oct 17, 2020

Could be entirely unrelated to this issue, but at first step in AdaBelief we havem_t = grad, causing v_t = 0 and step_size = step_size / epsilon_t which seems like unintended behaviour.

Edit2:
Deleted my earlier further comments as they were not exactly correct.

@juntang-zhuang
Copy link
Owner

juntang-zhuang commented Oct 17, 2020

Thanks for comment @henriyl . The detailed implementation is definitely not perfect now and might suffer from numerical issues, and you refer to a very good point. For this paper, it's more like a "proof-of-idea" considering the key modification of Adam(W) is only 2 lines of code, therefore many details are not well solved (these details might not be a big problem for CV task but might be more serious in RNN with exploding or vanishing gradient). We are working on the improvement, both in implementation and on theory (personally I guess the convergence bound is too loose in the paper). Thanks again for pointing this out.

@juntang-zhuang
Copy link
Owner

juntang-zhuang commented Oct 19, 2020

@bratao Just an update, I might confuse "gradient clip" with "gradient threshold" before, please see the discussion in readme.md. Perhaps "gradient clip" still helps, which shrinks vector amplitude but keeps the direction, but it might require different clip ranges from Adam; while "gradient threshold" is element-wise operation, and for each element outputs a value in a fixed range, and each dimension of the parameter is independently thresholded, this might cause 0 denominator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants