Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex (Hermitian) eigenvalues/eigenvectors #105

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ generated-sources: \
$(LAPACK)/ilaenv.f \
$(LAPACK)/[sd]geqrf.f $(LAPACK)/[sd]ormqr.f \
$(LAPACK)/[sd]orgqr.f \
$(LAPACK)/[sd]sygvx.f
$(LAPACK)/[sd]sygvx.f \
$(LAPACK)/[cz]heev.f
ant javah
touch $@

Expand Down
162 changes: 161 additions & 1 deletion src/main/c/NativeBlas.c
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ static ComplexDouble getComplexDouble(JNIEnv *env, jobject dc)
/**********************************************************************/

static char *routine_names[] = {
"CAXPY", "CCOPY", "CDOTC", "CDOTU", "CGEEV", "CGEMM", "CGEMV", "CGERC", "CGERU", "CGESVD", "CSCAL", "CSSCAL", "CSWAP", "DASUM", "DAXPY", "DCOPY", "DDOT", "DGEEV", "DGELSD", "DGEMM", "DGEMV", "DGEQRF", "DGER", "DGESV", "DGESVD", "DGETRF", "DNRM2", "DORGQR", "DORMQR", "DPOSV", "DPOTRF", "DSCAL", "DSWAP", "DSYEV", "DSYEVD", "DSYEVR", "DSYEVX", "DSYGVD", "DSYGVX", "DSYSV", "DZASUM", "DZNRM2", "ICAMAX", "IDAMAX", "ILAENV", "ISAMAX", "IZAMAX", "SASUM", "SAXPY", "SCASUM", "SCNRM2", "SCOPY", "SDOT", "SGEEV", "SGELSD", "SGEMM", "SGEMV", "SGEQRF", "SGER", "SGESV", "SGESVD", "SGETRF", "SNRM2", "SORGQR", "SORMQR", "SPOSV", "SPOTRF", "SSCAL", "SSWAP", "SSYEV", "SSYEVD", "SSYEVR", "SSYEVX", "SSYGVD", "SSYGVX", "SSYSV", "ZAXPY", "ZCOPY", "ZDOTC", "ZDOTU", "ZDSCAL", "ZGEEV", "ZGEMM", "ZGEMV", "ZGERC", "ZGERU", "ZGESVD", "ZSCAL", "ZSWAP", 0
"CAXPY", "CCOPY", "CDOTC", "CDOTU", "CGEEV", "CGEMM", "CGEMV", "CGERC", "CGERU", "CGESVD", "CHEEV", "CSCAL", "CSSCAL", "CSWAP", "DASUM", "DAXPY", "DCOPY", "DDOT", "DGEEV", "DGELSD", "DGEMM", "DGEMV", "DGEQRF", "DGER", "DGESV", "DGESVD", "DGETRF", "DNRM2", "DORGQR", "DORMQR", "DPOSV", "DPOTRF", "DSCAL", "DSWAP", "DSYEV", "DSYEVD", "DSYEVR", "DSYEVX", "DSYGVD", "DSYGVX", "DSYSV", "DZASUM", "DZNRM2", "ICAMAX", "IDAMAX", "ILAENV", "ISAMAX", "IZAMAX", "SASUM", "SAXPY", "SCASUM", "SCNRM2", "SCOPY", "SDOT", "SGEEV", "SGELSD", "SGEMM", "SGEMV", "SGEQRF", "SGER", "SGESV", "SGESVD", "SGETRF", "SNRM2", "SORGQR", "SORMQR", "SPOSV", "SPOTRF", "SSCAL", "SSWAP", "SSYEV", "SSYEVD", "SSYEVR", "SSYEVX", "SSYGVD", "SSYGVX", "SSYSV", "ZAXPY", "ZCOPY", "ZDOTC", "ZDOTU", "ZDSCAL", "ZGEEV", "ZGEMM", "ZGEMV", "ZGERC", "ZGERU", "ZGESVD", "ZHEEV", "ZSCAL", "ZSWAP", 0
};

static char *routine_arguments[][23] = {
Expand All @@ -116,6 +116,7 @@ static char *routine_arguments[][23] = {
{ "M", "N", "ALPHA", "X", "INCX", "Y", "INCY", "A", "LDA" },
{ "M", "N", "ALPHA", "X", "INCX", "Y", "INCY", "A", "LDA" },
{ "JOBU", "JOBVT", "M", "N", "A", "LDA", "S", "U", "LDU", "VT", "LDVT", "WORK", "LWORK", "RWORK", "INFO" },
{ "JOBZ", "UPLO", "N", "A", "LDA", "W", "WORK", "LWORK", "RWORK", "INFO" },
{ "N", "CA", "CX", "INCX" },
{ "N", "SA", "CX", "INCX" },
{ "N", "CX", "INCX", "CY", "INCY" },
Expand Down Expand Up @@ -193,6 +194,7 @@ static char *routine_arguments[][23] = {
{ "M", "N", "ALPHA", "X", "INCX", "Y", "INCY", "A", "LDA" },
{ "M", "N", "ALPHA", "X", "INCX", "Y", "INCY", "A", "LDA" },
{ "JOBU", "JOBVT", "M", "N", "A", "LDA", "S", "U", "LDU", "VT", "LDVT", "WORK", "LWORK", "RWORK", "INFO" },
{ "JOBZ", "UPLO", "N", "A", "LDA", "W", "WORK", "LWORK", "RWORK", "INFO" },
{ "N", "ZA", "ZX", "INCX" },
{ "N", "ZX", "INCX", "ZY", "INCY" },
};
Expand Down Expand Up @@ -5167,3 +5169,161 @@ JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_ssygvx(JNIEnv *env, jclass this
return info;
}

JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_cheev(JNIEnv *env, jclass this, jchar jobz, jchar uplo, jint n, jfloatArray a, jint aIdx, jint lda, jfloatArray w, jint wIdx, jfloatArray work, jint workIdx, jint lwork, jfloatArray rwork, jint rworkIdx)
{
extern void cheev_(char *, char *, jint *, jfloat *, jint *, jfloat *, jfloat *, jint *, jfloat *, int *);

char jobzChr = (char) jobz;
char uploChr = (char) uplo;
jfloat *rworkPtrBase = 0, *rworkPtr = 0;
if (rwork) {
rworkPtrBase = (*env)->GetFloatArrayElements(env, rwork, NULL);
rworkPtr = rworkPtrBase + rworkIdx;
}
jfloat *aPtrBase = 0, *aPtr = 0;
if (a) {
if((*env)->IsSameObject(env, a, rwork) == JNI_TRUE)
aPtrBase = rworkPtrBase;
else
aPtrBase = (*env)->GetFloatArrayElements(env, a, NULL);
aPtr = aPtrBase + 2*aIdx;
}
jfloat *wPtrBase = 0, *wPtr = 0;
if (w) {
if((*env)->IsSameObject(env, w, rwork) == JNI_TRUE)
wPtrBase = rworkPtrBase;
else
if((*env)->IsSameObject(env, w, a) == JNI_TRUE)
wPtrBase = aPtrBase;
else
wPtrBase = (*env)->GetFloatArrayElements(env, w, NULL);
wPtr = wPtrBase + wIdx;
}
jfloat *workPtrBase = 0, *workPtr = 0;
if (work) {
if((*env)->IsSameObject(env, work, rwork) == JNI_TRUE)
workPtrBase = rworkPtrBase;
else
if((*env)->IsSameObject(env, work, a) == JNI_TRUE)
workPtrBase = aPtrBase;
else
if((*env)->IsSameObject(env, work, w) == JNI_TRUE)
workPtrBase = wPtrBase;
else
workPtrBase = (*env)->GetFloatArrayElements(env, work, NULL);
workPtr = workPtrBase + 2*workIdx;
}
int info;

cheev_(&jobzChr, &uploChr, &n, aPtr, &lda, wPtr, workPtr, &lwork, rworkPtr, &info);
if(workPtrBase) {
(*env)->ReleaseFloatArrayElements(env, work, workPtrBase, 0);
if (workPtrBase == rworkPtrBase)
rworkPtrBase = 0;
if (workPtrBase == aPtrBase)
aPtrBase = 0;
if (workPtrBase == wPtrBase)
wPtrBase = 0;
workPtrBase = 0;
}
if(wPtrBase) {
(*env)->ReleaseFloatArrayElements(env, w, wPtrBase, 0);
if (wPtrBase == rworkPtrBase)
rworkPtrBase = 0;
if (wPtrBase == aPtrBase)
aPtrBase = 0;
wPtrBase = 0;
}
if(aPtrBase) {
(*env)->ReleaseFloatArrayElements(env, a, aPtrBase, 0);
if (aPtrBase == rworkPtrBase)
rworkPtrBase = 0;
aPtrBase = 0;
}
if(rworkPtrBase) {
(*env)->ReleaseFloatArrayElements(env, rwork, rworkPtrBase, JNI_ABORT);
rworkPtrBase = 0;
}

return info;
}

JNIEXPORT jint JNICALL Java_org_jblas_NativeBlas_zheev(JNIEnv *env, jclass this, jchar jobz, jchar uplo, jint n, jdoubleArray a, jint aIdx, jint lda, jdoubleArray w, jint wIdx, jdoubleArray work, jint workIdx, jint lwork, jdoubleArray rwork, jint rworkIdx)
{
extern void zheev_(char *, char *, jint *, jdouble *, jint *, jdouble *, jdouble *, jint *, jdouble *, int *);

char jobzChr = (char) jobz;
char uploChr = (char) uplo;
jdouble *rworkPtrBase = 0, *rworkPtr = 0;
if (rwork) {
rworkPtrBase = (*env)->GetDoubleArrayElements(env, rwork, NULL);
rworkPtr = rworkPtrBase + rworkIdx;
}
jdouble *aPtrBase = 0, *aPtr = 0;
if (a) {
if((*env)->IsSameObject(env, a, rwork) == JNI_TRUE)
aPtrBase = rworkPtrBase;
else
aPtrBase = (*env)->GetDoubleArrayElements(env, a, NULL);
aPtr = aPtrBase + 2*aIdx;
}
jdouble *wPtrBase = 0, *wPtr = 0;
if (w) {
if((*env)->IsSameObject(env, w, rwork) == JNI_TRUE)
wPtrBase = rworkPtrBase;
else
if((*env)->IsSameObject(env, w, a) == JNI_TRUE)
wPtrBase = aPtrBase;
else
wPtrBase = (*env)->GetDoubleArrayElements(env, w, NULL);
wPtr = wPtrBase + wIdx;
}
jdouble *workPtrBase = 0, *workPtr = 0;
if (work) {
if((*env)->IsSameObject(env, work, rwork) == JNI_TRUE)
workPtrBase = rworkPtrBase;
else
if((*env)->IsSameObject(env, work, a) == JNI_TRUE)
workPtrBase = aPtrBase;
else
if((*env)->IsSameObject(env, work, w) == JNI_TRUE)
workPtrBase = wPtrBase;
else
workPtrBase = (*env)->GetDoubleArrayElements(env, work, NULL);
workPtr = workPtrBase + 2*workIdx;
}
int info;

zheev_(&jobzChr, &uploChr, &n, aPtr, &lda, wPtr, workPtr, &lwork, rworkPtr, &info);
if(workPtrBase) {
(*env)->ReleaseDoubleArrayElements(env, work, workPtrBase, 0);
if (workPtrBase == rworkPtrBase)
rworkPtrBase = 0;
if (workPtrBase == aPtrBase)
aPtrBase = 0;
if (workPtrBase == wPtrBase)
wPtrBase = 0;
workPtrBase = 0;
}
if(wPtrBase) {
(*env)->ReleaseDoubleArrayElements(env, w, wPtrBase, 0);
if (wPtrBase == rworkPtrBase)
rworkPtrBase = 0;
if (wPtrBase == aPtrBase)
aPtrBase = 0;
wPtrBase = 0;
}
if(aPtrBase) {
(*env)->ReleaseDoubleArrayElements(env, a, aPtrBase, 0);
if (aPtrBase == rworkPtrBase)
rworkPtrBase = 0;
aPtrBase = 0;
}
if(rworkPtrBase) {
(*env)->ReleaseDoubleArrayElements(env, rwork, rworkPtrBase, JNI_ABORT);
rworkPtrBase = 0;
}

return info;
}

16 changes: 16 additions & 0 deletions src/main/c/org_jblas_NativeBlas.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

122 changes: 120 additions & 2 deletions src/main/java/org/jblas/Eigen.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
*/
public class Eigen {
private static final DoubleMatrix dummyDouble = new DoubleMatrix(1);
private static final ComplexDoubleMatrix dummyComplexDouble = new ComplexDoubleMatrix(1);

/**
* Compute the eigenvalues for a symmetric matrix.
Expand Down Expand Up @@ -87,7 +88,6 @@ public static ComplexDoubleMatrix eigenvalues(DoubleMatrix A) {
DoubleMatrix WR = new DoubleMatrix(A.rows);
DoubleMatrix WI = WR.dup();
SimpleBlas.geev('N', 'N', A.dup(), WR, WI, dummyDouble, dummyDouble);

return new ComplexDoubleMatrix(WR, WI);
}

Expand Down Expand Up @@ -306,10 +306,70 @@ public static DoubleMatrix[] symmetricGeneralizedEigenvectors(DoubleMatrix A, Do
return result;
}

/**
* Computes the eigenvalues of a complex matrix.
*/
public static ComplexDoubleMatrix eigenvalues(ComplexDoubleMatrix A) {
A.assertSquare();
ComplexDoubleMatrix W = new ComplexDoubleMatrix(A.rows);
SimpleBlas.cgeev('N', 'N', A.dup(), W, dummyComplexDouble, dummyComplexDouble);
return W;
}

/**
* Computes the eigenvalues and eigenvectors of a complex matrix.
*
* @return an array of ComplexDoubleMatrix objects containing the (right) eigenvectors
* stored as the columns of the first matrix, and the eigenvalues as the
* diagonal elements of the second matrix.
*/
public static ComplexDoubleMatrix[] eigenvectors(ComplexDoubleMatrix A) {
A.assertSquare();
// setting up result arrays
ComplexDoubleMatrix W = new ComplexDoubleMatrix(A.rows);
ComplexDoubleMatrix VR = new ComplexDoubleMatrix(A.rows, A.rows);

SimpleBlas.cgeev('N', 'V', A.dup(), W, dummyComplexDouble, VR);
return new ComplexDoubleMatrix[]{VR, ComplexDoubleMatrix.diag(W)};
}

/**
* Computes the eigenvalues of a complex Hermitian matrix.
*
* Assumes that the input is an Hermitian matrix.
*/
public static DoubleMatrix hermitianEigenvalues(ComplexDoubleMatrix A) {
A.assertSquare();
DoubleMatrix W = new DoubleMatrix(A.rows);
SimpleBlas.cheev('N', 'U', A.dup(), W);
return W;
}

/**
* Computes the eigenvalues and eigenvectors of a complex Hermitian matrix.
*
* Assumes that the input is an Hermitian matrix.
*
* @return an array of ComplexDoubleMatrix objects containing the orthonormal eigenvectors
* stored as the columns of the first matrix, and the eigenvalues (in ascending order)
* as the diagonal elements of the second matrix.
*/
public static ComplexDoubleMatrix[] hermitianEigenvectors(ComplexDoubleMatrix A) {
A.assertSquare();
// setting up result arrays
DoubleMatrix W = new DoubleMatrix(A.rows);
ComplexDoubleMatrix eigenvectors = A.dup();

SimpleBlas.cheev('V', 'U', eigenvectors, W);
return new ComplexDoubleMatrix[]{eigenvectors, ComplexDoubleMatrix.diag(W.toComplex())};
}


//BEGIN
// The code below has been automatically generated.
// DO NOT EDIT!
private static final FloatMatrix dummyFloat = new FloatMatrix(1);
private static final ComplexFloatMatrix dummyComplexFloat = new ComplexFloatMatrix(1);

/**
* Compute the eigenvalues for a symmetric matrix.
Expand Down Expand Up @@ -346,7 +406,6 @@ public static ComplexFloatMatrix eigenvalues(FloatMatrix A) {
FloatMatrix WR = new FloatMatrix(A.rows);
FloatMatrix WI = WR.dup();
SimpleBlas.geev('N', 'N', A.dup(), WR, WI, dummyFloat, dummyFloat);

return new ComplexFloatMatrix(WR, WI);
}

Expand Down Expand Up @@ -565,5 +624,64 @@ public static FloatMatrix[] symmetricGeneralizedEigenvectors(FloatMatrix A, Floa
return result;
}

/**
* Computes the eigenvalues of a complex matrix.
*/
public static ComplexFloatMatrix eigenvalues(ComplexFloatMatrix A) {
A.assertSquare();
ComplexFloatMatrix W = new ComplexFloatMatrix(A.rows);
SimpleBlas.cgeev('N', 'N', A.dup(), W, dummyComplexFloat, dummyComplexFloat);
return W;
}

/**
* Computes the eigenvalues and eigenvectors of a complex matrix.
*
* @return an array of ComplexFloatMatrix objects containing the (right) eigenvectors
* stored as the columns of the first matrix, and the eigenvalues as the
* diagonal elements of the second matrix.
*/
public static ComplexFloatMatrix[] eigenvectors(ComplexFloatMatrix A) {
A.assertSquare();
// setting up result arrays
ComplexFloatMatrix W = new ComplexFloatMatrix(A.rows);
ComplexFloatMatrix VR = new ComplexFloatMatrix(A.rows, A.rows);

SimpleBlas.cgeev('N', 'V', A.dup(), W, dummyComplexFloat, VR);
return new ComplexFloatMatrix[]{VR, ComplexFloatMatrix.diag(W)};
}

/**
* Computes the eigenvalues of a complex Hermitian matrix.
*
* Assumes that the input is an Hermitian matrix.
*/
public static FloatMatrix hermitianEigenvalues(ComplexFloatMatrix A) {
A.assertSquare();
FloatMatrix W = new FloatMatrix(A.rows);
SimpleBlas.cheev('N', 'U', A.dup(), W);
return W;
}

/**
* Computes the eigenvalues and eigenvectors of a complex Hermitian matrix.
*
* Assumes that the input is an Hermitian matrix.
*
* @return an array of ComplexFloatMatrix objects containing the orthonormal eigenvectors
* stored as the columns of the first matrix, and the eigenvalues (in ascending order)
* as the diagonal elements of the second matrix.
*/
public static ComplexFloatMatrix[] hermitianEigenvectors(ComplexFloatMatrix A) {
A.assertSquare();
// setting up result arrays
FloatMatrix W = new FloatMatrix(A.rows);
ComplexFloatMatrix eigenvectors = A.dup();

SimpleBlas.cheev('V', 'U', eigenvectors, W);
return new ComplexFloatMatrix[]{eigenvectors, ComplexFloatMatrix.diag(W.toComplex())};
}


//END
}
Loading