-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Description
Describe the bug
Obtain nan loss after 1 epoch when finetune a pretrained resmlp_12_224.fb_in1k with mixed precision on cifar10
To Reproduce
I use timm Version: 0.8.18.dev0
Steps to reproduce the behavior:
- run the command
python train.py data/ --dataset torch/cifar10 --model resmlp_12_224.fb_in1k --pretrained --amp
Expected behavior
Model can be finetuned normally
Screenshots
Learning rate (0.05) calculated from base learning rate (0.1) and global batch size (128) with linear scaling.
Using native Torch AMP. Training in mixed precision.
Scheduled epochs: 300. LR stepped per epoch.
Train: 0 [ 0/390 ( 0%)] Loss: 8.364 (8.36) Time: 1.735s, 73.77/s (1.735s, 73.77/s) LR: 1.000e-05 Data: 0.683 (0.683)
Train: 0 [ 50/390 ( 13%)] Loss: 3.141 (5.03) Time: 0.684s, 187.02/s (0.694s, 184.33/s) LR: 1.000e-05 Data: 0.009 (0.021)
Train: 0 [ 100/390 ( 26%)] Loss: 2.805 (3.97) Time: 0.643s, 199.16/s (0.684s, 187.03/s) LR: 1.000e-05 Data: 0.008 (0.015)
Train: 0 [ 150/390 ( 39%)] Loss: 2.700 (3.54) Time: 0.585s, 218.96/s (0.663s, 193.18/s) LR: 1.000e-05 Data: 0.008 (0.013)
Train: 0 [ 200/390 ( 51%)] Loss: 2.415 (3.29) Time: 0.582s, 219.86/s (0.643s, 199.16/s) LR: 1.000e-05 Data: 0.008 (0.012)
Train: 0 [ 250/390 ( 64%)] Loss: 2.560 (3.11) Time: 0.586s, 218.30/s (0.631s, 202.90/s) LR: 1.000e-05 Data: 0.008 (0.011)
Train: 0 [ 300/390 ( 77%)] Loss: 2.296 (2.98) Time: 0.584s, 219.01/s (0.623s, 205.42/s) LR: 1.000e-05 Data: 0.008 (0.010)
Train: 0 [ 350/390 ( 90%)] Loss: 2.184 (2.88) Time: 0.583s, 219.72/s (0.617s, 207.35/s) LR: 1.000e-05 Data: 0.008 (0.010)
Train: 0 [ 389/390 (100%)] Loss: 2.217 (2.81) Time: 0.571s, 224.31/s (0.614s, 208.57/s) LR: 1.000e-05 Data: 0.000 (0.010)
Test: [ 0/78] Time: 0.612 (0.612) Loss: 1.0401 (1.0401) Acc@1: 73.4375 (73.4375) Acc@5: 99.2188 (99.2188)
Test: [ 50/78] Time: 0.199 (0.204) Loss: 1.0240 (1.0031) Acc@1: 71.8750 (70.3891) Acc@5: 98.4375 (98.7439)
Test: [ 78/78] Time: 0.103 (0.200) Loss: 0.8205 (1.0043) Acc@1: 75.0000 (70.1100) Acc@5: 100.0000 (98.6800)
Current checkpoints:
('./output/train/20230412-030416-resmlp_12_224_fb_in1k-224/checkpoint-0.pth.tar', 70.11)
Train: 1 [ 0/390 ( 0%)] Loss: 2.160 (2.16) Time: 1.105s, 115.80/s (1.105s, 115.80/s) LR: 1.001e-02 Data: 0.528 (0.528)
Train: 1 [ 50/390 ( 13%)] Loss: nan (nan) Time: 0.578s, 221.63/s (0.586s, 218.58/s) LR: 1.001e-02 Data: 0.008 (0.018)
Train: 1 [ 100/390 ( 26%)] Loss: nan (nan) Time: 0.575s, 222.61/s (0.581s, 220.44/s) LR: 1.001e-02 Data: 0.009 (0.014)
Desktop (please complete the following information):
- OS: Ubuntu 18.04
- This repository version 0.8.18.dev0
- PyTorch version w/ CUDA/cuDNN torch 1.12.1, cuda 10.2