Skip to content

Commit

Permalink
add MPI::Comm#Allreduce method
Browse files Browse the repository at this point in the history
  • Loading branch information
seiya committed Apr 21, 2011
1 parent 60ba60e commit b17de54
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
23 changes: 23 additions & 0 deletions ext/mpi/mpi.c
Expand Up @@ -438,6 +438,28 @@ rb_comm_reduce(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf, VALUE rb_op, VALU
return Qnil;
}
static VALUE
rb_comm_allreduce(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf, VALUE rb_op)
{
void *sendbuf, *recvbuf;
int sendcount, recvcount;
MPI_Datatype sendtype, recvtype;
int rank, size;
struct _Comm *comm;
struct _Op *op;
OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype);
Data_Get_Struct(self, struct _Comm, comm);
check_error(MPI_Comm_rank(comm->comm, &rank));
check_error(MPI_Comm_size(comm->comm, &size));
OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype);
if (recvcount != sendcount)
rb_raise(rb_eArgError, "sendbuf and recvbuf has the same length");
if (recvtype != sendtype)
rb_raise(rb_eArgError, "sendbuf and recvbuf has the same type");
Data_Get_Struct(rb_op, struct _Op, op);
check_error(MPI_Allreduce(sendbuf, recvbuf, recvcount, recvtype, op->op, comm->comm));
return Qnil;
}
static VALUE
rb_comm_get_Errhandler(VALUE self)
{
struct _Comm *comm;
Expand Down Expand Up @@ -537,6 +559,7 @@ void Init_mpi()
rb_define_method(cComm, "Scatter", rb_comm_scatter, 3);
rb_define_method(cComm, "Alltoall", rb_comm_alltoall, 2);
rb_define_method(cComm, "Reduce", rb_comm_reduce, 4);
rb_define_method(cComm, "Allreduce", rb_comm_allreduce, 3);
rb_define_method(cComm, "Errhandler", rb_comm_get_Errhandler, 0);
rb_define_method(cComm, "Errhandler=", rb_comm_set_Errhandler, 1);

Expand Down
12 changes: 12 additions & 0 deletions spec/ruby-mpi_spec.rb
Expand Up @@ -153,6 +153,18 @@
end
end

it "should reduce data and send to all processes (allreduce)" do
world = MPI::Comm::WORLD
rank = world.rank
size = world.size
bufsize = 2
sendbuf = NArray.to_na([rank]*bufsize)
recvbuf = NArray.new(sendbuf.typecode,bufsize)
world.Allreduce(sendbuf, recvbuf, MPI::Op::SUM)
ary = NArray.new(sendbuf.typecode,bufsize).fill(size*(size-1)/2.0)
recvbuf.should == ary
end




Expand Down

0 comments on commit b17de54

Please sign in to comment.