Skip to content

Commit

Permalink
Merge pull request #8212 from jemiryguo/cutensorMG
Browse files Browse the repository at this point in the history
Add CutensorMg support
  • Loading branch information
asi1024 committed Mar 28, 2024
2 parents 150c903 + 269dbf5 commit a834f04
Show file tree
Hide file tree
Showing 9 changed files with 1,415 additions and 4 deletions.
2 changes: 1 addition & 1 deletion cupy_backends/cuda/cupy_cutensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define INCLUDE_GUARD_CUDA_CUPY_CUTENSOR_H

#include <library_types.h>
#include <cutensor.h>
#include <cutensorMg.h>

#if CUTENSOR_VERSION < 10500

Expand Down
104 changes: 104 additions & 0 deletions cupy_backends/cuda/libs/cutensor.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ cpdef enum:
JIT_MODE_DEFAULT = 1, # NOQA, the corresponding plan will try to compile a dedicated kernel for the given operation. Only supported for GPUs with compute capability >= 8.0 (Ampere or newer).
JIT_MODE_ALL = 2 # NOQA, the corresponding plan will compile all the kernel candidates for the given contraction.

# cutensorMgHostDevice_t
CUTENSOR_MG_DEVICE_HOST = -1 # NOQA, regular memory on the host
CUTENSOR_MG_DEVICE_HOST_PINNED = -2 # NOQA, pinned memory on the host

# cutensorMgAlgo_t
CUTENSORMG_ALGO_DEFAULT = -1

# Version information
cpdef size_t get_version()
cpdef size_t get_cudart_version()
Expand Down Expand Up @@ -240,3 +247,100 @@ cpdef reduce(

#
cpdef destroyOperationDescriptor(intptr_t desc)

###############################################################################
# cutensorMg
###############################################################################

# MgHandle creation and destruction
cpdef intptr_t createMg(uint32_t numDevices, intptr_t devices) except? 0
cpdef destroyMg(intptr_t handle)

# MgTensorDescriptor creation and destruction
cpdef intptr_t createMgTensorDescriptor(
intptr_t handle,
uint32_t numModes,
intptr_t extent,
intptr_t elementStride,
intptr_t blockSize,
intptr_t blockStride,
intptr_t deviceCount,
uint32_t numDevices,
intptr_t devices,
int dataType) except? 0
cpdef destroyMgTensorDescriptor(intptr_t desc)

# MgCopyDescriptor creation and destruction
cpdef intptr_t createMgCopyDescriptor(
intptr_t handle,
intptr_t descDst,
intptr_t modesDst,
intptr_t descSrc,
intptr_t modesSrc) except? 0
cpdef destroyMgCopyDescriptor(intptr_t desc)

cpdef int64_t getMgCopyWorkspace(
intptr_t handle,
intptr_t desc,
intptr_t workspaceDeviceSize)

# MgCopyPlan creation and destruction
cpdef intptr_t createMgCopyPlan(
intptr_t handle,
intptr_t desc,
intptr_t workspaceDeviceSize,
int64_t workspaceHostSize) except? 0
cpdef destroyMgCopyPlan(intptr_t plan)

# copyMg
cpdef _copyMg(
intptr_t handle, intptr_t plan,
intptr_t ptrDst, const intptr_t ptrSrc,
intptr_t workspaceDevice, intptr_t workspaceHost,
intptr_t _streams)

# MgContractionDescriptor creation and destruction
cpdef intptr_t createMgContractionDescriptor(
intptr_t handle,
intptr_t descA,
intptr_t modesA,
intptr_t descB,
intptr_t modesB,
intptr_t descC,
intptr_t modesC,
intptr_t descD,
intptr_t modesD,
int compute) except? 0
cpdef destroyMgContractionDescriptor(intptr_t desc)

# MgContractionFind creation and destruction
cpdef intptr_t createMgContractionFind(
intptr_t handle,
int algo) except? 0
cpdef destroyMgContractionFind(intptr_t find)

cpdef int64_t getMgContractionWorkspace(
intptr_t handle,
intptr_t desc,
intptr_t find,
int preference,
intptr_t workspaceDeviceSize)

# MgContractionPlan creation and destruction
cpdef intptr_t createMgContractionPlan(
intptr_t handle,
intptr_t desc,
intptr_t find,
intptr_t workspaceDeviceSize,
int64_t workspaceHostSize) except? 0
cpdef destroyMgContractionPlan(intptr_t plan)

# contractMg
cpdef _contractMg(
intptr_t handle, intptr_t plan,
intptr_t alpha, const intptr_t A,
const intptr_t B, intptr_t beta,
const intptr_t C, intptr_t D,
intptr_t workspaceDevice,
intptr_t workspaceHost,
intptr_t _streams)

0 comments on commit a834f04

Please sign in to comment.