Skip to content
This repository has been archived by the owner on Mar 15, 2024. It is now read-only.

I need some help to reproduce DeiT-III finetuning result #167

Closed
bhheo opened this issue Jun 2, 2022 · 23 comments
Closed

I need some help to reproduce DeiT-III finetuning result #167

bhheo opened this issue Jun 2, 2022 · 23 comments
Assignees

Comments

@bhheo
Copy link
Contributor

bhheo commented Jun 2, 2022

Hi

Thank you for sharing finetune code & training logs
On IN-1k pretraining, I got similar results to your log: ViT-S 81.43 and ViT-B 82.88
But, I failed to reproduce finetune performance even with your official finetuning setting
So, I would like to ask for advice or help.

Here is my fine-tune result with ViT-B on IN-1k.
image

I expected performance will increase as your fine-tune log, but. instead, the finetune degrades the performance.
I can't use submitit, so I used the following command on 1 node 8 GPUs A100 machine

OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=${num_gpus_per_node} --nnodes=${WORLD_SIZE} --node_rank=${RANK}  --master_addr=${MASTER_ADDR}  --master_port=${MASTER_PORT} --use_env main.py \
    --model deit_base_patch16_LS \
    --data-path ${local_data_path} \
    --finetune ${SAVE_BASE_PATH}/pretraining/checkpoint-${epoch}.pth \
    --output_dir ${SAVE_BASE_PATH}/finetune4 \
    --batch-size 64 \
    --print_freq 400 \
    --epochs 20 \
    --smoothing 0.1 \
    --reprob 0.0 \
    --opt adamw \
    --lr 1e-5 \
    --weight-decay 0.1 \
    --input-size 224 \
    --drop 0.0 \
    --drop-path 0.2 \
    --mixup 0.8 \
    --cutmix 1.0 \
    --unscale-lr \
    --no-repeated-aug \
    --aa rand-m9-mstd0.5-inc1 \

and full args printed on the command line

Namespace(ThreeAugment=False, aa='rand-m9-mstd0.5-inc1', attn_only=False, auto_resume=True, batch_size=64, bce_loss=False, clip_grad=None, color_jitter=0.3, cooldown_epochs=10, cutmix=1.0, cutmix_minmax=None, data_path='/mnt/ddn/datasets/ILSVRC2015/train/Data/CLS-LOC', data_set='IMNET', decay_epochs=30, decay_rate=0.1, device='cuda', dist_backend='nccl', dist_eval=False, dist_url='env://', distillation_alpha=0.5, distillation_tau=1.0, distillation_type='none', distributed=True, drop=0.0, drop_path=0.2, epochs=20, eval=False, finetune='/mnt/backbone-nfs/bhheo/checkpoints/deit_codebase_deit_base_patch16_LS_800epoch_reproduce/pretraining/checkpoint-800.pth', gpu=0, inat_category='name', input_size=224, log_dir='nsmlv2', log_name='finetune', lr=1e-05, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, min_lr=1e-05, mixup=0.8, mixup_mode='batch', mixup_prob=1.0, mixup_switch_prob=0.5, model='deit_base_patch16_LS', model_ema=True, model_ema_decay=0.99996, model_ema_force_cpu=False, momentum=0.9, num_workers=10, opt='adamw', opt_betas=None, opt_eps=1e-08, output_dir='/mnt/backbone-nfs/bhheo/checkpoints/deit_codebase_deit_base_patch16_LS_800epoch_reproduce/finetune4', patience_epochs=10, pin_mem=True, print_freq=400, rank=0, recount=1, remode='pixel', repeated_aug=False, reprob=0.0, resplit=False, resume='', save_periods=['last2'], sched='cosine', seed=0, smoothing=0.1, src=False, start_epoch=0, teacher_model='regnety_160', teacher_path='', train_interpolation='bicubic', unscale_lr=True, warmup_epochs=5, warmup_lr=1e-06, weight_decay=0.1, world_size=8)

I think it is the same as your finetune setting.
I double-checked my code but I still don't know why the result is totally different.

I'm using different library versions torch : 1.11.0a0+b6df043, torchvision: 0.11.0a0, timm: 0.5.4
It might cause some problems, but there was no problem in pretraining and the performance difference is too severe for a simple library version issue.

I'm sorry to keep bothering you, but could you please let me know if there is something wrong with my setting?
Or could you please share the ViT-B weights pretrained on IN-1k 192x192 resolution without finetuning on 224x224?
If you share the weights before finetune, I can verify my finetune code without doubting my pretraining.

@TouvronHugo
Copy link
Contributor

Hi @bhheo,
Thanks for your message.
At first sight I don't see any error in your fine-tuning.
I have uploaded the 192 model to help you:
DeiT_B_192
Keep me informed.
Best,
Hugo

@bhheo
Copy link
Contributor Author

bhheo commented Jun 3, 2022

Thank you for sharing 192 model.
I will let you know if I find out an error in my code.

@TouvronHugo TouvronHugo self-assigned this Jun 7, 2022
@bhheo
Copy link
Contributor Author

bhheo commented Jun 13, 2022

Hi, @TouvronHugo

I still haven't succeeded in reproducing the finetune results, but I want to share my progress

I have tried finetuning with your 192 model weight
The result is similar to the previous result
So, I think my pretraining is not the reason for the problem
image

I found that test-crop-ratio 1.0 is not implemented in your code
So, I tried fine-tuning with test-crop-ratio 1.0
Performance is improved but still far from target performance
image

Next, I will try to downgrade my library versions, torch and timm
I will inform you if I get a result.

Regards

@bhheo
Copy link
Contributor Author

bhheo commented Jun 14, 2022

I got similar results with torch==1.7.1 and timm==0.3.4
I also tried finetuning on V100 machine.
But, not effective

image

I don't know what should I do next to reproduce.
Could you give me a full log for fine-tuning ViT-B?
I think comparing train_loss or print(args) might help to find code errors in my code.

Best

@TouvronHugo
Copy link
Contributor

This is quite strange as the most complex procedure is clearly pre-training and not finetuning.
Here the full log with train loss and memory consumption log_vit_b.txt
Best,

@TouvronHugo
Copy link
Contributor

What is interesting in your logs is that from epochs 0 it looks a bit worse than in our logs. But after the epochs 0 normally the model weights don't change too much the lr is very low. Maybe there is a problem with the interpolation of the position encoding or with the loading of the weights.
Are you using the DeiT repo code for these steps? If so I can check on my side if there is a difference between the DeiT repo and my internal code base. Normally there isn't but we never know. ;)

@bhheo
Copy link
Contributor Author

bhheo commented Jun 14, 2022

Thank you for your kindness.

Yes. I'm using the DeiT repo a2ffd162 with minor changes, such as logger and bugfix.
Checking the difference with internal code might be helpful.

I put a tab here, because it makes an undefined linear_scaled_lr error when unscale_lr is True.

deit/main.py

Line 341 in 9bfdc73

args.lr = linear_scaled_lr

Except this, I think my changes don't affect the training process.

I observed that my train loss is much lower than your logs.
So, my finetune might omit some regularization or trick. I will also double-check my setting.

@TouvronHugo
Copy link
Contributor

I'll check that.

Yes, you are right, a tab was missing. I fix that ;)

@TouvronHugo
Copy link
Contributor

Hi @bhheo,

Did you solve your finetuning issue?

I haven't had the time to compare my internal code and the public repo yet, but I should have some time in the next months.

Best,

Hugo

@bhheo
Copy link
Contributor Author

bhheo commented Jul 9, 2022

Hi @TouvronHugo

Unfortunately, the finetuning issue is not solved yet.
I have been busy with other issues.
But, I registered a few different finetune settings on our servers yesterday.
I think it might figure out what is the problem on my finetuning.
I will let you know when the experiments are done.

Best
Byeongho Heo

@Yuxin-CV
Copy link

Yuxin-CV commented Jul 9, 2022

Hi @TouvronHugo and @bhheo, I also failed to reproduce the finetune performance using the official released code & pre-trained weight.

Specifically, under the same setting & configuration of @bhheo (I use the pre-trained weight @ 192 px here #167 (comment) and set crop-pct to 1.0 during inference), my best fine-tuning performance is 83.47 (got this result within the first 5 epochs), which is also similar to @bhheo.

Since the official released code has crop-pct=0.875, so I guess there must be some other differences between the internal code and the public repo of DEiT III.

Best
Yuxin

@TouvronHugo
Copy link
Contributor

TouvronHugo commented Jul 11, 2022

Hi @bhheo and @Yuxin-CV,
I'm trying to look at this by September.
Just for information, I have launched a multi-seed experiment with my internal codebase.
I have: 83.81% +- (0.01%). So the gaps are not related to the data-aug seed.
Best,
Hugo

@bhheo
Copy link
Contributor Author

bhheo commented Jul 18, 2022

Hi

I haven't reached to 83.8% accuracy.
But, I want to share my trials. I measured accuracy at the last epoch.

  • my finetune : 83.06%
  • AdamW -> lamb : 83.39%
  • RA -> no RA : 82.44%
  • RA -> ThreeAug : 82.68%
  • CE -> BCE (no smoothing) : 83.05%
  • mixup 0.8 -> mixup 0.0 : 83.07%
  • AdamW -> lamb, mixup 0.8 -> mixup 0.0 : 83.51%
  • attn-only : 83.24%
  • wd 0.1 -> wd 0.2 : 83.15%

Lamb optimizer improve the performance.
But, I don't think it is correct setting because it has much lower accuracy at epoch 0.

Fine-tuning costs only 20 epochs. So, I can test diverse settings.
I will keep searching on it.
Please let me know if you have any suspicious setting.

Best
Byeongho Heo

@tangjiasheng
Copy link

tangjiasheng commented Jul 26, 2022

Hi,

I also have trouble reproducing some of the results. For example, I tried to reproduce deit3_huge with and without imagenet-21k pretraining.
For 1k training from scratch, I run with

python -m torch.distributed.launch
           --nproc_per_node=8 train.py \
           --model deit_huge_patch14_LS \
           --data-path ${imagenet_dir} \
           --output_dir ${output_dir} \
           --nb-classes 1000 \
           --batch 256 \
           --lr 3e-3 --epochs 800 \
           --weight-decay 0.05 --sched cosine \
           --input-size 160 \
           --reprob 0.0 \
           --smoothing 0.0 --warmup-epochs 5 --drop 0.0 \
           --seed 0 --opt lamb \
           --warmup-lr 1e-6 --mixup .8 --drop-path 0.6 --cutmix 1.0 \
           --unscale-lr --repeated-aug \
           --bce-loss  \
           --color-jitter 0.3 --ThreeAugment

But at my epoch 55, I got Max accuracy: 43.85%, which is quite lower than the number 50.34 at same epoch 55 provided in your log of README.
For 21k pretraining, when I fine-tuned my trained 21k-model, the best accuracy I can reach is 83.xx%.
But It should be pointed that the version of imagenet-21k is winter21 not fall11, which most of related works were done with. Fall11 is not provided with the official Image-Net website. Also, I'm not sure which is the correct to do train-test split with 21k, thus I use all 21k. But I don't think these can affect such big gap of the training.

Best,
Jiasheng

@Yuxin-CV
Copy link

Hi,

I also have trouble reproducing some of the results. For example, I tried to reproduce deit3_huge with and without imagenet-21k pretraining. For 1k training from scratch, I run with

python -m torch.distributed.launch
           --nproc_per_node=8 train.py \
           --model deit_huge_patch14_LS \
           --data-path ${imagenet_dir} \
           --output_dir ${output_dir} \
           --nb-classes 1000 \
           --batch 256 \
           --lr 3e-3 --epochs 800 \
           --weight-decay 0.05 --sched cosine \
           --input-size 160 \
           --reprob 0.0 \
           --smoothing 0.0 --warmup-epochs 5 --drop 0.0 \
           --seed 0 --opt lamb \
           --warmup-lr 1e-6 --mixup .8 --drop-path 0.6 --cutmix 1.0 \
           --unscale-lr --repeated-aug \
           --bce-loss  \
           --color-jitter 0.3 --ThreeAugment

But at my epoch 55, I got Max accuracy: 43.85%, which is quite lower than the number 50.34 at same epoch 55 provided in your log of README. For 21k pretraining, when I fine-tuned my trained 21k-model, the best accuracy I can reach is 83.xx%. But It should be pointed that the version of imagenet-21k is winter21 not fall11, which most of related works were done with. Fall11 is not provided with the official Image-Net website. Also, I'm not sure which is the correct to do train-test split with 21k, thus I use all 21k. But I don't think these can affect such big gap of the training.

Best, Jiasheng

Hi @tangjiasheng & @TouvronHugo, I guess there is something wrong with the input size for the ViT-H model.
The ViT-H model has patch_size=14 instead of 16, therefore the --input-size shouldn't be 160 (160 / 14 = 11.43 not an int).
I think you can try with --input-size=126 or --input-size=154 based on the Tab 6 of the paper.
image
Best
Yuxin

@TouvronHugo
Copy link
Contributor

Hi @Yuxin-CV and @bhheo,
Did you try to evaluate the 192x192 model at resolution 224x224 without fine-tuning?
(Just to have some additional insight before looking at the code in more detail )

@TouvronHugo
Copy link
Contributor

TouvronHugo commented Jul 28, 2022

Hi @tangjiasheng,
We use the ImageNet-1k val as a val during the 21k training. I think this explains the difference in the logs.
In your case you are using all 21k as val, right?
Best,

Hugo

@TouvronHugo
Copy link
Contributor

Hi @Yuxin-CV,

Yes, good catch for the resolutions. With ViT-H/14 at resolution 128 and 160 the code works it only removes a few pixels from the border of the image which does not have a significant effect. But it's cleaner to use 126 and 154 ;)

Best,

Hugo

@TouvronHugo
Copy link
Contributor

TouvronHugo commented Jul 28, 2022

Hi @Yuxin-CV and @bhheo,
I haven't had time to test but I just saw an error in the code

deit/main.py

Line 419 in cb1f48a

set_training_mode=args.finetune == '', # keep in eval mode during finetuning
.
The model must be in training mode for the finetuning of deit III. So try to replace set_training_mode=args.finetune == '' by set_training_mode=True if you have the time don't hesitate to test this ;)
(Without training mode drop-path is not activate)
Best,
Hugo
(If this doesn't solve the problem I will look into it further as promised by early September ;) )

@bhheo
Copy link
Contributor Author

bhheo commented Jul 28, 2022

@TouvronHugo
Oh, it looks critical
I will test it ASAP

Best
Heo

@tangjiasheng
Copy link

Hi @tangjiasheng, We use the ImageNet-1k val as a val during the 21k training. I think this explains the difference in the logs. In your case you are using all 21k as val, right? Best,

Hugo

I use the same setting. Validating the model on 1k set by mapping the data using class_map within timm create_dataset func. Do you use different strategy?

Best,
Jiasheng

@bhheo
Copy link
Contributor Author

bhheo commented Jul 29, 2022

Hi

I got the result, and it is almost the same as the official log.
set_training_mode=True solves the fine-tune problem.
image

Thank you for your advice @TouvronHugo

Best
Heo

@TouvronHugo
Copy link
Contributor

Hi

I got the result, and it is almost the same as the official log. set_training_mode=True solves the fine-tune problem. image

Thank you for your advice @TouvronHugo

Best Heo

Great! I just fixed that in the code by adding a train_mode argument which is true by default.

@bhheo bhheo closed this as completed Aug 2, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants