Skip to content
This repository has been archived by the owner on Aug 3, 2021. It is now read-only.

performance---the training time of transformer model is too long in mixed-precision model #203

Closed
dingsiyu opened this issue Aug 2, 2018 · 11 comments

Comments

@dingsiyu
Copy link

dingsiyu commented Aug 2, 2018

when i train the transformer in mixed-precision model, the time is so long, i.e., (1) the transformer_big.py : two GPU (V100), the parameters are all default, time per step: 13s,
(2) tensor2tensor (big model): two GPU (V100), the parameters are all default, time per step: 0.3s,but as present in mixed-precision training document (https://arxiv.org/pdf/1710.03740.pdf), the training time should be shorter than that in FP32 model. So, why?

@okuchaiev
Copy link
Member

(1) Are you using Volta GPU? (Only Volta has Tensor Cores used in mixed precision)
(2) Which version of CUDA and Tensorflow you have?

@Zrachel
Copy link

Zrachel commented Aug 3, 2018

1). Yes. We use V100, all the descriptions above is performance in V100.
2). CUDA 9.0, Tensorflow 1.8 and 1.9 are both tried, no difference.

When we use the default Transformer based model(https://github.com/tensorflow/tensor2tensor), we get the following speed:
4.3 global_steps/s with 4 V100 gpus

with OpenSeq2Seq (default configuration), we get
1 step/13 s with 2 V100 gpus in mixed mode (Gpu utility:1%) and
1 step/0.33s with 2 V100 gpu in fp32 mode (Gpu utility:90%) and
1 step/19s with 2 V100 gpus in fp16 mode (Gpu utility:1%)
for the same based model.

@okuchaiev
Copy link
Member

GPU utilization of 1% means that most of the work falls on CPU. This happens because public Tensorflow + CUDA 9.0 does not have batch gemm in float16 integrated.
This is why we require CUDA 9.1 (see https://nvidia.github.io/OpenSeq2Seq/html/mixed-precision.html) and TF built with this PR included.

I would recommend you just use NVIDIA's Tensorflow container (18.07-py3) which you can get here for free: https://ngc.nvidia.com/registry/nvidia-tensorflow . It contains cublas, cuda, cudnn + TF version tested to work with each other nicely and occasionally some GPU improvements which aren't in TF upstream yet. This way you don't need to worry about details like above.

@okuchaiev
Copy link
Member

@Zrachel and @dingsiyu were you able to get speedups using mixed precision?

@dingsiyu
Copy link
Author

dingsiyu commented Aug 8, 2018

I am not able to get speedups using mixed precision. After upgrading CUDA from 9.0 to 9.2, we have another problem, i,e,:

System information :
(1)OS Platform and Distribution : centos6.3
(2)TensorFlow installed from: conda
(3)TensorFlow version (use command below): v1.8.0(it already has the code TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);)
(4)Python version: 3.6
(5)CUDA version: 9.2
(6)cuDNN version: 7.1.4
(7)GPU model : V100

Exact command to reproduce:

with tf.device("/gpu:0"):
a = tf.random_normal(dtype=tf.float16, shape=[5, 2, 3], name='a')
b = tf.random_normal(dtype=tf.float16, shape=[5, 3, 2], name='b')
c = tf.matmul(a, b)
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True, allow_soft_placement=False))
print(sess.run(c).shape)

Describe the problem :
Matrix of fp16 does multiplication on CPU. So we are not able to speedup mixed training

logs
InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'MatMul_1': Could not satisfy explicit device specification '/device:GPU:0' because no supported kernel for GPU devices is available.
Registered kernels:
device='GPU'; T in [DT_DOUBLE]
device='GPU'; T in [DT_FLOAT]
device='GPU'; T in [DT_COMPLEX128]
device='GPU'; T in [DT_COMPLEX64]
device='CPU'; T in [DT_INT32]
device='CPU'; T in [DT_HALF]
device='CPU'; T in [DT_DOUBLE]
device='CPU'; T in [DT_FLOAT]
device='CPU'; T in [DT_COMPLEX128]
device='CPU'; T in [DT_COMPLEX64]

[[Node: MatMul_1 = BatchMatMul[T=DT_HALF, adj_x=false, adj_y=false, _device="/device:GPU:0"](a_1, b_1)]]

so, what can we do?

@okuchaiev
Copy link
Member

Can you please try NVIDIA's TensorFlow container (18.07-py3)? You can get it here for free: https://ngc.nvidia.com/registry/nvidia-tensorflow

I am not sure why upstream TF still doesn't have batched gemm in fp16 ...

@Zrachel
Copy link

Zrachel commented Aug 9, 2018

Thank you @okuchaiev . We cannot access this website. Is there any other ways (like GoogleDrive) to access this container?

@okuchaiev
Copy link
Member

The website seems up and running for me https://ngc.nvidia.com NVIDIA's TF containers are available only from there - it requires registration but it is quick and free

@dingsiyu
Copy link
Author

dingsiyu commented Aug 10, 2018

we have addressed the problem that FP16 matmul can not run on GPU, but but almost no speedup:

System information :
(1)OS Platform and Distribution : centos6.3
(2)TensorFlow installed from: conda
(3)TensorFlow version (use command below): v1.9.0
(4)Python version: 3.6
(5)CUDA version: 9.2
(6)cuDNN version: 7.1.4
(7)GPU model : V100(number : 2)

model : Openseq2seq---transformer_big.py
FP32 : batch_size = 128, 2 GPU(v100),time per step = 0.34
mixed : batch_size = 128, 2 GPU(v100),ime per step = 0.33

the speed of FP32 and mixed almost identical, why the mixed mode can not speedup the transformer model ?

@okuchaiev
Copy link
Member

@dingsiyu
I tested "transformer-big.py" and I get the following (note increase in global_step/sec):
screenshot from 2018-08-10 11-23-24

This is using NVIDIA's TF containers, 2 GPUs, OpenSeq2Seq from master branch and not using Horovod.

One thing I noticed is that my FP32 model reports around 0.424 time per step, while mixed has time closer to 0.33 (same is yours).
Can you please double-check if your FP32 model is actually FP32 and not mixed?
Did you make any changes to the model config or use different dataset?

@dingsiyu
Copy link
Author

@okuchaiev
I tested "transformer-big.py" many times last weekend, and the speed of FP32 mode always stay in 0.34 time per step. And I have checked the model config and the used dataset, I only changed the parameter of 'dtype' from tf.float32 to mixed when compare the speed between FP32 and mixed.

But i do not know if our System informations are same. I am not sure if CUDA 9.2 will influence the speed of FP32 mode.

so can you test "transformer-big.py" on the newest NVIDIA's TF containers which may have the same config with me.

thanks a lot !

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants