-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OpenCL counterpart of cuDNN #34
Comments
@dagamayank I can also provide you example kernel strings if you don't want to look at that part of the code and are only interested in providing help on optimizing the kernels for AMD GPUs, which would also be very welcome. |
@naibaf7 Have you seen last updates on Tensorflow thread? |
@bhack |
Cause I think that your work could fit fine in https://docs.google.com/spreadsheets/d/1YbHn7dAFPPG_PgTtgCJlWhMGorUPYsF681TsZ4Y4LP0/edit?usp=sharing |
@naibaf7 On Wed, May 25, 2016 at 2:24 AM, Fabian Tschopp notifications@github.com
Mayank Daga |
@dagamayank ss << generate_bw_defs();
ss << generate_bw_kernels("conv_backward");
ss << generate_wg_defs();
ss << generate_wg_kernels("conv_weights");
// Write complete kernel string
kernel_ = ss.str();
// std::cout << kernel_ << std::endl;
} (it's line https://github.com/naibaf7/caffe/blob/master/src/caffe/greentea/libdnn.cpp#L1588) This will give you the kernel string in typedef enum {
// Stack the batch update into one GEMM block
// (deterministic, 1 kernel call)
// Serializes the batch and may therefore under use
// the GPUs compute units.
LIBDNN_CONVOLUTION_WG_ALGO_DIRECT = 0,
// Use multiple GEMM blocks in parallel and update weights atomically
// (non deterministic, 1 kernel call, not supported on all devices)
// Parallelizes the batch and has therefore higher GPU usage.
LIBDNN_CONVOLUTION_WG_ALGO_ATOMIC = 1,
// Use multiple GEMM blocks and an intermediate buffer
// to reduce weight updates
// (deterministic, >= 2 kernel calls)
// Parallelizes the batch and has therefore higher GPU usage.
// NOT IMPLEMENTED YET
LIBDNN_CONVOLUTION_WG_ALGO_REDUCTION = 2
} libdnnConvolutionWeightAlgo_t;
typedef enum {
// Transform data before GEMM (load, im2col, gemm, store)
// This method is suitable for convolutions with similar
// spatial input == output sizes, but can become inefficient
// if input >> output (with large strides and kernels).
LIBDNN_CONVOLUTION_BW_ALGO_IM2COL = 0,
// Transform data after GEMM (load, gemm, col2im, store)
// Sometimes faster than im2col method, but uses
// atomic operations and is not deterministic.
LIBDNN_CONVOLUTION_BW_ALGO_COL2IM_ATOMIC = 1
} libdnnConvolutionBackwardAlgo_t; which one is being used can be changed here: Finally, you need to run a network in order to instantiate the layers and get some kernel strings. The recommended starting point for that is using the following command:
Together with the instructions above, you can dump the kernel strings to a text file like that, and look for optimization possibilities that way. Note that every convolution layer gets its own set of kernels, so the above command will give you many different ones. |
@naibaf7 |
I get failure errors on running "make runtest" on the code in master branch of your repo. Is this expected? Two of the errors are from libDNN. My development environment is AMD W9100 and Ubuntu 14.04. [----------] Global test environment tear-down |
@dagamayank The _Spatial failures are from Intel's convolution implementation. I think the fix here is to use the latest ViennaCL development branch: https://github.com/viennacl/viennacl-dev instead of what Ubuntu supplies. As for the libDNN, this test should definitely not fail. Here it would be helpful to get the failure message from the runtest itself (i.e. where the runtest on libdnn aborted. You can test this in detail by using: |
@naibaf7 [----------] Global test environment tear-down |
@dagamayank Do you have any other OpenCL device to check if the backward pass passes the test? |
@naibaf7 I am going through the kernels right now. Can you mention the reason for On Wed, Jun 1, 2016 at 2:36 AM, Fabian Tschopp notifications@github.com
Mayank Daga |
@dagamayank I put it into defines rather than directly into the kernel string for better readability of the kernel itself (i.e. easier to see where a constant is used and why). |
Are you using autotuning to generate the values of those constants? In On Wed, Jun 1, 2016 at 9:01 AM, Fabian Tschopp notifications@github.com
Mayank Daga |
@dagamayank I hope that helps. |
@dagamayank |
@naibaf7 |
@dagamayank Using vectors of size 4 and 16x16 thread blocks (64x64xTSK shared memory tiling) seems to work best on both cards so far though. |
@naibaf7 One question I had was - do I have to run the entire Alexnet or can I just run the 1st convolution layer using cifar10? What kind of performance are you seeing right now? |
@dagamayank
Especially the clBLAS forward performance is extremely detrimental, which was my main motivation to create libDNN. At this stage, libDNN beats cuBLAS-based implementations. The goal is to get within 70-80% of cuDNN. |
@dagamayank |
@naibaf7 I am very interesting in the LibDNN. It gets a good capability. For I am not familiar with opencl , I just glance over the LibDNN, it seems that it is also using matrix multiplication. If possibly, would your tell me if it is principle same to with cudnn? or so nice as you can provide me the references such as paper or document. Thank you. |
@zazd Yes it uses a local-memory and register-level GEMM. |
@bhack @gstoner With LibDNN on both the GTX 1080 and RX 480, the RX 480 performs exactly half as fast as the GTX 1080, just like expected. |
Do you have v2 kernels? |
@bhack |
@naibaf7 It is hard to talk about this topic.. We actually are the only one that use libdnn as upstream :wrink:. It could be nice if caffe could use libdnn as upstream naturally instead of having libdnn downstream. /cc @edgarriba |
@bhack |
I think also that @hughperkins could be interested to the standalone upstream |
@naibaf7 do you have Winograd kernels in libDNN? |
@dagamayank No not yet... |
Could be interesting if @dicecco1 would contribute upstream on libdnn standalone |
I'd be interested in being involved in this, though the way that OpenCL is used with FPGAs has some differences/conflicts with the current way that greentea has been setup. Currently compile time for kernels is on the order of hours for FPGA implementations, so they use offline compilation and program the FPGA with the binary (this still takes on the order of 300-400ms), so between kernels there has to be little or no reprogramming. |
So it is pratically impossibile to have an autotuning approach like libdnn. Right? |
Apart from that I think it's quite straightforward to provide a couple of interfaces for offline building and import built binaries. Is that right @naibaf7? |
Yeah, essentially for the FPGA implementations you need to decide more on an architecture (since in FPGAs you're configuring circuits rather than processing instructions) and it is usually best to have something that is either general (e.g. can handle different sizes/strides) or is very specific to a model (e.g. tuned to be very high performance for AlexNet). Autotuning for different layers would fit more into the model specific approach to FPGA implementations, but this would still be offline. |
@dicecco1 I have not checked in detail your paper but your Winograd kernel could be ported also on GPU/CPU or need to be heavily reeinginered? |
The winograd kernel would need to be heavily re-engineered for CPU/GPU implementations. |
I don't know if also @keryell is interested in dicecco1 kernels |
For all in the thread I'm talking of https://github.com/dicecco1/fpga_caffe |
There certainly are ways to either cache or tune the kernels on a surrogate platform. |
@bhack @dicecco1 |
@naibaf7 Can you notify us if you have some feedback of others interested to have v3 kernels and libdnn standalone as upstream? |
@bhack |
Observation: I'm still waiting on an example of calling libdnn from C++ :-) |
You can seen an example with tuning commented at https://github.com/tiny-dnn/tiny-dnn/blob/master/tiny_dnn/core/kernels/conv2d_op_libdnn.h |
@naibaf7 ok please give us an update as you can cause the standalone version it is quite on hold. |
@bhack
|
Status update, non-atomic backward kernels for pooling finished, library unit tested & verified with style-transfer and MNIST examples. |
Latest? Is it the end of the project? |
No, this is just the latest point in time I project being done with this step :) |
Ok ;) |
After that I don't know what the next optimization is going to be. Either mobile ARM chips with integrated GPUs or AMD's Vega and FP16, depending on what I can get my hands on first. |
I came across your post on the Tensorflow thread that you are developing an OpenCL counterpart for cuDNN. I would like to help/contribute on that project. Let me know where and how can I help. I have extensive OpenCL programming experience and am currently focused on ML activities at AMD.
The text was updated successfully, but these errors were encountered: