Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

The batch-sizes of single machine commands are not adjusted #33

Closed
datumbox opened this issue Jan 21, 2022 · 7 comments
Closed

The batch-sizes of single machine commands are not adjusted #33

datumbox opened this issue Jan 21, 2022 · 7 comments

Comments

@datumbox
Copy link
Contributor

On the training doc, I believe we need to adjust the batch-size (or the LR) on the single machine commands to maintain the total batch-size the same.

For example, currently the ConvNeXt-S reports:

  • Multi-node: --nodes 4 --ngpus 8 --batch_size 128 --lr 4e-3
  • Single-machine: --nproc_per_node=8 --batch_size 128 --lr 4e-3 <- I believe here it should be --batch_size 512

Same applies for the other variants.

@liuzhuang13
Copy link
Contributor

liuzhuang13 commented Jan 21, 2022

Hi Vasilis, thanks for your question!

We use --update_freq argument to control the total batch size as well. It is the gradient accumulation step for every parameter update. The total effective batch size is num_gpus* bs_per_gpu * update_freq. We do adjust the update_freq to be 4 in the case you pointed and others; otherwise the GPU memory is not enough for a single machine to train at 4096 batch size.

@datumbox
Copy link
Contributor Author

datumbox commented Jan 21, 2022

Oh makes sense, TorchVision scripts don't support the specific parameter and I didn't go through the code to see that it handles it. Thanks for the clarification!

@liuzhuang13
Copy link
Contributor

Glad I could help. Do you have a workaround for training, if TorchVision scripts don't support the specific parameter (update_freq)? Could you train with multiple nodes?

@datumbox
Copy link
Contributor Author

Yes we train with multiple nodes like you do, using submitit. If we run into a situation where you can't fit all in a single node, we provide on the documentation the single-machine equivalent of the command but indicate that in practice this was trained with X machines and Y GPUs. Part of the reason why I didn't notice the extra parameter is that our training script works similarly. So in this case, I assumed that you did the same and you missed to do the conversion.

Your approach of update_freq sounds very interesting as I understand it allows you to reduce the batch size for memory reasons but without affecting the total batch size and thus not requiring re-parameterization of LR, epochs etc. I wonder if you would be interested in contributing this to TorchVision's script. If you get the time to do it, I'm happy to help you with the review.

@liuzhuang13
Copy link
Contributor

We inherited this gradient accumulation feature from the BEiT codebase. In the coming week we would be busy with some other paper related work so I'm not sure I can help contribute this in a short time. If after one week it is still relevant or needed, I'm happy to contribute. For me the main thing to work out would be the process; the code part should be simple enough.

BTW, the conversion for single machine you mentioned (using --batch_size 512) seems very likely to OOM on a typical machine for a rather large model...

@anonymoussss
Copy link

hi, I am trying to reproduce the convnext-tiny result on ImageNet-1k.

I hava a single machine with 8 gpus, so I set nproc_per_node=8. But I didn't notice the setting you mentioned for “the effective batch size 4096 ” in the training doc.I just set "nproc_per_node=8, batch_size=96, update_freq=4", so my effective batch size is only 8964=3072.

Will it affect my reproduction results?
@liuzhuang13

@liuzhuang13
Copy link
Contributor

Hi @anonymoussss,

It is possible to affect reproduction results, as each batch size will have different optimal learning rates. It is common practice to scale lr in proportion to batch size, meaning you may use 3e-3 (instead of 4e-3 for 4096) as learning rate if your effective batch size is 3072.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants