Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove MPIRequest.wait() #672

Merged
merged 5 commits into from
Sep 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
- [#666](https://github.com/helmholtz-analytics/heat/pull/666) New feature: distributed prepend/append for diff().
- [#667](https://github.com/helmholtz-analytics/heat/pull/667) Enhancement `reshape`: rename axis parameter
- [#670](https://github.com/helmholtz-analytics/heat/pull/670) New Feature: `bincount()`
- [#672](https://github.com/helmholtz-analytics/heat/pull/672) Bug / Enhancement: Remove `MPIRequest.wait()`, rewrite calls with capital letters. lower case `wait()` now falls back to the `mpi4py` function

# v0.4.0

Expand Down
4 changes: 2 additions & 2 deletions heat/core/arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def diff(a, n=1, axis=-1, prepend=None, append=None):
ret.lloc[diff_slice] = dif

if rank > 0:
snd.wait() # wait for the send to finish
snd.Wait() # wait for the send to finish
if rank < size - 1:
cr_slice = [slice(None)] * len(a.shape)
# slice of 1 element in the selected axis for the shape creation
Expand All @@ -399,7 +399,7 @@ def diff(a, n=1, axis=-1, prepend=None, append=None):
axis_slice_end = [slice(None)] * len(a.shape)
# select the last elements in the selected axis
axis_slice_end[axis] = slice(-1, None)
rec.wait()
rec.Wait()
# diff logic
ret.lloc[axis_slice_end] = (
recv_data.reshape(ret.lloc[axis_slice_end].shape) - ret.lloc[axis_slice_end]
Expand Down
12 changes: 0 additions & 12 deletions heat/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,18 +1103,6 @@ def Wait(self, status=None):
self.recvbuf = self.recvbuf.permute(self.permutation)
self.tensor.copy_(self.recvbuf)

def wait(self, status=None):
self.handle.wait(status)
if (
self.tensor is not None
and isinstance(self.tensor, torch.Tensor)
and self.tensor.is_cuda
and not CUDA_AWARE_MPI
):
if self.permutation is not None:
self.recvbuf = self.recvbuf.permute(self.permutation)
self.tensor.copy_(self.recvbuf)

def __getattr__(self, name):
"""
Default pass-through for the communicator methods.
Expand Down
4 changes: 2 additions & 2 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def get_halo(self, halo_size):
req_list.append(self.comm.Irecv(res_next, source=prev_rank))

for req in req_list:
req.wait()
req.Wait()

self.__halo_next = res_prev
self.__halo_prev = res_next
Expand Down Expand Up @@ -2792,7 +2792,7 @@ def resplit_(self, axis=None):
lp_arr = None
for k in lp_keys:
if rcv[k][0] is not None:
rcv[k][0].wait()
rcv[k][0].Wait()
if lp_arr is None:
lp_arr = rcv[k][1]
else:
Expand Down
18 changes: 9 additions & 9 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def matmul(a, b, allow_resplit=False):
if any(lshape_map[:, 0, :][:, 1] == 1):
a_d1_1s_flag = True

index_map_comm.wait()
index_map_comm.Wait()
for pr in range(a.comm.size):
start0 = index_map[pr, 0, 0, 0].item()
stop0 = index_map[pr, 0, 0, 1].item()
Expand All @@ -382,7 +382,7 @@ def matmul(a, b, allow_resplit=False):
a_block_map[pr, dim0, dim1] = torch.tensor(
(dim0 * mB, dim1 * kB), dtype=torch.int, device=a._DNDarray__array.device
)
rem_map_comm.wait()
rem_map_comm.Wait()
if b.split == 0:
# the blocks are shifted in the 2nd dimension of A for as many remainders
# there are between the blocks in the first dim of B
Expand Down Expand Up @@ -440,7 +440,7 @@ def matmul(a, b, allow_resplit=False):
b_block_map[:, cnt:, :, 0] += 1

# work loop: loop over all processes (also will incorporate the remainder calculations)
c_wait.wait()
c_wait.Wait()

if split_0_flag:
# need to send b here and not a
Expand Down Expand Up @@ -484,7 +484,7 @@ def matmul(a, b, allow_resplit=False):

# receive the data from the last loop and do the calculation with that
if pr != 0:
req[pr - 1].wait()
req[pr - 1].Wait()
# after receiving the last loop's bcast
__mm_c_block_setter(
b_proc=pr - 1,
Expand Down Expand Up @@ -518,7 +518,7 @@ def matmul(a, b, allow_resplit=False):

# need to wait if its the last loop, also need to collect the remainders
if pr == b.comm.size - 1:
req[pr].wait()
req[pr].Wait()
__mm_c_block_setter(
b_proc=pr,
a_proc=a.comm.rank,
Expand Down Expand Up @@ -610,7 +610,7 @@ def matmul(a, b, allow_resplit=False):
# receive the data from the last loop and do the calculation with that
if pr != 0:
# after receiving the last loop's bcast
req[pr - 1].wait()
req[pr - 1].Wait()
__mm_c_block_setter(
a_proc=pr - 1,
b_proc=b.comm.rank,
Expand Down Expand Up @@ -645,7 +645,7 @@ def matmul(a, b, allow_resplit=False):

# need to wait if its the last loop, also need to collect the remainders
if pr == b.comm.size - 1:
req[pr].wait()
req[pr].Wait()
__mm_c_block_setter(
a_proc=pr,
b_proc=a.comm.rank,
Expand Down Expand Up @@ -706,7 +706,7 @@ def matmul(a, b, allow_resplit=False):

# receive the data from the last loop and do the calculation with that
if pr != 0:
req[pr - 1].wait()
req[pr - 1].Wait()
# after receiving the last loop's bcast
st0 = index_map[pr - 1, 0, 0, 0].item()
sp0 = index_map[pr - 1, 0, 0, 1].item() + 1
Expand All @@ -717,7 +717,7 @@ def matmul(a, b, allow_resplit=False):
del b_lp_data[pr - 1]

if pr == b.comm.size - 1:
req[pr].wait()
req[pr].Wait()
st0 = index_map[pr, 0, 0, 0].item()
sp0 = index_map[pr, 0, 0, 1].item() + 1
st1 = index_map[pr, 1, 1, 0].item()
Expand Down
8 changes: 4 additions & 4 deletions heat/core/linalg/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def __split0_q_loop(col, r_tiles, proc_tile_start, active_procs, q0_tiles, q_dic
if col in q_dict_waits.keys():
for key in q_dict_waits[col].keys():
new_key = q_dict_waits[col][key][3] + key + "e"
q_dict_waits[col][key][0][1].wait()
q_dict_waits[col][key][0][1].Wait()
q_dict[col][new_key] = [
q_dict_waits[col][key][0][0],
q_dict_waits[col][key][1].wait(),
Expand Down Expand Up @@ -728,7 +728,7 @@ def __split0_q_loop(col, r_tiles, proc_tile_start, active_procs, q0_tiles, q_dic
for pr in range(diag_process, active_procs[-1] + 1):
if local_merge_q[pr][1] is not None:
# receive q from the other processes
local_merge_q[pr][1].wait()
local_merge_q[pr][1].Wait()
if rank in active_procs:
sum_row = sum(q0_tiles.tile_rows_per_process[:pr])
end_row = q0_tiles.tile_rows_per_process[pr] + sum_row
Expand Down Expand Up @@ -790,7 +790,7 @@ def __split0_q_loop(col, r_tiles, proc_tile_start, active_procs, q0_tiles, q_dic
)
for ind in qi_mult[qi_col]:
if global_merge_dict[ind][1] is not None:
global_merge_dict[ind][1].wait()
global_merge_dict[ind][1].Wait()
lp_q = global_merge_dict[ind][0]
if mult_qi_col.shape[1] < lp_q.shape[1]:
new_mult = torch.zeros(
Expand All @@ -810,7 +810,7 @@ def __split0_q_loop(col, r_tiles, proc_tile_start, active_procs, q0_tiles, q_dic
q0_tiles.arr.lloc[:, write_inds[2] : write_inds[2] + hold.shape[1]] = hold
else:
for ind in merge_dict_keys:
global_merge_dict[ind][1].wait()
global_merge_dict[ind][1].Wait()
if col in q_dict.keys():
del q_dict[col]

Expand Down
12 changes: 6 additions & 6 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ def concatenate(arrays, axis=0):
chunk_map[arr0.comm.rank, i] = chk[i].stop - chk[i].start
chunk_map_comm = arr0.comm.Iallreduce(MPI.IN_PLACE, chunk_map, MPI.SUM)

lshape_map_comm.wait()
chunk_map_comm.wait()
lshape_map_comm.Wait()
chunk_map_comm.Wait()

if s0 is not None:
send_slice = [slice(None)] * arr0.ndim
Expand All @@ -342,7 +342,7 @@ def concatenate(arrays, axis=0):
tag=pr + arr0.comm.size + spr,
)
arr0._DNDarray__array = arr0.lloc[keep_slice].clone()
send.wait()
send.Wait()
for pr in range(spr):
snt = abs((chunk_map[pr, s0] - lshape_map[0, pr, s0]).item())
snt = (
Expand Down Expand Up @@ -389,7 +389,7 @@ def concatenate(arrays, axis=0):
tag=pr + arr1.comm.size + spr,
)
arr1._DNDarray__array = arr1.lloc[keep_slice].clone()
send.wait()
send.Wait()
for pr in range(arr1.comm.size - 1, spr, -1):
snt = abs((chunk_map[pr, axis] - lshape_map[1, pr, axis]).item())
snt = (
Expand Down Expand Up @@ -2355,9 +2355,9 @@ def resplit(arr, axis=None):
buf = torch.zeros_like(new_tiles[key])
rcv_waits[key] = [arr.comm.Irecv(buf=buf, source=spr, tag=spr), buf]
for w in waits:
w.wait()
w.Wait()
for k in rcv_waits.keys():
rcv_waits[k][0].wait()
rcv_waits[k][0].Wait()
new_tiles[k] = rcv_waits[k][1]

return new_arr
Expand Down