Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/187 balance #308

Merged
merged 10 commits into from
Jun 18, 2019
152 changes: 149 additions & 3 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import arithmetics
from . import devices
from . import exponential
from . import factories
from . import indexing
from . import io
from . import linalg
Expand Down Expand Up @@ -483,6 +484,135 @@ def astype(self, dtype, copy=True):

return self

def balance_(self):
TheSlimvReal marked this conversation as resolved.
Show resolved Hide resolved
"""
Function for balancing a DNDarray between all nodes. To determine if this is needed use the is_balanced function.
If the DNDarray is already balanced this function will do nothing. This function modifies the DNDarray itself and will not return anything.

Examples
--------
>>> a = ht.zeros((10, 2), split=0)
>>> a[:, 0] = ht.arange(10)
>>> b = a[3:]
[0/2] tensor([[3., 0.],
[1/2] tensor([[4., 0.],
[5., 0.],
[6., 0.]])
[2/2] tensor([[7., 0.],
[8., 0.],
[9., 0.]])
>>> b.balance_()
>>> print(b.gshape, b.lshape)
[0/2] (7, 2) (1, 2)
TheSlimvReal marked this conversation as resolved.
Show resolved Hide resolved
[1/2] (7, 2) (3, 2)
[2/2] (7, 2) (3, 2)
>>> b
[0/2] tensor([[3., 0.],
[4., 0.],
[5., 0.]])
[1/2] tensor([[6., 0.],
[7., 0.]])
[2/2] tensor([[8., 0.],
[9., 0.]])
>>> print(b.gshape, b.lshape)
[0/2] (7, 2) (3, 2)
[1/2] (7, 2) (2, 2)
[2/2] (7, 2) (2, 2)
"""
if self.is_balanced():
return
# units -> {pr, 1st index, 2nd index}
lshape_map = factories.zeros((self.comm.size, len(self.gshape)), dtype=int)
lshape_map[self.comm.rank, :] = torch.Tensor(self.lshape)
lshape_map_comm = self.comm.Iallreduce(MPI.IN_PLACE, lshape_map, MPI.SUM)

chunk_map = factories.zeros((self.comm.size, len(self.gshape)), dtype=int)
_, _, chk = self.comm.chunk(self.shape, self.split)
for i in range(len(self.gshape)):
chunk_map[self.comm.rank, i] = chk[i].stop - chk[i].start
chunk_map_comm = self.comm.Iallreduce(MPI.IN_PLACE, chunk_map, MPI.SUM)

lshape_map_comm.wait()
chunk_map_comm.wait()

# create list of which processes need to send data to lower ranked nodes
send_list = [True if lshape_map[pr, self.split] != (chunk_map[pr, self.split]) and lshape_map[pr, self.split] != 0 else False
TheSlimvReal marked this conversation as resolved.
Show resolved Hide resolved
for pr in range(1, self.comm.size)]
send_list.insert(0, True if lshape_map[0, self.split] > (chunk_map[0, self.split]) else False)
first_pr_w_data = send_list.index(True) # first process with *too much* data
last_pr_w_data = next((i for i in reversed(range(len(lshape_map[:, self.split]))) if lshape_map[i, self.split] > chunk_map[i, self.split]))

# create arbitrary slices for which data to send and which data to keep
send_slice = [slice(None), ] * self.numdims
keep_slice = [slice(None), ] * self.numdims

# first send the first entries of the data to the 0th node and then the next data to the 1st ...
# this will redistributed the data forward
if first_pr_w_data != 0:
for spr in range(first_pr_w_data, last_pr_w_data + 1):
if self.comm.rank == spr:
for pr in range(spr):
send_amt = abs((chunk_map[pr, self.split] - lshape_map[pr, self.split]).item())
if send_amt:
send_amt = send_amt if send_amt < self.lshape[self.split] else self.lshape[self.split]
send_slice[self.split] = slice(0, send_amt)
keep_slice[self.split] = slice(send_amt, self.lshape[self.split])

self.comm.Isend(self.__array[send_slice].clone(), dest=pr, tag=pr + self.comm.size + spr)
self.__array = self.__array[keep_slice].clone()

# else:
for pr in range(spr):
snt = abs((chunk_map[pr, self.split] - lshape_map[pr, self.split]).item())
snt = snt if snt < lshape_map[spr, self.split] else lshape_map[spr, self.split].item()

if self.comm.rank == pr and snt:
shp = list(self.gshape)
shp[self.split] = snt
data = torch.zeros(shp)
self.comm.Recv(data, source=spr, tag=pr + self.comm.size + spr)
self.__array = torch.cat((self.__array, data), dim=self.split)
lshape_map[pr, self.split] += snt
lshape_map[spr, self.split] -= snt

if self.is_balanced():
return

# now the DNDarray is balanced from 0 to x, (by pulling data from the higher ranking nodes)
# next we balance the data from x to the self.comm.size
send_list = [True if lshape_map[pr, self.split] > (chunk_map[pr, self.split]) else False
for pr in range(self.comm.size)]
first_pr_w_data = send_list.index(True) # first process with *too much* data
last_pr_w_data = next((i for i in reversed(range(len(lshape_map[:, self.split]))) if lshape_map[i, self.split] > chunk_map[i, self.split]))

send_slice = [slice(None), ] * self.numdims
keep_slice = [slice(None), ] * self.numdims
# need to send from the last one with data
for spr in range(last_pr_w_data, first_pr_w_data - 1, -1):
if self.comm.rank == spr:
for pr in range(self.comm.size - 1, spr, -1):
send_amt = abs((chunk_map[pr, self.split] - lshape_map[pr, self.split]).item())
if send_amt:
send_amt = send_amt if send_amt < self.lshape[self.split] else self.lshape[self.split]
send_slice[self.split] = slice(self.lshape[self.split] - send_amt, self.lshape[self.split])
keep_slice[self.split] = slice(0, self.lshape[self.split] - send_amt)

self.comm.Isend(self.__array[send_slice].clone(), dest=pr, tag=pr + self.comm.size + spr)
self.__array = self.__array[keep_slice].clone()

for pr in range(self.comm.size - 1, spr, -1):
snt = abs((chunk_map[pr, self.split] - lshape_map[pr, self.split]).item())
snt = snt if snt < lshape_map[spr, self.split] else lshape_map[spr, self.split].item()

if self.comm.rank == pr and snt:
shp = list(self.gshape)
shp[self.split] = snt
data = torch.zeros(shp)
self.comm.Recv(data, source=spr, tag=pr + self.comm.size + spr)
self.__array = torch.cat((data, self.__array), dim=self.split)
lshape_map[pr, self.split] += snt
lshape_map[spr, self.split] -= snt

def __bool__(self):
"""
Boolean scalar casting.
Expand Down Expand Up @@ -982,7 +1112,7 @@ def __getitem__(self, key):
arr = self.__array[key]
gout = list(arr.shape)
else:
warnings.warn("This process (rank: {}) is without data after slicing".format(
warnings.warn("This process (rank: {}) is without data after slicing, running the .balance_() function is recommended".format(
self.comm.rank), ResourceWarning)
# arr is empty and gout is zeros

Expand Down Expand Up @@ -1019,7 +1149,7 @@ def __getitem__(self, key):
arr = self.__array[tuple(key)]
gout = list(arr.shape)
else:
warnings.warn("This process (rank: {}) is without data after slicing".format(
warnings.warn("This process (rank: {}) is without data after slicing, running the .balance_() function is recommended".format(
self.comm.rank), ResourceWarning)
# arr is empty
# gout is all 0s and is the proper shape
Expand All @@ -1045,7 +1175,7 @@ def __getitem__(self, key):
arr = self.__array[key]
gout = list(arr.shape)
else:
warnings.warn("This process (rank: {}) is without data after slicing".format(
warnings.warn("This process (rank: {}) is without data after slicing, running the .balance_() function is recommended".format(
self.comm.rank), ResourceWarning)
# arr is empty
# gout is all 0s and is the proper shape
Expand Down Expand Up @@ -1135,6 +1265,22 @@ def __int__(self):
"""
return self.__cast(int)

def is_balanced(self):
"""
Determine if a DNDarray is balanced evenly (or as evenly as possible) across all nodes

Returns
-------
balanced : bool
True if balanced, False if not
"""
_, _, chk = self.comm.chunk(self.shape, self.split)
test_lshape = tuple([x.stop - x.start for x in chk])
balanced = 1 if test_lshape == self.lshape else 0

out = self.comm.allreduce(balanced, MPI.SUM)
return True if out == self.comm.size else False

def is_distributed(self):
"""
Determines whether the data of this tensor is distributed across multiple processes.
Expand Down
19 changes: 19 additions & 0 deletions heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ def test_astype(self):
self.assertEqual(as_float64._DNDarray__array.dtype, torch.float64)
self.assertIs(as_float64, data)

def test_balance_(self):
data = ht.zeros((70, 20), split=0)
data = data[:50]
data.balance_()
self.assertTrue(data.is_balanced())

data = ht.zeros((4, 120), split=1)
data = data[:, 40:70]
data.balance_()
self.assertTrue(data.is_balanced())

def test_bool_cast(self):
# simple scalar tensor
a = ht.ones(1)
Expand Down Expand Up @@ -149,6 +160,14 @@ def test_int_cast(self):
with self.assertRaises(TypeError):
int(ht.full((ht.MPI_WORLD.size,), 2, split=0))

def test_is_balanced(self):
data = ht.zeros((70, 20), split=0)
if data.comm.size != 1:
data = data[:50]
self.assertFalse(data.is_balanced())
data.balance_()
self.assertTrue(data.is_balanced())

def test_is_distributed(self):
data = ht.zeros((5, 5,))
self.assertFalse(data.is_distributed())
Expand Down