Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rcorbish #8617 #8188

Merged
merged 1 commit into from Sep 1, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -53,7 +53,8 @@
* JCublas lapack
*
* @author Adam Gibson
* @author Richard Corbishley
* @author Richard Corbishley (signed)
*
*/
@Slf4j
public class JcublasLapack extends BaseLapack {
Expand All @@ -70,7 +71,6 @@ public void sgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {
if (A.ordering() == 'c')
a = A.dup('f');


if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();

Expand Down Expand Up @@ -193,7 +193,7 @@ public void dgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) {


//=========================
// Q R DECOMP
// Q R DECOMP
@Override
public void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
INDArray a = A;
Expand Down Expand Up @@ -306,8 +306,8 @@ public void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
if (r != null && r != R)
R.assign(r);

log.info("A: {}", A);
if (R != null) log.info("R: {}", R);
log.debug("A: {}", A);
if (R != null) log.debug("R: {}", R);
}

@Override
Expand Down Expand Up @@ -419,16 +419,18 @@ public void dgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
if (r != null && r != R)
R.assign(r);

log.info("A: {}", A);
if (R != null) log.info("R: {}", R);
log.debug("A: {}", A);
if (R != null) log.debug("R: {}", R);
}

//=========================
// CHOLESKY DECOMP
@Override
public void spotrf(byte uplo, int N, INDArray A, INDArray INFO) {
public void spotrf(byte _uplo, int N, INDArray A, INDArray INFO) {
INDArray a = A;

int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;

if (A.dataType() != DataType.FLOAT)
log.warn("FLOAT potrf called for " + A.dataType());

Expand Down Expand Up @@ -489,7 +491,7 @@ public void spotrf(byte uplo, int N, INDArray A, INDArray INFO) {
if (a != A)
A.assign(a);

if (uplo == 'U') {
if (uplo == CUBLAS_FILL_MODE_UPPER ) {
A.assign(A.transpose());
INDArrayIndex ix[] = new INDArrayIndex[2];
for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) {
Expand All @@ -506,13 +508,15 @@ public void spotrf(byte uplo, int N, INDArray A, INDArray INFO) {
}
}

log.info("A: {}", A);
log.debug("A: {}", A);
}

@Override
public void dpotrf(byte uplo, int N, INDArray A, INDArray INFO) {
public void dpotrf(byte _uplo, int N, INDArray A, INDArray INFO) {
INDArray a = A;

int uplo = _uplo == 'L' ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;

if (A.dataType() != DataType.DOUBLE)
log.warn("DOUBLE potrf called for " + A.dataType());

Expand Down Expand Up @@ -573,7 +577,7 @@ public void dpotrf(byte uplo, int N, INDArray A, INDArray INFO) {
if (a != A)
A.assign(a);

if (uplo == 'U') {
if (uplo == CUBLAS_FILL_MODE_UPPER ) {
A.assign(A.transpose());
INDArrayIndex ix[] = new INDArrayIndex[2];
for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) {
Expand All @@ -590,7 +594,7 @@ public void dpotrf(byte uplo, int N, INDArray A, INDArray INFO) {
}
}

log.info("A: {}", A);
log.debug("A: {}", A);
}


Expand Down