Skip to content

Commit

Permalink
Fix cpu builds
Browse files Browse the repository at this point in the history
  • Loading branch information
agibsonccc committed Apr 20, 2023
1 parent d7247c5 commit 6779e8b
Show file tree
Hide file tree
Showing 50 changed files with 574 additions and 1,159 deletions.
3 changes: 2 additions & 1 deletion libnd4j/include/helpers/ConstantShapeHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ class SD_LIB_EXPORT ConstantShapeHelper {
ConstantShapeBuffer* createShapeInfoWithNoUnitiesForReduce(const sd::LongType* maxShapeInfo,
const std::vector<sd::LongType>& dimsWithUnities,
sd::memory::Workspace* workspace = nullptr);
ConstantShapeBuffer* createSubArrShapeInfo(const sd::LongType* inShapeInfo, const LongType* dims, const int dimsSize,
ConstantShapeBuffer* createSubArrShapeInfo(const sd::LongType* inShapeInfo, const LongType* dims,
const LongType dimsSize,
sd::memory::Workspace* workspace = nullptr);

const sd::LongType* emptyShapeInfo(sd::DataType dataType);
Expand Down
4 changes: 2 additions & 2 deletions libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ ConstantShapeBuffer* ConstantShapeHelper::createShapeInfoWithNoUnitiesForReduce(
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(inShapeInfo) - dimsWithUnities.size()),
sd::LongType);

int temp;
sd::LongType temp;
if (dimsWithUnities.size() == 1 && shape::isCommonVector(inShapeInfo, temp) && temp == dimsWithUnities[0]) {
auto dims = ShapeUtils::evalDimsToExclude(shape::rank(inShapeInfo), {temp});
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dims.data(), dims.size(), newShapeInfo);
Expand All @@ -227,7 +227,7 @@ ConstantShapeBuffer* ConstantShapeHelper::createShapeInfoWithNoUnitiesForReduce(

////////////////////////////////////////////////////////////////////////
ConstantShapeBuffer* ConstantShapeHelper::createSubArrShapeInfo(const sd::LongType* inShapeInfo, const LongType* dims,
const int dimsSize, sd::memory::Workspace* workspace) {
const sd::LongType dimsSize, sd::memory::Workspace* workspace) {
sd::LongType* newShapeInfo = ShapeBuilders::createSubArrShapeInfo(inShapeInfo, dims, dimsSize, workspace);

ShapeDescriptor *descriptor = new ShapeDescriptor(newShapeInfo);
Expand Down
4 changes: 2 additions & 2 deletions libnd4j/include/helpers/cpu/MmulHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
if (Y != nullptr && X->dataType() != Y->dataType())
throw datatype_exception::build("mmulMxV expects all data types to be the same", A->dataType(), Y->dataType());

int xLenDim, yLenDim(0);
sd::LongType xLenDim, yLenDim(0);

if (A->rankOf() != 2) throw std::runtime_error("MmulHelper::mmulMxV: rank of A array is not equal 2 !");
if (!shape::isCommonVector(X->shapeInfo(), xLenDim))
Expand Down Expand Up @@ -356,7 +356,7 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
if (Z != nullptr && X->dataType() != Z->dataType())
throw datatype_exception::build("Dot expects all data types to be the same", X->dataType(), Z->dataType());

int xLenDim(0), yLenDim(0);
sd::LongType xLenDim(0), yLenDim(0);

if (!shape::isCommonVector(X->shapeInfo(), xLenDim))
throw std::runtime_error("MmulHelper::dot: X array must be vector !");
Expand Down
9 changes: 4 additions & 5 deletions libnd4j/include/helpers/cuda/ConstantShapeHelper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,13 @@ ConstantShapeBuffer * ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcas
}

////////////////////////////////////////////////////////////////////////
ConstantShapeBuffer * ConstantShapeHelper::createShapeInfoWithNoUnitiesForReduce(const sd::LongType* inShapeInfo,
const std::vector<int>& dimsWithUnities,
ConstantShapeBuffer * ConstantShapeHelper::createShapeInfoWithNoUnitiesForReduce(const sd::LongType* inShapeInfo, const std::vector<LongType>& dimsWithUnities,
sd::memory::Workspace* workspace) {
sd::LongType* newShapeInfo = nullptr;
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(inShapeInfo) - dimsWithUnities.size()),
sd::LongType);

int temp;
sd::LongType temp;
if (dimsWithUnities.size() == 1 && shape::isCommonVector(inShapeInfo, temp) && temp == dimsWithUnities[0]) {
auto dims = ShapeUtils::evalDimsToExclude(shape::rank(inShapeInfo), {temp});
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dims.data(), dims.size(), newShapeInfo);
Expand All @@ -231,8 +230,8 @@ ConstantShapeBuffer * ConstantShapeHelper::createShapeInfoWithNoUnitiesForReduce
}

////////////////////////////////////////////////////////////////////////
ConstantShapeBuffer *ConstantShapeHelper::createSubArrShapeInfo(const sd::LongType* inShapeInfo, const int* dims,
const int dimsSize, sd::memory::Workspace* workspace) {
ConstantShapeBuffer *ConstantShapeHelper::createSubArrShapeInfo(const sd::LongType* inShapeInfo, const LongType* dims,
const LongType dimsSize, sd::memory::Workspace* workspace) {
sd::LongType* newShapeInfo = ShapeBuilders::createSubArrShapeInfo(inShapeInfo, dims, dimsSize, workspace);

ShapeDescriptor *descriptor = new ShapeDescriptor(newShapeInfo);
Expand Down
12 changes: 6 additions & 6 deletions libnd4j/include/helpers/cuda/ConstantTadHelper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,23 @@ ConstantTadHelper &ConstantTadHelper::getInstance() {
return instance;
}

TadPack ConstantTadHelper::tadForDimensions(const sd::LongType *originalShape, int dimension,
TadPack ConstantTadHelper::tadForDimensions(const sd::LongType *originalShape, LongType dimension,
const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
}

TadPack ConstantTadHelper::tadForDimensions(const sd::LongType *originalShape, const std::vector<int> &dimensions,
TadPack ConstantTadHelper::tadForDimensions(const sd::LongType *originalShape, const std::vector<LongType> &dimensions,
const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
return tadForDimensions(originalShape, const_cast<sd::LongType *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
}

TadPack ConstantTadHelper::tadForDimensions(const sd::LongType *originalShape, int *dimensions, int dimLength,
TadPack ConstantTadHelper::tadForDimensions(const sd::LongType *originalShape, LongType *dimensions, LongType dimLength,
const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
return tadForDimensions(tadDescriptor);
}

TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions,
TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<LongType> &dimensions,
const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape);
return tadForDimensions(tadDescriptor);
Expand All @@ -75,7 +75,7 @@ TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
if (_cache[deviceId].count(descriptor) == 0) {
const auto shapeInfo = descriptor.originalShape().toShapeInfo();
const int rank = shape::rank(shapeInfo);
const std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rank, descriptor.axis());
const std::vector<sd::LongType > dimsToExclude = ShapeUtils::evalDimsToExclude(rank, descriptor.axis());
const sd::LongType numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude);
const int subArrRank =
(rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size();
Expand Down
45 changes: 21 additions & 24 deletions libnd4j/include/helpers/cuda_off/MmulHelper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ static SD_KERNEL void usualCudaGemm(const void* vA, const sd::LongType* aShapeIn
const T2* B = reinterpret_cast<const T2*>(vB);
T3* C = reinterpret_cast<T3*>(vC);

__shared__ int K, *coords;
__shared__ sd::LongType K, *coords;
__shared__ bool betaPresent;
__shared__ sd::LongType cLen, totalThreads;
__shared__ T3 alphaZ, betaZ;

if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
coords = reinterpret_cast<int*>(shmem);
coords = reinterpret_cast<sd::LongType *>(shmem);
cLen = shape::length(cShapeInfo);

K = shape::shapeOf(const_cast<sd::LongType*>(aShapeInfo))[aKaxis];
Expand Down Expand Up @@ -364,7 +364,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
// MXN x N = M
NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y, const double alpha, const double beta,
const char outOrder) {
int xLenDim, yLenDim(0);
sd::LongType xLenDim, yLenDim(0);

if (A->rankOf() != 2) throw std::runtime_error("MmulHelper::mmulMxV cuda: rank of A array is not equal 2 !");
if (!shape::isCommonVector(X->shapeInfo(), xLenDim))
Expand Down Expand Up @@ -468,7 +468,7 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
////////////////////////////////////////////////////////////////////////////
// (X * Y) = Z[0]
NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, const double alpha, const double beta) {
int xLenDim(0), yLenDim(0);
sd::LongType xLenDim(0), yLenDim(0);

if (!shape::isCommonVector(X->shapeInfo(), xLenDim))
throw std::runtime_error("MmulHelper::dot cuda: X array must be vector !");
Expand Down Expand Up @@ -527,21 +527,22 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, sd::NDArray* Z, con
template <typename T1, typename T2, typename T3>
static SD_KERNEL void batchedCudaGemm(const void* vA, const sd::LongType* aShapeInfo, const void* vB,
const sd::LongType* bShapeInfo, void* vC, const sd::LongType* cShapeInfo,
const int* aBatchDims, const int* bBatchDims, const int* cBatchDims,
const int aMaxis, const int aKaxis, const int bKaxis, const int bNaxis,
const int cMaxis, const int cNaxis, const double alpha, const double beta) {
const LongType* aBatchDims, const LongType* bBatchDims,
const LongType* cBatchDims, const LongType aMaxis, const LongType aKaxis,
const LongType bKaxis, const LongType bNaxis, const LongType cMaxis,
const LongType cNaxis, const double alpha, const double beta) {
const T1* A = reinterpret_cast<const T1*>(vA);
const T2* B = reinterpret_cast<const T2*>(vB);
T3* C = reinterpret_cast<T3*>(vC);

__shared__ bool betaPresent;
__shared__ int aRank, bRank, cRank, K, *coords;
__shared__ sd::LongType aRank, bRank, cRank, K, *coords;
__shared__ sd::LongType cLen, totalThreads;
__shared__ T3 alphaZ, betaZ;

if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
coords = reinterpret_cast<int*>(shmem);
coords = reinterpret_cast<sd::LongType *>(shmem);
cLen = shape::length(cShapeInfo);

K = shape::shapeOf(const_cast<sd::LongType*>(aShapeInfo))[aKaxis];
Expand Down Expand Up @@ -607,9 +608,9 @@ template <typename T1, typename T2, typename T3>
SD_HOST static void batchedGemm(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem,
cudaStream_t* stream, const void* vA, const sd::LongType* aShapeInfo, const void* vB,
const sd::LongType* bShapeInfo, void* vC, const sd::LongType* cShapeInfo,
const int* aBatchDims, const int* bBatchDims, const int* cBatchDims, const int aMaxis,
const int aKaxis, const int bKaxis, const int bNaxis, const int cMaxis,
const int cNaxis, const double alpha, const double beta) {
const LongType* aBatchDims, const LongType* bBatchDims, const LongType* cBatchDims,
const LongType aMaxis, const LongType aKaxis, const LongType bKaxis,
const LongType bNaxis, const LongType cMaxis, const LongType cNaxis, const double alpha, const double beta) {
batchedCudaGemm<T1, T2, T3><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(
vA, aShapeInfo, vB, bShapeInfo, vC, cShapeInfo, aBatchDims, bBatchDims, cBatchDims, aMaxis, aKaxis, bKaxis,
bNaxis, cMaxis, cNaxis, alpha, beta);
Expand All @@ -618,8 +619,8 @@ SD_HOST static void batchedGemm(const int blocksPerGrid, const int threadsPerBlo
///////////////////////////////////////////////////////////////////
NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, const double alpha, const double beta,
const char outOrder) {
const int aRank = A->rankOf();
const int bRank = B->rankOf();
const sd::LongType aRank = A->rankOf();
const sd::LongType bRank = B->rankOf();

// input ranks validation
if (aRank > bRank && bRank != 2)
Expand Down Expand Up @@ -651,9 +652,9 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con

if (C->isEmpty()) return C;

const int cRank = C->rankOf();
const sd::LongType cRank = C->rankOf();

const int aMaxis(aRank - 2), aKaxis(aRank - 1), bKaxis(bRank - 2), bNaxis(bRank - 1), cMaxis(cRank - 2),
const sd::LongType aMaxis(aRank - 2), aKaxis(aRank - 1), bKaxis(bRank - 2), bNaxis(bRank - 1), cMaxis(cRank - 2),
cNaxis(cRank - 1);

const int threadsPerBlock = SD_MAX_NUM_THREADS / 8;
Expand All @@ -662,23 +663,19 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con

PointersManager manager(A->getContext(), "MmulHelper::mmulNxN");

const int *aBatchDims(nullptr), *bBatchDims(nullptr), *cBatchDims(nullptr);
const sd::LongType *aBatchDims(nullptr), *bBatchDims(nullptr), *cBatchDims(nullptr);

if (aRank > 2)
aBatchDims = reinterpret_cast<int*>(manager.replicatePointer(
aBatchDims = reinterpret_cast<sd::LongType *>(manager.replicatePointer(
ShapeUtils::evalDimsToExclude(aRank, {aMaxis, aKaxis}).data(), (aRank - 2) * sizeof(int)));
if (bRank > 2)
bBatchDims = reinterpret_cast<int*>(manager.replicatePointer(
bBatchDims = reinterpret_cast<sd::LongType *>(manager.replicatePointer(
ShapeUtils::evalDimsToExclude(bRank, {bKaxis, bNaxis}).data(), (bRank - 2) * sizeof(int)));
if (cRank > 2)
cBatchDims = reinterpret_cast<int*>(manager.replicatePointer(
cBatchDims = reinterpret_cast<sd::LongType *>(manager.replicatePointer(
ShapeUtils::evalDimsToExclude(cRank, {cMaxis, cNaxis}).data(), (cRank - 2) * sizeof(int)));

NDArray::prepareSpecialUse({C}, {A, B});
// BUILD_TRIPLE_SELECTOR(A->dataType(), b->dataType(), C->dataType(), batchedGemm, (blocksPerGrid, threadsPerBlock,
// A->getContext()->getCudaStream(), A->specialBuffer(), A->specialShapeInfo(), B->specialBuffer(),
// B->specialShapeInfo(), C->specialBuffer(), C->special(), aMaxis, aKaxis, bKaxis, bNaxis, cMaxis, cNaxis, alpha,
// beta), SD_NUMERIC_TYPES, SD_NUMERIC_TYPES, SD_FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(
A->dataType(), batchedGemm,
(blocksPerGrid, threadsPerBlock, sharedMem, A->getContext()->getCudaStream(), A->specialBuffer(),
Expand Down
2 changes: 1 addition & 1 deletion libnd4j/include/helpers/impl/MmulHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ NDArray* sd::MmulHelper::tensorDot(const sd::NDArray* a, const sd::NDArray* b,
//////////////////////////////////////////////////////////////////////////
sd::NDArray* MmulHelper::mmul(const sd::NDArray* A, const sd::NDArray* B, sd::NDArray* C, const double alpha,
const double beta, const char outOrder) {
int lenDim;
sd::LongType lenDim;
const int aRank = A->rankOf();
const int bRank = B->rankOf();
const bool isAVector = shape::isCommonVector(A->shapeInfo(), lenDim);
Expand Down
3 changes: 1 addition & 2 deletions libnd4j/include/helpers/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -2422,8 +2422,7 @@ SD_LIB_EXPORT SD_INLINE SD_HOST_DEVICE bool isEmpty(const sd::LongType *shapeInf
// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array
// (already stored in maxIdxs)
SD_LIB_EXPORT SD_INLINE SD_HOST_DEVICE void maxIndToMinInd(sd::LongType *maxIdxs, sd::LongType *minIdxs, const sd::LongType *maxShapeInfo,
const sd::LongType *minShapeInfo, const sd::LongType *dimsToExclude,
int dimsLen) {
const sd::LongType *minShapeInfo, const sd::LongType *dimsToExclude, long long int dimsLen) {
const auto maxRank = shape::rank(maxShapeInfo);
const auto minRank = shape::rank(minShapeInfo);

Expand Down
2 changes: 1 addition & 1 deletion libnd4j/include/legacy/NativeOpExecutioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ class SD_LIB_EXPORT NativeOpExecutioner {
BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortGeneric(x, xShapeInfo, descending), SD_COMMON_TYPES);
}

static void execSort(void *x, const sd::LongType *xShapeInfo, long long int *dimension, int dimensionLength,
static void execSort(void *x, const sd::LongType *xShapeInfo, long long int *dimension, long long int dimensionLength,
const sd::LongType *tadShapeInfo, const sd::LongType *tadOffsets, bool descending) {
auto xType = sd::ArrayOptions::dataType(xShapeInfo);

Expand Down
11 changes: 6 additions & 5 deletions libnd4j/include/legacy/NativeOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -1366,16 +1366,17 @@ SD_LIB_EXPORT void sortByValue(sd::Pointer* extraPointers, void* x, sd::LongType
sd::LongType const* dyShapeInfo, bool descending);

SD_LIB_EXPORT void sortTad(sd::Pointer* extraPointers, void* hX, sd::LongType const* hXShapeInfo, void* dX,
sd::LongType const* dXShapeInfo, long long int* dimension, int dimensionLength,
sd::LongType const* dXShapeInfo,sd::LongType * dimension, sd::LongType dimensionLength,
sd::LongType const* tadShapeInfo, sd::LongType const* tadOffsets, bool descending);

SD_LIB_EXPORT void sortTadByKey(sd::Pointer* extraPointers, void* x, sd::LongType const* xShapeInfo, void* dx,
sd::LongType const* dxShapeInfo, void* y, sd::LongType const* yShapeInfo, void* dy,
sd::LongType const* dyShapeInfo, long long int* dimension, int dimensionLength, bool descending);
SD_LIB_EXPORT void sortTadByKey(sd::Pointer* extraPointers, void* x, sd::LongType const* xShapeInfo, void* dX,
sd::LongType const* dXShapeInfo, void* y, sd::LongType const* yShapeInfo, void* dy,
sd::LongType const* dyShapeInfo, sd::LongType * dimension, long long int dimensionLength, bool descending);

SD_LIB_EXPORT void sortTadByValue(sd::Pointer* extraPointers, void* x, sd::LongType const* xShapeInfo, void* dx,
sd::LongType const* dxShapeInfo, void* y, sd::LongType const* yShapeInfo, void* dy,
sd::LongType const* dyShapeInfo, long long int* dimension, int dimensionLength,
sd::LongType const* dyShapeInfo, sd::LongType * dimension,
sd::LongType dimensionLength,
bool descending);

// special sort impl for sorting out COO indices and values
Expand Down

0 comments on commit 6779e8b

Please sign in to comment.