Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved fine-tuning, ConvNeXt support, improved training speed of GHNs #7

Merged
merged 10 commits into from Jul 26, 2022

Conversation

bknyaz
Copy link
Contributor

@bknyaz bknyaz commented Jul 21, 2022

Training times

Implementation of some steps in the forward pass of GHNs is improved to speed up the training time of GHNs without altering their overall behavior.

Speed is measured on NVIDIA A100-40GB in terms of seconds per training iteration on ImageNet (averaged for the first 50 iterations). 4xA100 are used for meta-batch size (bm) = 8. Measurements can be noisy because of potentially other users using some computational resources of the same cluster node.

Model AMP* Current version This PR Estimated total speed up for 300/150 epochs**
MLP with bm = 1 0.30 sec/iter 0.22 sec/iter 5.0 days -> 3.7 days
MLP with bm = 8 1.64 sec/iter 1.01 sec/iter 13.7 days -> 8.4 days
GHN-2 with bm = 1 0.77 sec/iter 0.70 sec/iter 12.9 days -> 11.7 days
GHN-2 with bm = 8 3.80 sec/iter 3.08 sec/iter 31.7 days -> 25.7 days
GHN-2 with bm = 8 3.45 sec/iter 2.80 sec/iter 28.8 days -> 23.4 days
  • *Automatic Mixed Precision (enabled by the --amp argument in the code)
  • **To estimate the total training time, 300 epochs is used for bm=1 and 150 epochs is used for bm=8 (according to the paper).

Fine-tuning and ConvNeXt support

According to the report (Pretraining a Neural Network before Knowing Its Architecture) showing improved fine-tuning results, the following arguments are added to the code: --opt, --init, --imsize, --beta, --layer and file ppuda/utils/init.py with initialization functions. Also argument --val is added to enable evaluation on the validation data rather than testing data during training.

  • For example, to obtain fine-tuning results of GHN-orth for ResNet-50 according to the report:
    python experiments/sgd/train_net.py --val --split predefined --arch 0 --epochs 300 -d cifar10 --n_shots 100 --lr 0.01 --wd 0.01 --ckpt ./checkpoints/ghn2_imagenet.pt --opt sgd --init orth --imsize 32 --beta 3e-5 --layer 37

  • For ConvNeXt-Base:
    python experiments/sgd/train_net.py --val --arch convnext_base -b 48 --epochs 300 -d cifar10 --n_shots 100 --lr 0.001 --wd 0.1 --ckpt ./checkpoints/ghn2_imagenet.pt --opt adamw --init orth --imsize 32 --beta 3e-5 --layer 94.
    Multiple warnings will be printed that some layers (layer_scale) of ConvNeXt are not supported by GHNs, which is intended.

A simple example to try parameter prediction for ConvNeXt is to run:

python examples/torch_models.py cifar10 convnext_base

Code correctness

To make sure that the evaluation results (classification accuracies of predicted parameters) reported in the paper are the same as in this PR, the GHNs were evaluated on selected architectures and the same results were obtained (see the table below).

Model ResNet-50 ViT Test Architecture (index in the test split)
GHN-2-CIFAR-10 (top 1 acc) 58.6% 11.4% 77.1% (210)
GHN-2-ImageNet (top5 acc) 5.3% 4.4% 48.3% (85)

To further confirm the correctness of the updated code, the training loss and top1 accuracy of training GHN-2 on CIFAR-10 for 3 epochs are reported in the table below. The command used in this benchmark is: python experiments/train_ghn.py -m 8 -n -v 50 --ln.

Version Epoch 1 Epoch 2 Epoch 3
Current version loss=2.41, top1=17.23 loss=2.02, top1=20.62 loss=1.94, top1=24.56
This PR loss=2.51, top1=17.58 loss=2.01, top1=21.62 loss=1.90, top1=25.88

These results can be noisy because of several factors like random batches, initialization of GHN, etc.

Other

Python script experiments/train_ghn_stable.py is added to automatically resume training GHNs from the last saved checkpoint (if any) if the run failed for some reason (e.g. OOM, nan loss, etc.).
Now instead of running python experiments/train_ghn.py -m 8 -n -v 50 --ln one can use
python experiments/train_ghn_stable.py experiments/train_ghn.py -m 8 -n -v 50 --ln.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 21, 2022
@bknyaz bknyaz mentioned this pull request Jul 21, 2022
@michaldrozdzal michaldrozdzal merged commit b038f45 into facebookresearch:main Jul 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants