Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two questions about the library #148

Closed
RuABraun opened this issue Nov 14, 2019 · 12 comments
Closed

Two questions about the library #148

RuABraun opened this issue Nov 14, 2019 · 12 comments

Comments

@RuABraun
Copy link

  1. It seems the F32 GEMM implementation quantizes the input and output? I got that from the usage here, where one has to pass min/max values of the output. I'm worried that the approximation will degrade accuracy significantly (I'm already quantising all the layers I can to int8), still have to test that, but just to confirm is there no SGEMM implementation that doesn't quantise the input and output?

  2. The readme says

XNNPACK is a highly optimized library of floating-point neural network inference operators

however in the code there seems to be implementation for GEMM with int8 weights etc. ? I'm using QNNPACK at the moment for that, would it make sense to switch to XNNPACK for int8 layers?

@Maratyszcza
Copy link
Contributor

Maratyszcza commented Nov 14, 2019

  1. F32 functions don't quantize computations. However, many F32 operators accept output_min and output_max arguments which enable clamping output to arbitrary range (helpful e.g. for fusing simple activation functions). If you don't want to clamp output, just set them to +-std::numeric_limits<float>::infinity()
  2. XNNPACK is a fork of QNNPACK. Q8 operators in XNNPACK are remnants of QNNPACK code, but many internal optimizations were removed, and performance of these operators would be worse than in QNNPACK.

@RuABraun
Copy link
Author

Okay thank you!

@RuABraun
Copy link
Author

RuABraun commented Nov 14, 2019

@Maratyszcza one more question, is the kernel shaped (input_dim, output_dim) or (output_dim, input_dim) ?

edit: seems to be (output_dim, input_dim)

@Maratyszcza
Copy link
Contributor

For FullyConnected operator, it is (output_dim, input_dim). Generally operators use NHWC layout.

@RuABraun
Copy link
Author

Works great!

Two more questions, any plans on adding an elementwise product operator? And is it somehow possible to make the FullyConnected add to the the output instead of setting it?

@Maratyszcza
Copy link
Contributor

Elementwise product (including broadcasting support) landed just yesterday, see xnn_create_multiply_nd_f32 and xnn_setup_multiply_nd_f32. Currently Fully Connected operator doesn't support fused addition, but you can use separate Add operator (see xnn_create_add_nc_f32 and xnn_setup_add_nc_f32)

@RuABraun
Copy link
Author

Haha perfect timing! Thanks

@RuABraun
Copy link
Author

RuABraun commented Nov 21, 2019

@Maratyszcza seems broadcasting support is not available for the add operation (for when one wants to implement batch normalization for example)? is my impression correct?

could disabling the check and setting the bias mean stride to 0 work as quick hack?

edit: hah looks like it does!

@Maratyszcza
Copy link
Contributor

For batch normalization, I'd recommend to convert it into 1x1 depthwise convolution. XNNPACK has special optimizations for 1x1 DW convolution.

@RuABraun
Copy link
Author

RuABraun commented Nov 21, 2019

I'm not working with images, so my input/output matrixes are only 2D, I assume xnn_create_convolution2d... would crash then (don't see a 1D version)? Don't see an option to set stride to 0 in one direction.

@Maratyszcza
Copy link
Contributor

Maratyszcza commented Nov 21, 2019

You can set all height dimensions (input height, kernel height, height subsampling) to 1, this would be equivalent to 1D convolution.

@RuABraun
Copy link
Author

Ah thanks, forgot to check the setup function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants