Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change to use apex for better fp16 and multi-gpu support #116

Merged
merged 2 commits into from
Dec 13, 2018

Conversation

FDecaYed
Copy link
Contributor

@FDecaYed FDecaYed commented Dec 12, 2018

Hi there,

This PR includes changes to improve FP16 and multi-gpu performance. We get over 3.5x performance increase on Tesla V100 across all examples.

NVIDIA Apex(https://github.com/NVIDIA/apex) is added as a new dependency. It fixed issues with existing fp16 implementation(for example not converting loss/grad to float before scaling) as well as provide a more efficient implementation.

Below is test results we run on MRPC and SQuAD examples. All test baselines(before numbers) are fp32, since we found it actually is the best performing config. Reason being optimizer is forced on cpu under fp16.
The after numbers are running with --fp16 after this PR. All tests done on single tesla V100 16GB.
MRPC on BERT-base:

before: 109 seconds, 9GB memory needed
after: 27 seconds, 5.5GB
speedup: 4x

SQuAD on BERT-base:

before: 90 minutes, 12.5GB
after: 24 minutes, 7.5GB
speedup: 3.75x

SQuAD on BERT-large:

before: 250 minutes, 15GB, with --train_batch_size 24 --gradient_accumulation_steps 6
after: 68 minutes, 14.5GB, with --train_batch_size 24 --gradient_accumulation_steps 3
speedup: 3.68x

optimize_on_cpu option is also removed entirely from code since I can't find any situation where it is faster than gradient_accumulation_steps. Of course assuming at least batch 1 can fit into GPU memory.

@thomwolf
Copy link
Member

That's really awesome! I love the work you guys did on apex and I would be super happy to have an 'official' implementation of BERT using apex (plus it showcases all the major modules: FusedAdam, FusedLayerNorm, 16bits, distributed optimizer...). And the speed improvement is impressive, fine-tuning BERT-large on SQuAD in 1h is amazing!

Just three general questions:

  1. could you reproduce the numerical results of the examples (SQuAD and MRPC) with this implementation?
  2. did you test distributed training?
  3. the main issue I see right now is the fact that apex is not on pypi and users have to manually install it. Now that pytorch-pretrained-bert is used as a dependency in downstream librairies like AllenNLP it's important to keep a smooth install process. Can you guys put apex on pypi? If not we should add some logic to handle the case when apex is not installed. It's ok for the examples (run_classifier and run_squad) which are not part of the package per se but the modifications in modeling.py needs to be taken care of.

@FDecaYed
Copy link
Contributor Author

Hi @thomwolf ,

  1. I have been able to reproduce numerical results of the examples. It shows some variance with different random seeds, especially with MRPC. But that should be somewhat expected and overall the results seems the same as baseline.
    For example, I got {"exact_match": 84.0491958372753, "f1": 90.94106705651285} running SQuAD BERT-Large with default dynamic loss scaling and seed. I did not store other results since they should be very easy to re-run.
  2. I sanity checked distributed training results while developing. I'll run more results and post it here.
  3. Adding fallback to modeling.py should be easy since we can use BertLayerNorm in there. We just need to make sure it share the same interface. For example parameter names, in case user wants to build groups base on names. As for pypi, @mcarilli what's you thought?

-Deyu

@FDecaYed
Copy link
Contributor Author

FDecaYed commented Dec 12, 2018

update:

  1. I have tested SQuAD BERT-Large with 4 V100 on a DGX station. Here is the result:
training time: 20:56
speedup over 1 V100: 3.2x
evaluation result: {"exact_match": 83.6329233680227, "f1": 90.68315529756794}

command used:
python3 -m torch.distributed.launch --nproc_per_node=4 ./run_squad.py --bert_model bert-large-uncased --do_train --do_predict --do_lower_case --train_file $SQUAD_DIR/train-v1.1.json --predict_file $SQUAD_DIR/dev-v1.1.json --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --output_dir /tmp/debug_squad/ --train_batch_size 6 --fp16

  1. I modified model.py so it now will fallback to BertLayerNorm when apex is not installed.
    Parameters gamma, beta are changed to weight, bias.

-Deyu

@thomwolf
Copy link
Member

Ok thanks for the update!

It looks good to me, I will do a few tests on various hardwares and it'll be included in the new 0.4.0 release coming out today (hopefully)

Congrats on the MLPerf results by the way!

@thomwolf thomwolf merged commit 91aab2a into huggingface:master Dec 13, 2018
@thomwolf
Copy link
Member

thomwolf commented Dec 13, 2018

@FDecaYed I am trying to reproduce your numbers but I can't get very close. I am using an Azure NDv2 server with 8 NVIDIA Tesla V100 NVLINK interconnected GPUs and 40 Intel Skylake cores.

Switching to fp16 lowers the memory usage by half indeed but the training time stays about the same ie around (e.g. 100 seconds for run_classifier on 1 GPU and about 50 minutes for the 2 epochs of your distributed training command on run_squad, with 4 GPUs in that case).

I have the new release of PyTorch 1.0.0, CUDA 10 and installed apex with cpp/cuda extensions. I am using the fourth-release branch on the present repo which was rebased from master with your PR.

If you have any insight I would be interested. Could the difference come from using a DGX versus an Azure server? Can you give me the exact command you used to train the run_classifier example for instance?

@FDecaYed
Copy link
Contributor Author

FDecaYed commented Dec 13, 2018

there could be a lot of things, let's sort them out one by one:
The command I used for MRPC example is
CUDA_VISIBLE_DEVICES=0 python3 ./run_classifier.py --task_name MRPC --do_train --do_eval --do_lower_case --data_dir $GLUE_DIR/MRPC/ --bert_model bert-base-uncased --max_seq_length 128 --train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir /tmp/mrpc_output/ --fp16
CUDA_VISIBLE_DEVICES is to make sure only one GPU is used. I noticed the code is using Dataparallel when there is only one process but more than 1 GPU in the box. torch.nn.DataParallel may not provide good speed on some cases. Are you running just one GPU on you 100 sec run? I reported time print by tqdm trange, is that the same number you are talking about here?

From my past experience with cloud, single GPU number should not be that far from any DGX, unless you are bound by input. I doubt that's the case base on the workload. If we indeed are running and reporting the same thing, there must be some software differences. We are still in the progress moving up to pytorch 1.0, so my test was on 0.4. I'll merge your release branch and try on pytorch 1.0 on my side on DGX today.

Meanwhile, this is the container I used for testing. You could try it on Azure and see if you can get my result. Note that it does not have latest apex installed, so you need uninstall apex and build latest inside.
https://ngc.nvidia.com/catalog/containers/nvidia%2Fpytorch

-Deyu

@thomwolf
Copy link
Member

thomwolf commented Dec 13, 2018

Thanks for the quick reply!

The timing I was reporting was the full timing for the training (3 iterations for the MRPC example).
Using your MRPC example command I get this example from training on a single V100: about 1 min 24 second of training, ie. around 84 seconds (~27 seconds per iteration).
Using static loss scale gives the same results.
image

And training without 16bits gives a total training time roughly similar: 1 min 31 seconds
image

@FDecaYed
Copy link
Contributor Author

I tested on pytorch 1.0 and still getting the same speed up
screenshot from 2018-12-13 14-52-11
I used the foruth-release branch and public dockerhub 1.0-cuda10.0-cudnn7-devel image here:
https://hub.docker.com/r/pytorch/pytorch/tags/
Only modification I need was adding encoding='utf-8' reading csv.
Could you run the same docker image and see if the speed is still the same? If so, could you do a quick profile with nvprof -o bert-profile.nvvp with just training 1 epoch and share the output? I don't have access to Azure now.

@thomwolf
Copy link
Member

Ok, I got the 3-4x speed-up using the pytorch dockerhub 1.0-cuda10.0-cudnn7-devel image 🔥
Thanks a lot for your help!

I'm still wondering why I can't get these speedups outside of the docker container so I will try to investigate that a bit further (in particular since other people may start opening issues here :-).

If you have any further insight, don't hesitate to share :-)

@thomwolf
Copy link
Member

Ok nailed it I think it was a question of not installing cuda100 together with pytorch.
Everything seems to work fine now!

@FDecaYed
Copy link
Contributor Author

FDecaYed commented Dec 14, 2018

Great! It'll be great if we can later update readme to document V100 expected speed as well.

@donglixp
Copy link
Contributor

Thanks for the nice work! @FDecaYed @thomwolf

I tried fp16 training for bert-large. It has the imbalanced memory problem, which wastes gpu power a lot. The nvidia-smi results are shown as follows:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.79       Driver Version: 410.79       CUDA Version: 10.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla V100-PCIE...  Off  | 0000A761:00:00.0 Off |                    0 |
| N/A   39C    P0   124W / 250W |  15128MiB / 16130MiB |     99%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-PCIE...  Off  | 0000C0BA:00:00.0 Off |                    0 |
| N/A   41C    P0   116W / 250W |  10012MiB / 16130MiB |     95%      Default |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-PCIE...  Off  | 0000D481:00:00.0 Off |                    0 |
| N/A   38C    P0    80W / 250W |  10012MiB / 16130MiB |     91%      Default |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-PCIE...  Off  | 0000EC9F:00:00.0 Off |                    0 |
| N/A   40C    P0    61W / 250W |  10012MiB / 16130MiB |     95%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     11870      C   python                                     15117MiB |
|    1     11870      C   python                                     10001MiB |
|    2     11870      C   python                                     10001MiB |
|    3     11870      C   python                                     10001MiB |
+-----------------------------------------------------------------------------+

qwang70 pushed a commit to DRL36/pytorch-pretrained-BERT that referenced this pull request Mar 2, 2019
Change to use apex for better fp16 and multi-gpu support
@Oxi84
Copy link

Oxi84 commented Jul 22, 2019

Is it already in add in pytorch-transformers? If so how do I use it, where should i specify the settings that I want to use Fp16 and apex and is apex already added in installation of pytorch transformers on anaconda 3?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants