Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[src] Adding GPU/CUDA lattice batched decoder + binary #3114

Merged
merged 21 commits into from Apr 26, 2019

Conversation

hugovbraun
Copy link
Contributor

Adding the GPU lattice, batched, decoder. We're currently at 3500 XRTF end-to-end on V100, while generating full lattices, using this decoder and a TDNN acoustic model.
Using the same model, we get 381 XRTF on a dual socket Titanium Xeon, using all 96 hardware threads.
We'll publish more performance results in the next few days, alongside with a docker container containing this branch ready to use, allowing everybody to easily run the binary and test performance.

This PR has two components:

  • cudadecoder/ , which contains the decoder itself.
  • cudadecoderbin/, which contains a binary using the decoder. It can be used directly or viewed as an example on how to use the (batched) cuda decoder.

We've been testing stability/correctness on a variety of models.

Other authors: @luitjens @ryanleary
Many thanks to @danpovey @chenzhehuai @galv for their help during the pre-review process

@hugovbraun hugovbraun changed the title [src] Adding GPU/Cuda lattice batched decoder + binary [src] Adding GPU/CUDA lattice batched decoder + binary Mar 15, 2019
@luitjens
Copy link
Contributor

Dan any idea why travis is failing with a link outside of the code this patchset touches?

/home/travis/build/kaldi-asr/kaldi/src/online/online-audio-source.cc:161: undefined reference to PaUtil_GetRingBufferWriteAvailable' /home/travis/build/kaldi-asr/kaldi/src/online/online-audio-source.cc:164: undefined reference to PaUtil_WriteRingBuffer'
kaldi-online.a(online-audio-source.o): In function OnlinePaSource': /home/travis/build/kaldi-asr/kaldi/src/online/online-audio-source.cc:71: undefined reference to PaUtil_InitializeRingBuffer'
kaldi-online.a(online-audio-source.o): In function kaldi::OnlinePaSource::Read(kaldi::Vector<float>*)': /home/travis/build/kaldi-asr/kaldi/src/online/online-audio-source.cc:122: undefined reference to PaUtil_GetRingBufferReadAvailable'
/home/travis/build/kaldi-asr/kaldi/src/online/online-audio-source.cc:122: undefined reference to PaUtil_GetRingBufferReadAvailable' /home/travis/build/kaldi-asr/kaldi/src/online/online-audio-source.cc:138: undefined reference to PaUtil_ReadRingBuffer'

@danpovey
Copy link
Contributor

danpovey commented Mar 15, 2019 via email

@ChunhuiWang-China
Copy link
Contributor

Interesting, when compiling this branch, cuda-decoder.o happens error, /tools/openfst/include/fst/push.h:132:126: error: ‘g’ was not declared in this scope
total_weight = GallicWeight(
^
/tools/openfst/include/fst/push.h:132:193: error: template argument 2 is invalid
total_weight = GallicWeight(
^

--error 0x1 --

make: *** [cuda-decoder.o] Error 1

@btiplitz
Copy link
Contributor

@Plusmonkey I've had that exact error before. I believe the cause was something in my tool chain was out of date.

@danpovey
Copy link
Contributor

danpovey commented Mar 19, 2019 via email

@btiplitz
Copy link
Contributor

@danpovey I tracked down when I got this exact issue. I was compiling with Ubuntu 14. The default compiler there is 4.8.4. It seems there was a bug in 4.8.4 and I needed to have 4.8.5 to get past this. Having 4.9.4 around, I upgraded to that version. I looked in the changelog and I am unsure how I determined the issue was fixed in 4.8.5.

@Plusmonkey can you tell me which version of gcc you are running.

Copy link
Contributor

@btiplitz btiplitz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line34 of cudadecoder/Makefile, please add (the non-cuda build fails without this)
else
depend:

@btiplitz
Copy link
Contributor

@hugovbraunOn line34 of cudadecoder/Makefile, please add (the non-cuda build fails without this) - sorry the browser at work won't allow in-line comments.
else
depend:

@btiplitz
Copy link
Contributor

@danpovey windows build is almost done. 1 error at a time

@hugovbraun
Copy link
Contributor Author

hugovbraun commented Mar 22, 2019

Thanks @btiplitz and @chenzhehuai I'll add those shortly. Just found a bug that appeared after the commits squash (after rebasing on the latest upstream), fixing that and then adding the new rules

@hugovbraun
Copy link
Contributor Author

The bug was with the aspire model and was due to the rebasing on top of last week's master. The good news is that rebasing on today's master solved the issue.
@btiplitz when opening the file to apply your changes I noticed that something was wrong with the nested if statements (in that version: https://github.com/kaldi-asr/kaldi/blob/080c805665c54329f716a8b7f08378c51ba57299/src/cudadecoder/Makefile ). I fixed it, but I'm not sure what the "else" in your change must refer to now, could you check the latest ?

@btiplitz
Copy link
Contributor

@hugovbraun the else only runs on non gpu builds and basically does nothing.

@hugovbraun
Copy link
Contributor Author

@btiplitz updated, the non-cuda builds now have an empty depend target

@cloudhan
Copy link
Contributor

cloudhan commented Mar 23, 2019

@Plusmonkey are you using CUDA 8.0?

I also failed building this branch with CUDA 8.0.61 and gcc 5.4.0 with the exactly same error. But it compiles and runs smoothly with CUDA 9.0.

@hugovbraun Does this decoder officially support CUDA8?

@luitjens
Copy link
Contributor

luitjens commented Mar 23, 2019 via email

@ChunhuiWang-China
Copy link
Contributor

@btiplitz Hi , I use gcc version 4.8.5 20150623 (Red Hat 4.8.5-36) (GCC)

@cloudhan
Copy link
Contributor

cloudhan commented Mar 25, 2019

@luitjens @Plusmonkey
I drilled down the /tools/openfst/include/fst/push.h:132:126: error: ‘g’ was not declared in this scope and it seems to be related with compiler

in include/fst/push.h

      total_weight = GallicWeight(
          ptype & kPushRemoveCommonAffix
              ? total_weight.Value1()
              : StringWeight<Label, GallicStringType(gtype)>::One(),
          ptype & kPushRemoveTotalWeight ? total_weight.Value2()
                                         : Weight::One());

in the generated cuda-decoder-kernels.cu.cpp.ii at line 214785

total_weight = GallicWeight((ptype & kPushRemoveCommonAffix) ? (total_weight.Value1()) : StringWeight< typename Arc::Label, (g == 0) ? STRING_LEFT : ((g == 1) ? STRING_RIGHT :        STRING_RESTRICT)> ::One(), (ptype & kPushRemoveTotalWeight) ? (total_weight.Value2()) : Weight::One());

Apparently GallicStringType has been expended, which is defined as

constexpr StringType GallicStringType(GallicType g) {
  return g == GALLIC_LEFT
             ? STRING_LEFT
             : (g == GALLIC_RIGHT ? STRING_RIGHT : STRING_RESTRICT);
}

and unfortunately the formal parameter g was not replace by the actually parameter gtype.

If I manually replace the the g with gtype at line 214785 of cuda-decoder-kernels.cu.cpp.ii , this file compiles smoothly.

@ChunhuiWang-China
Copy link
Contributor

@luitjens @cloudhan Yes, I alse use cuda 8.0

@hugovbraun
Copy link
Contributor Author

Thanks @cloudhan for looking into that. I'll remove the openfst dependency for the kernel files, it is not really needed anyway. It should fix this issue.

@hugovbraun
Copy link
Contributor Author

I added wrappers around the kernels and I did some refactoring so that only the kernels are compiled by nvcc, other files are compiled directly by (pure) g++. OpenFst is never compiled by nvcc, fixing the cuda8.0 compilation issue.

@danpovey
Copy link
Contributor

Nice-- easy compilation is important.
We were trying to compile the marian project on the CLSP grid and it was no end of trouble because they use things like std::tuple in code that is compiled by nvcc.

@ChunhuiWang-China
Copy link
Contributor

@hugovbraun Hi , Can you give an example about how to use this cuda decoder?

Copy link
Contributor

@luitjens luitjens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hugovbraun Hi , Can you give an example about how to use this cuda decoder?

There are already examples in the code.

If you want to just do an offline decoding with nnet3 & ivectors then see: src/cudadecoderbin/batched-wav-nnet3-cuda.cc

If you want to use just the decoder then you can look at: src/cudadecoder/batched-threaded-cuda-decoder.cc for an example or modify it to match your requirements.

To use just the decoder you will need to come up with an efficient way to batch together audio files. That is what is implemented in the batched-threaded-cuda-decoder.cc.

Note we don't have an example of online decoding at the moment as it wasn't a target of our initial work. In theory the decoder supports it but a user of the decoder would have to handle channels/lanes appropriately.

src/cudadecoderbin/batched-wav-nnet3-cuda.cc Outdated Show resolved Hide resolved
@dpny518
Copy link
Contributor

dpny518 commented Apr 11, 2019

@ hugovbraun if I want to use this with other packages that create binaries
https://github.com/jimbozhang/kaldi-gop
What steps would need to be done

@cloudhan
Copy link
Contributor

If I want to keep the beam and lattice-beam unchanged, then what is the key factor that will affect the decoding speed?

@luitjens
Copy link
Contributor

luitjens commented Apr 12, 2019 via email

@luitjens
Copy link
Contributor

luitjens commented Apr 12, 2019 via email

@cloudhan
Copy link
Contributor

cloudhan commented Apr 12, 2019

@luitjens To be clear, I mean with the same machine, GPU, model, audio input, what will affect this decoder's throughput. there seems to be some option will affect the scheduling, e.g.

--max-batch-size=200
--batch_drain_size=10
--cuda-control-threads=2
--cuda-worker-threads=16
--max-outstanding-queue-length=4000

from the documentation, I can see the larger max-batch-size, the faster, the more memory hungry.
Any other option will also affect the speed, aka, throughput.

As I can see from nvidia-smi, the power consumption is only 145/250W for a 1080Ti, and the temperature is only 45 degree C, so I don't think it is busy enough...

edit:
OK, so now I switched back to cuda10, it runs magnitude faster. Maybe add some warning for cuda8 user for the degenerated performance(speed). But now the GPU becomes even more idle.

#define KALDI_CUDA_DECODER_DIV_ROUND_UP(a, b) ((a + b - 1) / b)

#define KALDI_CUDA_DECODER_ASSERT(val, recoverable) \
{ \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use do { } while(0) for these macros. It's important.

I know that there are some macros in kaldi that don't do this. But still.

// M is usually the batch size
inline dim3 KaldiCudaDecoderNumBlocks(int N, int M) {
dim3 grid;
// TODO MAX_NUM_BLOCKS.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we need a max number of blocks? Not all blocks need to be resident on the GPU at once, anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a few reasons:

  • some kernels aggregate values locally in shared memory first (like the histogram kernel). If the same CTA is in charge of more data (iterating on thread_idx += gridDim.x*blockDim.x), using the shared memory for local aggregates, it will push only once the local aggregates to the global memory. Those global aggregates are atomics so it's a big deal.
  • Some data is loaded only once in registers and reused when the CTA loops on thread_idx += gridDim.x*blockDim.x
  • Lauching/terminating CTAs has a cost
  • More specific to our case is: the amount of work we have to do for each frame depends on that frame (and the previous ones). Which means that each utterances in a batch will have a different number of tokens to process for instance. Because we launch a 2D grid, limiting the number of CTAs allow some load balancing to appear (if we always use the max to decide the x-dimension, all other utterances in the batch will launch CTAs at the end for nothing).

All of this is not important when we generate only a few thousands tokens / frame, but it really speeds up the kernels for challenging models (for this, we usually test on the aspire model evaluated on librispeech/test_other). It gives around 10x on histogram, 2x on expand, etc.

And for older arch, max # of CTAs / dim[x|y] is ~65000. Even if this is very high for us it's still cleaner to put a clear upper limit

// and the d_cutoff
// We use a 1:1 conversion between CostType <--> IntegerCostType
// IntegerCostType is used because it triggers native atomic operations
// (CostType does not)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See 15-2 Comparing Floating-Point Numbers Using Integer Operations in Hacker's Delight Second Edition to learn more about the 1:1 çonversion.

// Data is stored as 2D matrices (BatchSize, 1D_Size)
// For example, for the token queue, (BatchSize, max_tokens_per_frame_)
// DeviceMatrix owns the data but is not used to access it.
// DeviceMatrix is inherited in DeviceLaneMatrix and DeviceChannelMatrix
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the parent of

to the device. For best performance this should be between 2-4.
cuda-worker-threads: CPU threads for worker tasks like determinization and
feature extraction. For best performance this should take up all spare
CPU threads available on the system.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it should be MY_NUM_CORES - cuda-control-threads, if you assume that your server is otherwise free, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We could load std::thread::hardware_concurrency() by default if you think it's a good idea

cuda-control-threads: Each control thread is a concurrent pipeline. Thus
the GPU memory scales linearly with this parameter. This should always be
at least 2 but should probably not be higher than 4 as more concurrent
pipelines leads to more driver contention reducing performance.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any public documentation on which driver versions reduce the amount of kernel-call lock contention?

// Hashmap value. Used when computing the hashmap in PostProcessingMainQueue
struct __align__(16) HashmapValueT {
// Map key : fst state
int32 key;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StateId would be better. You're guaranteed that it's an int32 since you have the align pragma.

// Number of tokens associated to that state
int32 count;
// minimum cost for that state + argmin
int2 min_and_argmin_int_cost;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the best part of the whole code base. I just love how you can just shove the argmin into the 32 bits below the 32-bit score, and atomicMin on 64-bit integers preserves the argmin.

}

// binsearch_maxle (device)
// With L=[all indexes low<=i<=high such as vec[i]<= val]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is L here?

You should document that this code assumes that vec[low] <= val <= vec[high-1] , since it seems to assume that anyway.

return uold.i2;
}

// We should switch to native atom64 on atomicMinI2 and atomicSubI2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, why isn't this using atomicMin overloaded for unsigned long long int? It seems like it would be way faster that way. Do you really need to support compute capability 3.0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is in the list. For 3.0 I guess it is still supported in the Kaldi makefiles (we can keep the duplicate and add a #if CUDA_ARCH)

@galv
Copy link
Contributor

galv commented Apr 12, 2019

@hugovbraun Have you uploaded your GTC 2019 slides anywhere?

@hugovbraun
Copy link
Contributor Author

@cloudhan with beam and lattice-beam fixed, the most important factor for speed is the batch size. For now we don't support batch size > 200, but it will be part of a patch in the next weeks. You can try increasing cuda-control-threads for now.
Printing a warning for cuda 8.0 is probably is good idea.

@galv Thanks a lot for the review. Kernel call lock contention has been improved continuously since cuda 9. This is still work in progress.
However the recommended path going forward is clearly to use batches as big as possible. For a good model/well aligned eval data (leading to the creation a few thousands tokens/frame), going from batch size 1 to batch size 200 is almost free for the decoder (constant time). For more challenging models the remaining CPU tasks on the main cuda control thread take more time (time linear w/ batch size), but I've got a patch in progress to remove those. Those tasks are the host2host memory copies, and GetRawLattice currently called on the main thread.

@hugovbraun
Copy link
Contributor Author

@galv GTC19 slides : GTC19-Kaldi-Acceleration.pdf
It does not contain anything too technical though. It was a high level presentation of what we've done and performance numbers in various configurations.

@danpovey danpovey merged commit b8a35fd into kaldi-asr:master Apr 26, 2019
@danpovey
Copy link
Contributor

Thanks, guys!

@hugovbraun hugovbraun deleted the master-develop branch June 12, 2019 22:35
danpovey pushed a commit to danpovey/kaldi that referenced this pull request Jun 19, 2019
@lkf123010
Copy link

only RealTimeX 9.4x ,

batched-wav-nnet3-cuda --max-batch-size=200 --iterations=4 --cuda-control-threads=2 --cuda-worker-threads=50 --acoustic-scale=1.0 --beam=15.0 --max-active=10000 --lattice-beam=4.0 --ivector-silence-weighting.max-state-duration=40 --ivector-silence-weighting.silence-phones=1 --ivector-silence-weighting.silence-weight=0.00001 --config=/asr/kaldi/egs/aishell2/s5_nopitch/no_pitch_model/online_test/conf/online.conf --word-symbol-table=/asr/kaldi/egs/aishell2/s5_nopitch/no_pitch_model//words.txt /asr/kaldi/egs/aishell2/s5_nopitch/no_pitch_model//final.mdl /asr/kaldi/egs/aishell2/s5_nopitch/no_pitch_model//HCLG.fst 'scp:echo invite1_right2 invite1_right2.wav|' 'ark:|lattice-scale --acoustic-scale=10.0 ark:- ark:- | gzip -c >local/tmp/lat.invite1_right2.gz'
LOG (batched-wav-nnet3-cuda[5.5.673-fa957]:SelectGpuId():cu-device.cc:223) CUDA setup operating under Compute Exclusive Mode.
LOG (batched-wav-nnet3-cuda[5.5.673-fa957]:FinalizeActiveGpu():cu-device.cc:308) The active GPU is [0]: Tesla M40 24GB free:22689M, used:256M, total:22945M, free/total:0.988832 version 5.2


LOG (batched-wav-nnet3-cuda[5.5.673-fa957]:main():batched-wav-nnet3-cuda.cc:304) ~Group 3 completed Aggregate Total Time: 422.199 Audio: 3994.74 RealTimeX: 9.46175
LOG (batched-wav-nnet3-cuda[5.5.673-fa957]:main():batched-wav-nnet3-cuda.cc:315) Decoded 1 utterances, 0 with errors.
LOG (batched-wav-nnet3-cuda[5.5.673-fa957]:main():batched-wav-nnet3-cuda.cc:317) Overall likelihood per frame was 1.18383 per frame over 399468 frames.
LOG (batched-wav-nnet3-cuda[5.5.673-fa957]:main():batched-wav-nnet3-cuda.cc:320) Overall: Aggregate Total Time: 422.199 Total Audio: 3994.74 RealTimeX: 9.46175

@luitjens
Copy link
Contributor

luitjens commented Apr 15, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet