Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/mean var merge cleanup #445

Merged
merged 14 commits into from
Jan 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
232 changes: 120 additions & 112 deletions heat/core/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def mean(x, axis=None):
----------
x : ht.DNDarray
Values for which the mean is calculated for.
The dtype of x must be a float
axis : None, Int, iterable, defaults to None
Axis which the mean is taken in. Default None calculates mean of all data items.

Expand Down Expand Up @@ -705,19 +706,21 @@ def reduce_means_elementwise(output_shape_i):
if x.lshape[x.split] != 0:
mu = torch.mean(x._DNDarray__array, dim=axis)
else:
mu = factories.zeros(output_shape_i)
mu = factories.zeros(output_shape_i, device=x.device)

mu_shape = list(mu.shape) if list(mu.shape) else [1]

mu_tot = factories.zeros(([x.comm.size] + mu_shape), device=x.device)
n_tot = factories.zeros(x.comm.size, device=x.device)
mu_tot[x.comm.rank, :] = mu
n_tot[x.comm.rank] = float(x.lshape[x.split])
x.comm.Allreduce(MPI.IN_PLACE, mu_tot, MPI.SUM)
mu_tot[x.comm.rank, :] = mu
x.comm.Allreduce(MPI.IN_PLACE, n_tot, MPI.SUM)
x.comm.Allreduce(MPI.IN_PLACE, mu_tot, MPI.SUM)

for i in range(1, x.comm.size):
mu_tot[0, :], n_tot[0] = merge_means(mu_tot[0, :], n_tot[0], mu_tot[i, :], n_tot[i])
mu_tot[0, :], n_tot[0] = __merge_moments(
(mu_tot[0, :], n_tot[0]), (mu_tot[i, :], n_tot[i])
)
return mu_tot[0][0] if mu_tot[0].size == 1 else mu_tot[0]

# ----------------------------------------------------------------------------------------------
Expand All @@ -727,27 +730,32 @@ def reduce_means_elementwise(output_shape_i):
if not x.is_distributed():
# if x is not distributed do a torch.mean on x
ret = torch.mean(x._DNDarray__array.float())
return factories.array(ret, is_split=None)
return factories.array(ret, is_split=None, device=x.device)
else:
# if x is distributed and no axis is given: return mean of the whole set
mu_in = torch.mean(x._DNDarray__array)
if torch.isnan(mu_in):
mu_in = 0.0
n = x.lnumel
mu_tot = factories.zeros((x.comm.size, 2))
mu_proc = factories.zeros((x.comm.size, 2))
mu_proc[x.comm.rank][0] = mu_in
mu_proc[x.comm.rank][1] = float(n)
mu_tot = factories.zeros((x.comm.size, 2), device=x.device)
mu_proc = factories.zeros((x.comm.size, 2), device=x.device)
mu_proc[x.comm.rank] = mu_in, float(n)
x.comm.Allreduce(mu_proc, mu_tot, MPI.SUM)

for i in range(1, x.comm.size):
merged = merge_means(mu_tot[0, 0], mu_tot[0, 1], mu_tot[i, 0], mu_tot[i, 1])
mu_tot[0, 0] = merged[0]
mu_tot[0, 1] = merged[1]
mu_tot[0, 0], mu_tot[0, 1] = __merge_moments(
(mu_tot[0, 0], mu_tot[0, 1]), (mu_tot[i, 0], mu_tot[i, 1])
)
return mu_tot[0][0]

output_shape = list(x.shape)
if isinstance(axis, (list, tuple, dndarray.DNDarray, torch.Tensor)):
if isinstance(axis, (list, tuple)):
if len(set(axis)) != len(axis):
raise ValueError("duplicate value in axis")
if isinstance(axis, (dndarray.DNDarray, torch.Tensor)):
if axis.unique().numel() != axis.numel():
raise ValueError("duplicate value in axis")
if any([not isinstance(j, int) for j in axis]):
raise ValueError(
"items in axis iterable must be integers, axes: {}".format([type(q) for q in axis])
Expand All @@ -762,7 +770,9 @@ def reduce_means_elementwise(output_shape_i):
output_shape = [output_shape[it] for it in range(len(output_shape)) if it not in axis]
# multiple dimensions
if x.split is None:
return factories.array(torch.mean(x._DNDarray__array, dim=axis), is_split=x.split)
return factories.array(
torch.mean(x._DNDarray__array, dim=axis), is_split=x.split, device=x.device
)

if x.split in axis:
# merge in the direction of the split
Expand All @@ -773,6 +783,7 @@ def reduce_means_elementwise(output_shape_i):
return factories.array(
torch.mean(x._DNDarray__array, dim=axis),
is_split=x.split if x.split < len(output_shape) else len(output_shape) - 1,
device=x.device,
)
elif isinstance(axis, int):
if axis >= len(x.shape):
Expand All @@ -784,99 +795,73 @@ def reduce_means_elementwise(output_shape_i):
output_shape = output_shape if output_shape else (1,)

if x.split is None:
return factories.array(torch.mean(x._DNDarray__array, dim=axis), is_split=None)

return factories.array(
torch.mean(x._DNDarray__array, dim=axis), is_split=None, device=x.device
)
elif axis == x.split:
return reduce_means_elementwise(output_shape)
else:
# singular axis given (axis) not equal to split direction (x.split)
return factories.array(
torch.mean(x._DNDarray__array, dim=axis),
is_split=x.split if axis > x.split else x.split - 1,
device=x.device,
)
raise TypeError(
"axis (axis) must be an int or a list, ht.DNDarray, "
"torch.Tensor, or tuple, but was {}".format(type(axis))
)


def merge_means(mu1, n1, mu2, n2):
def __merge_moments(m1, m2, bessel=True):
"""
Function to merge two means by pairwise update.
**Note** all tensors/arrays must be either the same size or individual values (can be mixed, i.e. n can be a float)
Merge two statistical moments. If the length of m1/m2 (must be equal) is == 3 then the second moment (variance)
is merged. This function can be expanded to merge other moments according to Reference 1 as well.
Note: all tensors/arrays must be either the same size or individual values

Parameters
----------
mu1 : ht.DNDarray, torch.tensor, float, int
Calculated mean
n1 : ht.DNDarray, torch.tensor, float
number of elements used to calculate mu1
mu2 : ht.DNDarray, torch.tensor, float, int
Calculated mean
n2 : ht.DNDarray, torch.tensor, float
number of elements used to calculate mu2

Returns
-------
combined_set_count : int
Number of elements in the combined set.

References
----------
[1] J. Bennett, R. Grout, P. Pebay, D. Roe, D. Thompson, Numerically stable, single-pass, parallel statistics
algorithms, IEEE International Conference on Cluster Computing and Workshops, 2009, Oct 2009, New Orleans, LA,
USA.
"""
return mu1 + n2 * ((mu2 - mu1) / (n1 + n2)), n1 + n2


def merge_vars(var1, mu1, n1, var2, mu2, n2, bessel=True):
"""
Function to merge two variances by pairwise update.
**Note** this is a parallel of the merge_means function
**Note pt2.** all tensors/arrays must be either the same size or individual values

Parameters
----------
var1 : ht.DNDarray, torch.tensor, float, int
Variance.
mu1 : ht.DNDarray, torch.tensor, float, int
Calculated mean.
n1 : ht.DNDarray, torch.tensor, float, int
Number of elements used to calculate mu1.
var2 : ht.DNDarray, torch.tensor, float, int
Variance.
mu2 : ht.DNDarray, torch.tensor, float, int
Calculated mean.
n2 : ht.DNDarray, torch.tensor, float, int
Number of elements used to calculate mu2.
m1 : tuple
Tuple of the moments to merge together, the 0th element is the moment to be merged. The tuple must be
sorted in descending order of moments
m2 : tuple
Tuple of the moments to merge together, the 0th element is the moment to be merged. The tuple must be
sorted in descending order of moments
bessel : bool
Flag for the use of the bessel correction
Flag for the use of the bessel correction for the calculation of the variance

Returns
-------
combined_set_count : int
Number of elements in the combined set.
merged_moments : tuple
a tuple of the merged moments

References
----------
[1] J. Bennett, R. Grout, P. Pebay, D. Roe, D. Thompson, Numerically stable, single-pass, parallel statistics
algorithms, IEEE International Conference on Cluster Computing and Workshops, 2009, Oct 2009, New Orleans, LA,
USA.
"""
if len(m1) != len(m2):
raise ValueError(
"m1 and m2 must be same length, currently {} and {}".format(len(m1, len(m2)))
)
n1, n2 = m1[-1], m2[-1]
mu1, mu2 = m1[-2], m2[-2]
coquelin77 marked this conversation as resolved.
Show resolved Hide resolved
n = n1 + n2
delta = mu2 - mu1
mu = mu1 + n2 * (delta / n)
if len(m1) == 2: # merge means
return mu, n

var1, var2 = m1[-3], m2[-3]
if bessel:
return (
(var1 * (n1 - 1) + var2 * (n2 - 1) + (delta ** 2) * n1 * n2 / n) / (n - 1),
mu1 + n2 * (delta / (n1 + n2)),
n,
)
var_m = (var1 * (n1 - 1) + var2 * (n2 - 1) + (delta ** 2) * n1 * n2 / n) / (n - 1)
else:
return (
(var1 * n1 + var2 * n2 + (delta ** 2) * n1 * n2 / n) / n,
mu1 + n2 * (delta / (n1 + n2)),
n,
)
var_m = (var1 * n1 + var2 * n2 + (delta ** 2) * n1 * n2 / n) / n

if len(m1) == 3: # merge vars
return var_m, mu, n


def min(x, axis=None, out=None, keepdim=None):
Expand Down Expand Up @@ -1122,6 +1107,7 @@ def std(x, axis=None, bessel=True):
----------
x : ht.DNDarray
Values for which the std is calculated for.
The dtype of x must be a float
axis : None, int, defaults to None
Axis which the mean is taken in. Default None calculates the standard deviation of all data items.
bessel : bool, defaults to True
Expand Down Expand Up @@ -1168,6 +1154,7 @@ def var(x, axis=None, bessel=True):
----------
x : ht.DNDarray
Values for which the variance is calculated for.
The dtype of x must be a float
axis : None, int, defaults to None
Axis which the variance is taken in. Default None calculates the variance of all data items.
bessel : bool, defaults to True
Expand Down Expand Up @@ -1235,32 +1222,24 @@ def reduce_vars_elementwise(output_shape_i):
mu = torch.mean(x._DNDarray__array, dim=axis)
var = torch.var(x._DNDarray__array, dim=axis, unbiased=bessel)
else:
mu = factories.zeros(output_shape_i)
var = factories.zeros(output_shape_i)

n_for_merge = factories.zeros(x.comm.size)
n2 = factories.zeros(x.comm.size)
n2[x.comm.rank] = x.lshape[x.split]
x.comm.Allreduce(n2, n_for_merge, MPI.SUM)
mu = factories.zeros(output_shape_i, device=x.device)
var = factories.zeros(output_shape_i, device=x.device)

var_shape = list(var.shape) if list(var.shape) else [1]

var_tot = factories.zeros(([x.comm.size, 2] + var_shape))
n_tot = factories.zeros(x.comm.size)
var_tot = factories.zeros(([x.comm.size, 2] + var_shape), device=x.device)
n_tot = factories.zeros(x.comm.size, device=x.device)
var_tot[x.comm.rank, 0, :] = var
var_tot[x.comm.rank, 1, :] = mu
n_tot[x.comm.rank] = float(x.lshape[x.split])
x.comm.Allreduce(MPI.IN_PLACE, var_tot, MPI.SUM)
x.comm.Allreduce(MPI.IN_PLACE, n_tot, MPI.SUM)

for i in range(1, x.comm.size):
var_tot[0, 0, :], var_tot[0, 1, :], n_tot[0] = merge_vars(
var_tot[0, 0, :],
var_tot[0, 1, :],
n_tot[0],
var_tot[i, 0, :],
var_tot[i, 1, :],
n_tot[i],
var_tot[0, 0, :], var_tot[0, 1, :], n_tot[0] = __merge_moments(
(var_tot[0, 0, :], var_tot[0, 1, :], n_tot[0]),
(var_tot[i, 0, :], var_tot[i, 1, :], n_tot[i]),
bessel=bessel,
)
return var_tot[0, 0, :][0] if var_tot[0, 0, :].size == 1 else var_tot[0, 0, :]

Expand All @@ -1280,32 +1259,60 @@ def reduce_vars_elementwise(output_shape_i):
mu_in = 0.0

n = x.lnumel
var_tot = factories.zeros((x.comm.size, 3))
var_proc = factories.zeros((x.comm.size, 3))
var_proc[x.comm.rank][0] = var_in
var_proc[x.comm.rank][1] = mu_in
var_proc[x.comm.rank][2] = float(n)
var_tot = factories.zeros((x.comm.size, 3), device=x.device)
var_proc = factories.zeros((x.comm.size, 3), device=x.device)
var_proc[x.comm.rank] = var_in, mu_in, float(n)
x.comm.Allreduce(var_proc, var_tot, MPI.SUM)

for i in range(1, x.comm.size):
merged = merge_vars(
var_tot[0, 0],
var_tot[0, 1],
var_tot[0, 2],
var_tot[i, 0],
var_tot[i, 1],
var_tot[i, 2],
var_tot[0, 0], var_tot[0, 1], var_tot[0, 2] = __merge_moments(
(var_tot[0, 0], var_tot[0, 1], var_tot[0, 2]),
(var_tot[i, 0], var_tot[i, 1], var_tot[i, 2]),
bessel=bessel,
)
var_tot[0, 0] = merged[0]
var_tot[0, 1] = merged[1]
var_tot[0, 2] = merged[2]

return var_tot[0][0]

else: # axis is given
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small inconsistency: When axis is set, integer tensors are not supported any longer. Numpy does support ints and floats, while pytorch only supports floats in general. This behaviour is also in 'mean'.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because torch does not support ints so we dont either. if its needed, the easiest way to do this would be to cast the input to a float. I dont think that we need to do this though. i dont think that we need to be so religious to numpy

# case for var in one dimension
output_shape = list(x.shape)
if isinstance(axis, int):
coquelin77 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(axis, (list, tuple, dndarray.DNDarray, torch.Tensor)):
if isinstance(axis, (list, tuple)):
if len(set(axis)) != len(axis):
raise ValueError("duplicate value in axis")
if isinstance(axis, (dndarray.DNDarray, torch.Tensor)):
if axis.unique().numel() != axis.numel():
raise ValueError("duplicate value in axis")
if any([not isinstance(j, int) for j in axis]):
raise ValueError(
"items in axis iterable must be integers, axes: {}".format(
[type(q) for q in axis]
)
)
if any(d < 0 for d in axis):
axis = [stride_tricks.sanitize_axis(x.shape, j) for j in axis]
if any(d > len(x.shape) for d in axis):
raise ValueError(
"axes (axis) must be < {}, currently are {}".format(len(x.shape), axis)
)

output_shape = [output_shape[it] for it in range(len(output_shape)) if it not in axis]
# multiple dimensions
if x.split is None:
return factories.array(
torch.var(x._DNDarray__array, dim=axis), is_split=x.split, device=x.device
)
if x.split in axis:
# merge in the direction of the split
return reduce_vars_elementwise(output_shape)
else:
# multiple dimensions which does *not* include the split axis
# combine along the split axis
return factories.array(
torch.var(x._DNDarray__array, dim=axis),
is_split=x.split if x.split < len(output_shape) else len(output_shape) - 1,
device=x.device,
)
elif isinstance(axis, int):
if axis >= len(x.shape):
raise ValueError("axis must be < {}, currently is {}".format(len(x.shape), axis))
axis = stride_tricks.sanitize_axis(x.shape, axis)
Expand All @@ -1314,15 +1321,16 @@ def reduce_vars_elementwise(output_shape_i):
output_shape = output_shape if output_shape else (1,)

if x.split is None: # x is *not* distributed -> no need to distributed
return factories.array(torch.var(x._DNDarray__array, dim=axis, unbiased=bessel))
return factories.array(
torch.var(x._DNDarray__array, dim=axis, unbiased=bessel), device=x.device
)
elif axis == x.split: # x is distributed and axis chosen is == to split
return reduce_vars_elementwise(output_shape)
else:
# singular axis given (axis) not equal to split direction (x.split)
lcl = torch.var(x._DNDarray__array, dim=axis, keepdim=False)
return factories.array(lcl, is_split=x.split if axis > x.split else x.split - 1)
return factories.array(
lcl, is_split=x.split if axis > x.split else x.split - 1, device=x.device
)
else:
raise TypeError(
"axis (axis) must be an int, currently is {}. "
"Check if multidim var is available in PyTorch".format(type(axis))
)
raise TypeError("axis (axis) must be an int, tuple, list, etc.; currently it is {}. ")