From 7548e342dd208dd0f9e4522e431d6a6dbe660f53 Mon Sep 17 00:00:00 2001 From: Nathaniel Graham Date: Tue, 21 Jul 2015 16:14:51 -0600 Subject: [PATCH 1/2] Java bindings for alltoallw functions. Includes bindings for MPI_ALLTOALLW and MPI_IALLTOALLW. Signed-off-by: Nathaniel Graham --- ompi/mpi/java/c/mpiJava.h | 9 +++- ompi/mpi/java/c/mpi_Comm.c | 76 ++++++++++++++++++++++++++++++ ompi/mpi/java/c/mpi_MPI.c | 24 ++++++++++ ompi/mpi/java/java/Comm.java | 91 ++++++++++++++++++++++++++++++++++++ 4 files changed, 199 insertions(+), 1 deletion(-) diff --git a/ompi/mpi/java/c/mpiJava.h b/ompi/mpi/java/c/mpiJava.h index 45f11dd83dd..ad4ce8eb018 100644 --- a/ompi/mpi/java/c/mpiJava.h +++ b/ompi/mpi/java/c/mpiJava.h @@ -9,6 +9,8 @@ * University of Stuttgart. All rights reserved. * Copyright (c) 2004-2005 The Regents of the University of California. * All rights reserved. + * Copyright (c) 2015 Los Alamos National Security, LLC. All rights + * reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -95,7 +97,7 @@ void ompi_java_releaseReadPtr( /* Gets a buffer pointer for writing. */ void ompi_java_getWritePtr( void **ptr, ompi_java_buffer_t **item, JNIEnv *env, - jobject buf, jboolean db, int count, MPI_Datatype type); + jobject buf, jboolean db, int count, MPI_Datatype type); /* Gets a buffer pointer for writing. * 'size' is the number of processes. */ @@ -133,6 +135,11 @@ void ompi_java_releaseIntArray( void ompi_java_forgetIntArray( JNIEnv *env, jintArray array, jint *jptr, int *cptr); +void ompi_java_getDatatypeArray( + JNIEnv *env, jlongArray array, jlong **jptr, MPI_Datatype **cptr); +void ompi_java_forgetDatatypeArray( + JNIEnv *env, jlongArray array, jlong *jptr, MPI_Datatype *cptr); + void ompi_java_getBooleanArray( JNIEnv *env, jbooleanArray array, jboolean **jptr, int **cptr); void ompi_java_releaseBooleanArray( diff --git a/ompi/mpi/java/c/mpi_Comm.c b/ompi/mpi/java/c/mpi_Comm.c index 6ebbcbf9698..b8627369690 100644 --- a/ompi/mpi/java/c/mpi_Comm.c +++ b/ompi/mpi/java/c/mpi_Comm.c @@ -1581,6 +1581,82 @@ JNIEXPORT jlong JNICALL Java_mpi_Comm_iAllToAllv( return (jlong)request; } +JNIEXPORT void JNICALL Java_mpi_Comm_allToAllw( + JNIEnv *env, jobject jthis, jlong jComm, + jobject sendBuf, jintArray sCount, jintArray sDispls, jlongArray sTypes, + jobject recvBuf, jintArray rCount, jintArray rDispls, jlongArray rTypes) +{ + MPI_Comm comm = (MPI_Comm)jComm; + + jlong* jSTypes, *jRTypes; + MPI_Datatype *cSTypes, *cRTypes; + + ompi_java_getDatatypeArray(env, sTypes, &jSTypes, &cSTypes); + ompi_java_getDatatypeArray(env, rTypes, &jRTypes, &cRTypes); + + jint *jSCount, *jRCount, *jSDispls, *jRDispls; + int *cSCount, *cRCount, *cSDispls, *cRDispls; + ompi_java_getIntArray(env, sCount, &jSCount, &cSCount); + ompi_java_getIntArray(env, rCount, &jRCount, &cRCount); + ompi_java_getIntArray(env, sDispls, &jSDispls, &cSDispls); + ompi_java_getIntArray(env, rDispls, &jRDispls, &cRDispls); + + void *sPtr = ompi_java_getDirectBufferAddress(env, sendBuf), + *rPtr = ompi_java_getDirectBufferAddress(env, recvBuf); + + int rc = MPI_Alltoallw( + sPtr, cSCount, cSDispls, cSTypes, + rPtr, cRCount, cRDispls, cRTypes, comm); + + ompi_java_exceptionCheck(env, rc); + ompi_java_forgetIntArray(env, sCount, jSCount, cSCount); + ompi_java_forgetIntArray(env, rCount, jRCount, cRCount); + ompi_java_forgetIntArray(env, sDispls, jSDispls, cSDispls); + ompi_java_forgetIntArray(env, rDispls, jRDispls, cRDispls); + ompi_java_forgetDatatypeArray(env, sTypes, jSTypes, cSTypes); + ompi_java_forgetDatatypeArray(env, rTypes, jRTypes, cRTypes); +} + +JNIEXPORT jlong JNICALL Java_mpi_Comm_iAllToAllw( + JNIEnv *env, jobject jthis, jlong jComm, + jobject sendBuf, jintArray sCount, jintArray sDispls, jlongArray sTypes, + jobject recvBuf, jintArray rCount, jintArray rDispls, jlongArray rTypes) +{ + MPI_Comm comm = (MPI_Comm)jComm; + + jlong* jSTypes, *jRTypes; + MPI_Datatype *cSTypes, *cRTypes; + + ompi_java_getDatatypeArray(env, sTypes, &jSTypes, &cSTypes); + ompi_java_getDatatypeArray(env, rTypes, &jRTypes, &cRTypes); + + jint *jSCount, *jRCount, *jSDispls, *jRDispls; + int *cSCount, *cRCount, *cSDispls, *cRDispls; + ompi_java_getIntArray(env, sCount, &jSCount, &cSCount); + ompi_java_getIntArray(env, rCount, &jRCount, &cRCount); + ompi_java_getIntArray(env, sDispls, &jSDispls, &cSDispls); + ompi_java_getIntArray(env, rDispls, &jRDispls, &cRDispls); + + void *sPtr = ompi_java_getDirectBufferAddress(env, sendBuf), + *rPtr = ompi_java_getDirectBufferAddress(env, recvBuf); + + MPI_Request request; + + int rc = MPI_Ialltoallw( + sPtr, cSCount, cSDispls, cSTypes, + rPtr, cRCount, cRDispls, cRTypes, comm, &request); + + ompi_java_exceptionCheck(env, rc); + ompi_java_forgetIntArray(env, sCount, jSCount, cSCount); + ompi_java_forgetIntArray(env, rCount, jRCount, cRCount); + ompi_java_forgetIntArray(env, sDispls, jSDispls, cSDispls); + ompi_java_forgetIntArray(env, rDispls, jRDispls, cRDispls); + ompi_java_forgetDatatypeArray(env, sTypes, jSTypes, cSTypes); + ompi_java_forgetDatatypeArray(env, rTypes, jRTypes, cRTypes); + + return (jlong)request; +} + JNIEXPORT void JNICALL Java_mpi_Comm_neighborAllGather( JNIEnv *env, jobject jthis, jlong jComm, jobject sBuf, jboolean sdb, jint sOff, diff --git a/ompi/mpi/java/c/mpi_MPI.c b/ompi/mpi/java/c/mpi_MPI.c index fcdd4966b53..23e9dd2feca 100644 --- a/ompi/mpi/java/c/mpi_MPI.c +++ b/ompi/mpi/java/c/mpi_MPI.c @@ -975,6 +975,30 @@ void ompi_java_forgetIntArray(JNIEnv *env, jintArray array, (*env)->ReleaseIntArrayElements(env, array, jptr, JNI_ABORT); } +void ompi_java_getDatatypeArray(JNIEnv *env, jlongArray array, + jlong **jptr, MPI_Datatype **cptr) +{ + jlong *jLongs = (*env)->GetLongArrayElements(env, array, NULL); + *jptr = jLongs; + + int i, length = (*env)->GetArrayLength(env, array); + MPI_Datatype *cDatatypes = calloc(length, sizeof(MPI_Datatype)); + + for(i = 0; i < length; i++){ + cDatatypes[i] = (MPI_Datatype)jLongs[i]; + } + *cptr = cDatatypes; +} + +void ompi_java_forgetDatatypeArray(JNIEnv *env, jlongArray array, + jlong *jptr, MPI_Datatype *cptr) +{ + if(jptr != cptr) + free(cptr); + + (*env)->ReleaseLongArrayElements(env, array, jptr, JNI_ABORT); +} + void ompi_java_getBooleanArray(JNIEnv *env, jbooleanArray array, jboolean **jptr, int **cptr) { diff --git a/ompi/mpi/java/java/Comm.java b/ompi/mpi/java/java/Comm.java index b52fb7060e5..3fe3d1f7f8d 100644 --- a/ompi/mpi/java/java/Comm.java +++ b/ompi/mpi/java/java/Comm.java @@ -229,6 +229,7 @@ public static int compare(Comm comm1, Comm comm2) throws MPIException /** * Test if communicator object is null (has been freed). + * Java binding of {@code MPI_COMM_NULL}. * @return true if the comm object is null, false otherwise */ public final boolean isNull() @@ -2307,6 +2308,79 @@ private native long iAllToAllv(long comm, Buffer recvbuf, int[] recvcount, int[] rdispls, long recvtype) throws MPIException; +/** + * Adds flexibility to {@code allToAll}: location of data for send is //here + * specified by {@code sDispls} and location to place data on receive + * side is specified by {@code rDispls}. + *

Java binding of the MPI operation {@code MPI_ALLTOALLW}. + * @param sendBuf send buffer + * @param sendCount number of items sent to each buffer + * @param sDispls displacements from which to take outgoing data + * @param sendTypes datatypes of send buffer items + * @param recvBuf receive buffer + * @param recvCount number of elements received from each process + * @param rDispls displacements at which to place incoming data + * @param recvTypes datatype of each item in receive buffer + * @throws MPIException Signals that an MPI exception of some sort has occurred. + */ +public final void allToAllw( + Buffer sendBuf, int[] sendCount, int[] sDispls, Datatype[] sendTypes, + Buffer recvBuf, int[] recvCount, int[] rDispls, Datatype[] recvTypes) + throws MPIException +{ + MPI.check(); + assertDirectBuffer(sendBuf, recvBuf); + + long[] sendHandles = convertTypeArray(sendTypes); + long[] recvHandles = convertTypeArray(recvTypes); + + allToAllw(handle, sendBuf, sendCount, sDispls, + sendHandles, recvBuf, recvCount, rDispls, + recvHandles); +} + +private native void allToAllw(long comm, + Buffer sendBuf, int[] sendCount, int[] sDispls, long[] sendTypes, + Buffer recvBuf, int[] recvCount, int[] rDispls, long[] recvTypes) + throws MPIException; + +/** + * Adds flexibility to {@code iAllToAll}: location of data for send is + * specified by {@code sDispls} and location to place data on receive + * side is specified by {@code rDispls}. + *

Java binding of the MPI operation {@code MPI_IALLTOALLW}. + * @param sendBuf send buffer + * @param sendCount number of items sent to each buffer + * @param sDispls displacements from which to take outgoing data + * @param sendTypes datatype send buffer items + * @param recvBuf receive buffer + * @param recvCount number of elements received from each process + * @param rDispls displacements at which to place incoming data + * @param recvTypes datatype of each item in receive buffer + * @return communication request + * @throws MPIException Signals that an MPI exception of some sort has occurred. + */ +public final Request iAllToAllw( + Buffer sendBuf, int[] sendCount, int[] sDispls, Datatype[] sendTypes, + Buffer recvBuf, int[] recvCount, int[] rDispls, Datatype[] recvTypes) + throws MPIException +{ + MPI.check(); + assertDirectBuffer(sendBuf, recvBuf); + + long[] sendHandles = convertTypeArray(sendTypes); + long[] recvHandles = convertTypeArray(recvTypes); + + return new Request(iAllToAllw( + handle, sendBuf, sendCount, sDispls, sendHandles, + recvBuf, recvCount, rDispls, recvHandles)); +} + +private native long iAllToAllw(long comm, + Buffer sendBuf, int[] sendCount, int[] sDispls, long[] sendTypes, + Buffer recvBuf, int[] recvCount, int[] rDispls, long[] recvTypes) + throws MPIException; + /** * Java binding of {@code MPI_NEIGHBOR_ALLGATHER}. * @param sendbuf send buffer @@ -3232,4 +3306,21 @@ public final String getName() throws MPIException private native String getName(long handle) throws MPIException; +/** + * A helper method to convert an array of Datatypes to + * an array of longs (handles). + * @param dArray Array of Datatypes + * @return converted Datatypes + */ +private long[] convertTypeArray(Datatype[] dArray) { + long[] lArray = new long[dArray.length]; + + for(int i = 0; i < lArray.length; i++) { + if(dArray[i] != null) { + lArray[i] = dArray[i].handle; + } + } + return lArray; +} + } // Comm From 11e1f09c25ab6de06f296fce7801762b817dd3a5 Mon Sep 17 00:00:00 2001 From: Nathaniel Graham Date: Tue, 21 Jul 2015 16:39:59 -0600 Subject: [PATCH 2/2] White space fixes Signed-off-by: Nathaniel Graham --- ompi/mpi/java/c/mpiJava.h | 2 +- ompi/mpi/java/c/mpi_Comm.c | 8 ++++---- ompi/mpi/java/java/Comm.java | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ompi/mpi/java/c/mpiJava.h b/ompi/mpi/java/c/mpiJava.h index ad4ce8eb018..4f98a9db622 100644 --- a/ompi/mpi/java/c/mpiJava.h +++ b/ompi/mpi/java/c/mpiJava.h @@ -97,7 +97,7 @@ void ompi_java_releaseReadPtr( /* Gets a buffer pointer for writing. */ void ompi_java_getWritePtr( void **ptr, ompi_java_buffer_t **item, JNIEnv *env, - jobject buf, jboolean db, int count, MPI_Datatype type); + jobject buf, jboolean db, int count, MPI_Datatype type); /* Gets a buffer pointer for writing. * 'size' is the number of processes. */ diff --git a/ompi/mpi/java/c/mpi_Comm.c b/ompi/mpi/java/c/mpi_Comm.c index b8627369690..ccc4294babc 100644 --- a/ompi/mpi/java/c/mpi_Comm.c +++ b/ompi/mpi/java/c/mpi_Comm.c @@ -1605,8 +1605,8 @@ JNIEXPORT void JNICALL Java_mpi_Comm_allToAllw( *rPtr = ompi_java_getDirectBufferAddress(env, recvBuf); int rc = MPI_Alltoallw( - sPtr, cSCount, cSDispls, cSTypes, - rPtr, cRCount, cRDispls, cRTypes, comm); + sPtr, cSCount, cSDispls, cSTypes, + rPtr, cRCount, cRDispls, cRTypes, comm); ompi_java_exceptionCheck(env, rc); ompi_java_forgetIntArray(env, sCount, jSCount, cSCount); @@ -1643,8 +1643,8 @@ JNIEXPORT jlong JNICALL Java_mpi_Comm_iAllToAllw( MPI_Request request; int rc = MPI_Ialltoallw( - sPtr, cSCount, cSDispls, cSTypes, - rPtr, cRCount, cRDispls, cRTypes, comm, &request); + sPtr, cSCount, cSDispls, cSTypes, + rPtr, cRCount, cRDispls, cRTypes, comm, &request); ompi_java_exceptionCheck(env, rc); ompi_java_forgetIntArray(env, sCount, jSCount, cSCount); diff --git a/ompi/mpi/java/java/Comm.java b/ompi/mpi/java/java/Comm.java index 3fe3d1f7f8d..97b681cb812 100644 --- a/ompi/mpi/java/java/Comm.java +++ b/ompi/mpi/java/java/Comm.java @@ -2336,7 +2336,7 @@ public final void allToAllw( allToAllw(handle, sendBuf, sendCount, sDispls, sendHandles, recvBuf, recvCount, rDispls, - recvHandles); + recvHandles); } private native void allToAllw(long comm,