Skip to content

Commit

Permalink
Merge pull request #245 from laoniu85/master
Browse files Browse the repository at this point in the history
[#243] ....fix mistake of cublas<T>getrfBatched
  • Loading branch information
lebedov committed Apr 17, 2018
2 parents ac597c0 + b9d1fc4 commit 249538c
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions skcuda/cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5774,10 +5774,10 @@ def cublasSgetrfBatched(handle, n, A, lda, P, info, batchSize):
ctypes.c_void_p,
ctypes.c_int]
@_cublas_version_req(5.0)
def cublasCgetrfBatched(handle, n, A, lda, P, info, batchSize):
def cublasDgetrfBatched(handle, n, A, lda, P, info, batchSize):
"""
This function performs the LU factorization of an array of n x n matrices.
References
----------
`cublas<t>getrfBatched <http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-getrfbatched>`_
Expand All @@ -5788,7 +5788,6 @@ def cublasCgetrfBatched(handle, n, A, lda, P, info, batchSize):
int(info), batchSize)
cublasCheckStatus(status)


if _cublas_version >= 5000:
_libcublas.cublasCgetrfBatched.restype = int
_libcublas.cublasCgetrfBatched.argtypes = [_types.handle,
Expand All @@ -5799,6 +5798,31 @@ def cublasCgetrfBatched(handle, n, A, lda, P, info, batchSize):
ctypes.c_void_p,
ctypes.c_int]
@_cublas_version_req(5.0)
def cublasCgetrfBatched(handle, n, A, lda, P, info, batchSize):
"""
This function performs the LU factorization of an array of n x n matrices.
References
----------
`cublas<t>getrfBatched <http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-getrfbatched>`_
"""

status = _libcublas.cublasCgetrfBatched(handle, n,
int(A), lda, int(P),
int(info), batchSize)
cublasCheckStatus(status)


if _cublas_version >= 5000:
_libcublas.cublasZgetrfBatched.restype = int
_libcublas.cublasZgetrfBatched.argtypes = [_types.handle,
ctypes.c_int,
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int]
@_cublas_version_req(5.0)
def cublasZgetrfBatched(handle, n, A, lda, P, info, batchSize):
"""
This function performs the LU factorization of an array of n x n matrices.
Expand Down

0 comments on commit 249538c

Please sign in to comment.