-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Inverse(12) for CPU and CUDA #3485
Conversation
cudaMemcpyDeviceToHost)); | ||
for (auto i = 0; i < num_batches; ++i) { | ||
if (info_cpu[i] != 0) { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Matrix is singular at batch:", i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious: For the CPU version, Does the Eigen implementation provide such a friendly message stating that the matrix is singular ? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it does not. I would have to do Full Pivoting LU decomposition to provide such checks. And the standard provides flexibility in checking. Ke suggested I try a different approach for CUDA so the code may still change.
In reply to: 406930453 [](ancestors = 406930453)
|
||
int64_t num_batches = 1; | ||
const int64_t rows = input_shape.GetDims()[num_dim - 2]; | ||
const int64_t cols = input_shape.GetDims()[num_dim - 1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a check to enforce rows == cols (just like in the CUDA kernel) ? #Pending
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IAllocatorUniquePtr<double*> matrix_ptrs = inst->GetScratchBuffer<double*>(n_batches); | ||
ORT_RETURN_IF_ERROR(ComputeMatrixOffsets<double>(input_workspace.get(), num_batches, rows, matrix_ptrs)); | ||
// Do LU factorization | ||
CUBLAS_RETURN_IF_ERROR(cublasDgetrfBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), info.get(), n_batches)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use this - https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-matinvbatched ? It seems to basically do getrfBatched + getriBatched (provided rows and cols are < 32) ? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is exactly what is implemented.
Ke also suggested to check https://docs.nvidia.com/cuda/cusolver/#cuds-intro
In reply to: 406942817 [](ancestors = 406942817)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like to use cuSolver one must generate an I matrix on a device for every invocation and that requires writing a kernel which we are trying to avoid. Also TF and PyTorch are using cuBlas so this impl should be good enough.
In reply to: 406945634 [](ancestors = 406945634,406942817)
ORT_RETURN_IF_ERROR(ComputeMatrixOffsets<double>(input_workspace.get(), num_batches, rows, matrix_ptrs)); | ||
// Do LU factorization | ||
CUBLAS_RETURN_IF_ERROR(cublasDgetrfBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), info.get(), n_batches)); | ||
ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is some external discussion as to whether this approach is performant if there was a single large matrix to be inverted (https://stackoverflow.com/questions/37731103/cublas-matrix-inverse-much-slower-than-matlab). Basically, the approach you are taking is conducive for inverting a batch of smaller matrices. Quoting cuBlas documentation - "This function is intended to be used for matrices of small sizes where the launch overhead is a significant factor." Maybe there must be a plan to deal with a single large matrix. #Pending
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means it is optimized for small matrix overhead as well as solving large matrices. This is what my research says.
In reply to: 406944729 [](ancestors = 406944729)
|
||
Eigen::Map<const MatrixT<Eigen::half>> input_matrix(input_data, rows, cols); | ||
Eigen::Map<MatrixT<Eigen::half>> output_matrix(output_data, rows, cols); | ||
output_matrix = input_matrix.inverse(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does Eigen have a limit on the size of the matrix's rows (and cols) by any chance ? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This indicates (https://stackoverflow.com/questions/17430644/what-is-the-maximum-size-of-matrix-in-eigen) that it is bound by memory available on the system.
In reply to: 407768947 [](ancestors = 407768947)
|
||
template <typename T> | ||
using MatrixT = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW there's ConstEigenMatrixMapRowMajor and EigenMatrixMapRowMajor in math_cpuonly.h
|
||
namespace onnxruntime { | ||
namespace test { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs some tests where there are batches, preferably including where multiple dimensions provide the number of batches (e.g. rank 4 input).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Description:
Matrix Inverse (or batch)