From b8d86b6b1caa5f834d9bbf093038bbc0bae40b15 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Fri, 26 Jun 2020 16:50:20 +0200 Subject: [PATCH 01/14] abstracted moments with axis to own function. added skew, no tests done yet --- heat/core/statistics.py | 333 ++++++++++++++++++++++++---------------- 1 file changed, 201 insertions(+), 132 deletions(-) diff --git a/heat/core/statistics.py b/heat/core/statistics.py index 53f40b4dae..c4054b4559 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -1,3 +1,4 @@ +import math import numpy as np import torch @@ -729,7 +730,6 @@ def reduce_means_elementwise(output_shape_i): return mu_tot[0][0] if mu_tot[0].size == 1 else mu_tot[0] # ---------------------------------------------------------------------------------------------- - if axis is None: # full matrix calculation if not x.is_distributed(): @@ -752,71 +752,7 @@ def reduce_means_elementwise(output_shape_i): (mu_tot[0, 0], mu_tot[0, 1]), (mu_tot[i, 0], mu_tot[i, 1]) ) return mu_tot[0][0] - - output_shape = list(x.shape) - if isinstance(axis, (list, tuple, dndarray.DNDarray, torch.Tensor)): - if isinstance(axis, (list, tuple)): - if len(set(axis)) != len(axis): - raise ValueError("duplicate value in axis") - if isinstance(axis, (dndarray.DNDarray, torch.Tensor)): - if axis.unique().numel() != axis.numel(): - raise ValueError("duplicate value in axis") - if any([not isinstance(j, int) for j in axis]): - raise ValueError( - "items in axis iterable must be integers, axes: {}".format([type(q) for q in axis]) - ) - if any(d < 0 for d in axis): - axis = [stride_tricks.sanitize_axis(x.shape, j) for j in axis] - if any(d > len(x.shape) for d in axis): - raise ValueError( - "axes (axis) must be < {}, currently are {}".format(len(x.shape), axis) - ) - - output_shape = [output_shape[it] for it in range(len(output_shape)) if it not in axis] - # multiple dimensions - if x.split is None: - return factories.array( - torch.mean(x._DNDarray__array, dim=axis), is_split=x.split, device=x.device - ) - - if x.split in axis: - # merge in the direction of the split - return reduce_means_elementwise(output_shape) - else: - # multiple dimensions which does *not* include the split axis - # combine along the split axis - return factories.array( - torch.mean(x._DNDarray__array, dim=axis), - is_split=x.split if x.split < len(output_shape) else len(output_shape) - 1, - device=x.device, - ) - elif isinstance(axis, int): - if axis >= len(x.shape): - raise ValueError("axis (axis) must be < {}, currently is {}".format(len(x.shape), axis)) - axis = stride_tricks.sanitize_axis(x.shape, axis=axis) - - # only one axis given - output_shape = [output_shape[it] for it in range(len(output_shape)) if it != axis] - output_shape = output_shape if output_shape else (1,) - - if x.split is None: - - return factories.array( - torch.mean(x._DNDarray__array, dim=axis), is_split=None, device=x.device - ) - elif axis == x.split: - return reduce_means_elementwise(output_shape) - else: - # singular axis given (axis) not equal to split direction (x.split) - return factories.array( - torch.mean(x._DNDarray__array, dim=axis), - is_split=x.split if axis > x.split else x.split - 1, - device=x.device, - ) - raise TypeError( - "axis (axis) must be an int or a list, ht.DNDarray, " - "torch.Tensor, or tuple, but was {}".format(type(axis)) - ) + return __moment_w_axis(torch.mean, x, axis, reduce_means_elementwise) def __merge_moments(m1, m2, bessel=True): @@ -868,6 +804,29 @@ def __merge_moments(m1, m2, bessel=True): if len(m1) == 3: # merge vars return var_m, mu, n + sk1, sk2 = m1[-4], m2[-4] + dn = delta / n + if var_m != 0: # Skewness does nto exist of var is 0 + skew_m = ( + sk1 + + sk2 + + 3 * ((n1 * var2 - n2 * var1) + (dn ** 2) * (n1 * n2) * ((n1 ** 2) - (n2 ** 2))) * dn + ) + else: + skew_m = None + if len(m1) == 4: + return skew_m, var_m, mu, n + + k1, k2 = m1[-5], m2[-5] + if skew_m is None: + return None, skew_m, var_m, mu, n + s1 = 4 * dn * (n1 * sk2 - n2 * sk1) + s2 = 6 * (dn ** 2) * ((n1 ** 2) * var2 + (n2 ** 2) * var1) + s3 = (dn ** 4) * n1 * n2 * ((n1 ** 3) + (n2 ** 3)) + k = k1 + k2 + s1 + s2 + s3 + if len(m1) == 5: + return k, skew_m, var_m, mu, n + def min(x, axis=None, out=None, keepdim=None): # TODO: initial : scalar, optional Issue #101 @@ -1151,6 +1110,181 @@ def std(x, axis=None, ddof=0, **kwargs): return exponential.sqrt(var(x, axis, ddof, **kwargs), out=None) +def skew(x, axis, unbiased=True): + bessel = unbiased + + def reduce_skews_elementwise(output_shape_i): + """ + Function to combine the calculated vars together. This does an element-wise update of the + calculated vars to merge them together using the merge_vars function. This function operates + using x from the var function parameters. + + Parameters + ---------- + output_shape_i : iterable + Iterable with the dimensions of the output of the var function. + + Returns + ------- + variances : ht.DNDarray + The calculated variances. + """ + + if x.lshape[x.split] != 0: + mu = torch.mean(x._DNDarray__array, dim=axis) + var = torch.var(x._DNDarray__array, dim=axis, unbiased=bessel) + sk = __torch_skew(x._DNDarray__array, dim=axis, unbiased=unbiased) + else: + mu = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) + var = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) + sk = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) + + sk_shape = list(sk.shape) if list(sk.shape) else [1] + + tot = factories.zeros(([x.comm.size, 4] + sk_shape), dtype=x.dtype, device=x.device) + tot[x.comm.rank, 0, :] = sk + tot[x.comm.rank, 1, :] = var + tot[x.comm.rank, 2, :] = mu + tot[x.comm.rank, 3, :] = float(x.lshape[x.split]) + x.comm.Allreduce(MPI.IN_PLACE, tot, MPI.SUM) + + for i in range(1, x.comm.size): + tot[0, 0, :], tot[0, 1, :], tot[0, 2, :], tot[0, 3, :] = __merge_moments( + (tot[0, 0, :], tot[0, 1, :], tot[0, 2, :], tot[0, 3, :]), + (tot[i, 0, :], tot[i, 1, :], tot[0, 2, :], tot[0, 3, :]), + bessel=bessel, + ) + return tot[0, 0, :][0] if tot[0, 0, :].size == 1 else tot[0, 0, :] + + # ---------------------------------------------------------------------------------------------- + if axis is None: # no axis given + if not x.is_distributed(): # not distributed (full tensor on one node) + ret = __torch_skew(x._DNDarray__array.float(), unbiased=unbiased) + return factories.array(ret) + + else: # case for full matrix calculation (axis is None) + mu_in = torch.mean(x._DNDarray__array) + var_in = torch.var(x._DNDarray__array, unbiased=unbiased) + skew_in = __torch_skew(x._DNDarray__array.float(), biased=unbiased) + # Nan is returned when local tensor is empty + if torch.isnan(skew_in): + skew_in = 0.0 + if torch.isnan(var_in): + var_in = 0.0 + if torch.isnan(mu_in): + mu_in = 0.0 + + n = x.lnumel + tot = factories.zeros((x.comm.size, 4), dtype=x.dtype, device=x.device) + skew_proc = factories.zeros((x.comm.size, 4), dtype=x.dtype, device=x.device) + skew_proc[x.comm.rank] = skew_in, var_in, mu_in, float(n) + x.comm.Allreduce(skew_proc, tot, MPI.SUM) + + for i in range(1, x.comm.size): + tot[0, 0], tot[0, 1], tot[0, 2], tot[0, 3] = __merge_moments( + (tot[0, 0], tot[0, 1], tot[0, 2], tot[0, 3]), + (tot[i, 0], tot[i, 1], tot[i, 2], tot[0, 3]), + bessel=bessel, + ) + return tot[0][0] + + else: # axis is given + return __moment_w_axis(__torch_skew, x, axis, reduce_skews_elementwise, unbiased) + + +def __moment_w_axis(function, x, axis, elementwise_function, unbiased=None): + # helper for calculating statistical moment with a given axis + # case for var in one dimension + kwargs = {"dim": axis} + if unbiased: + kwargs["unbiased"] = unbiased + + output_shape = list(x.shape) + if isinstance(axis, (list, tuple, dndarray.DNDarray, torch.Tensor)): + if isinstance(axis, (list, tuple)) and len(set(axis)) != len(axis): + raise ValueError("duplicate value in axis") + if ( + isinstance(axis, (dndarray.DNDarray, torch.Tensor)) + and axis.unique().numel() != axis.numel() + ): + raise ValueError("duplicate value in axis") + if any([not isinstance(j, int) for j in axis]): + raise ValueError( + "items in axis iterable must be integers, axes: {}".format([type(q) for q in axis]) + ) + if any(d < 0 for d in axis): + axis = [stride_tricks.sanitize_axis(x.shape, j) for j in axis] + if any(d > len(x.shape) for d in axis): + raise ValueError( + "axes (axis) must be < {}, currently are {}".format(len(x.shape), axis) + ) + + output_shape = [output_shape[it] for it in range(len(output_shape)) if it not in axis] + # multiple dimensions + if x.split is None: + return factories.array( + function(x._DNDarray__array, **kwargs), is_split=x.split, device=x.device + ) + if x.split in axis: + # merge in the direction of the split + return elementwise_function(output_shape) + # multiple dimensions which does *not* include the split axis + # combine along the split axis + return factories.array( + function(x._DNDarray__array, **kwargs), + is_split=x.split if x.split < len(output_shape) else len(output_shape) - 1, + device=x.device, + ) + elif isinstance(axis, int): + if axis >= len(x.shape): + raise ValueError("axis must be < {}, currently is {}".format(len(x.shape), axis)) + axis = stride_tricks.sanitize_axis(x.shape, axis) + # only one axis given + output_shape = [output_shape[it] for it in range(len(output_shape)) if it != axis] + output_shape = output_shape if output_shape else (1,) + + if x.split is None: # x is *not* distributed -> no need to distributed + return factories.array( + function(x._DNDarray__array, **kwargs), dtype=x.dtype, device=x.device + ) + elif axis == x.split: # x is distributed and axis chosen is == to split + return elementwise_function(output_shape) + # singular axis given (axis) not equal to split direction (x.split) + lcl = function(x._DNDarray__array, **kwargs) + return factories.array( + lcl, is_split=x.split if axis > x.split else x.split - 1, dtype=x.dtype, device=x.device + ) + else: + raise TypeError("axis (axis) must be an int, tuple, list, etc.; currently it is {}. ") + + +def __torch_skew(torch_tensor, dim=None, unbiased=False): + # calculate the sample skewness of a torch tensor + # return the bias corrected Fischer-Pearson standardized moment coefficient by default + n = torch_tensor.numel() + diff = torch_tensor - torch.mean(torch_tensor, dim) + m3 = torch.true_divide(torch.pow(diff, 3), n) + m2 = torch.true_divide(torch.pow(diff, 2), n) + coeff = torch.sqrt(n * (n - 1)) / (n - 2) + if not unbiased: + return torch.true_divide(m3, torch.pow(m2, 1.5)) + return coeff * torch.true_divide(m3, torch.pow(m2, 1.5)) + + +def __torch_kurtosis(torch_tensor, dim=None, excess=True): + # calculate the sample kurtosis of a torch tensor, Pearson's definition + # returns the excess Kurtosis if excess is True + # there is not unbiased estimator for Kurtosis + n = torch_tensor.numel() + diff = torch_tensor - torch.mean(torch_tensor, dim) + m4 = torch.true_divide(torch.pow(diff, 4), n) + m2 = torch.true_divide(torch.pow(diff, 2), n) + k = torch.true_divide(m4, torch.pow(m2, 2)) + if excess: + k -= 3.0 + return k + + def var(x, axis=None, ddof=0, **kwargs): """ Calculates and returns the variance of a tensor. If an axis is given, the variance will be @@ -1299,69 +1433,4 @@ def reduce_vars_elementwise(output_shape_i): return var_tot[0][0] else: # axis is given - # case for var in one dimension - output_shape = list(x.shape) - if isinstance(axis, (list, tuple, dndarray.DNDarray, torch.Tensor)): - if isinstance(axis, (list, tuple)): - if len(set(axis)) != len(axis): - raise ValueError("duplicate value in axis") - if isinstance(axis, (dndarray.DNDarray, torch.Tensor)): - if axis.unique().numel() != axis.numel(): - raise ValueError("duplicate value in axis") - if any([not isinstance(j, int) for j in axis]): - raise ValueError( - "items in axis iterable must be integers, axes: {}".format( - [type(q) for q in axis] - ) - ) - if any(d < 0 for d in axis): - axis = [stride_tricks.sanitize_axis(x.shape, j) for j in axis] - if any(d > len(x.shape) for d in axis): - raise ValueError( - "axes (axis) must be < {}, currently are {}".format(len(x.shape), axis) - ) - - output_shape = [output_shape[it] for it in range(len(output_shape)) if it not in axis] - # multiple dimensions - if x.split is None: - return factories.array( - torch.var(x._DNDarray__array, dim=axis), is_split=x.split, device=x.device - ) - if x.split in axis: - # merge in the direction of the split - return reduce_vars_elementwise(output_shape) - else: - # multiple dimensions which does *not* include the split axis - # combine along the split axis - return factories.array( - torch.var(x._DNDarray__array, dim=axis), - is_split=x.split if x.split < len(output_shape) else len(output_shape) - 1, - device=x.device, - ) - elif isinstance(axis, int): - if axis >= len(x.shape): - raise ValueError("axis must be < {}, currently is {}".format(len(x.shape), axis)) - axis = stride_tricks.sanitize_axis(x.shape, axis) - # only one axis given - output_shape = [output_shape[it] for it in range(len(output_shape)) if it != axis] - output_shape = output_shape if output_shape else (1,) - - if x.split is None: # x is *not* distributed -> no need to distributed - return factories.array( - torch.var(x._DNDarray__array, dim=axis, unbiased=bessel), - dtype=x.dtype, - device=x.device, - ) - elif axis == x.split: # x is distributed and axis chosen is == to split - return reduce_vars_elementwise(output_shape) - else: - # singular axis given (axis) not equal to split direction (x.split) - lcl = torch.var(x._DNDarray__array, dim=axis, keepdim=False) - return factories.array( - lcl, - is_split=x.split if axis > x.split else x.split - 1, - dtype=x.dtype, - device=x.device, - ) - else: - raise TypeError("axis (axis) must be an int, tuple, list, etc.; currently it is {}. ") + return __moment_w_axis(torch.var, x, axis, reduce_vars_elementwise, bessel) From 861b6fbac14beb4b17af9cd2b93d4c1891dae104 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Fri, 26 Jun 2020 17:34:19 +0200 Subject: [PATCH 02/14] minor correction so skew --- heat/core/statistics.py | 114 ++++++++++++++++++++-------------------- 1 file changed, 58 insertions(+), 56 deletions(-) diff --git a/heat/core/statistics.py b/heat/core/statistics.py index c4054b4559..6c66a02e9d 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -25,6 +25,7 @@ "mean", "min", "minimum", + "skew", "std", "var", ] @@ -1062,55 +1063,7 @@ def mpi_argmin(a, b, _): MPI_ARGMIN = MPI.Op.Create(mpi_argmin, commute=True) -def std(x, axis=None, ddof=0, **kwargs): - """ - Calculates and returns the standard deviation of a tensor with the bessel correction. - If a axis is given, the variance will be taken in that direction. - - Parameters - ---------- - x : ht.DNDarray - Values for which the std is calculated for. - The dtype of x must be a float - axis : None, Int, iterable, defaults to None - Axis which the std is taken in. Default None calculates std of all data items. - ddof : int, optional - Delta Degrees of Freedom: the denominator implicitely used in the calculation is N - ddof, where N - represents the number of elements. Default: ddof=0. If ddof=1, the Bessel correction will be applied. - Setting ddof > 1 raises a NotImplementedError. - - Returns - ------- - stds : ht.DNDarray - The std/s, if split, then split in the same direction as x. - - Examples - -------- - >>> a = ht.random.randn(1,3) - >>> a - tensor([[ 0.3421, 0.5736, -2.2377]]) - >>> ht.std(a) - tensor(1.2742) - >>> a = ht.random.randn(4,4) - >>> a - tensor([[-1.0206, 0.3229, 1.1800, 1.5471], - [ 0.2732, -0.0965, -0.1087, -1.3805], - [ 0.2647, 0.5998, -0.1635, -0.0848], - [ 0.0343, 0.1618, -0.8064, -0.1031]]) - >>> ht.std(a, 0, ddof=1) - tensor([0.6157, 0.2918, 0.8324, 1.1996]) - >>> ht.std(a, 1, ddof=1) - tensor([1.1405, 0.7236, 0.3506, 0.4324]) - >>> ht.std(a, 1) - tensor([0.9877, 0.6267, 0.3037, 0.3745]) - """ - if not axis: - return np.sqrt(var(x, axis, ddof, **kwargs)) - else: - return exponential.sqrt(var(x, axis, ddof, **kwargs), out=None) - - -def skew(x, axis, unbiased=True): +def skew(x, axis=None, unbiased=True): bessel = unbiased def reduce_skews_elementwise(output_shape_i): @@ -1163,9 +1116,9 @@ def reduce_skews_elementwise(output_shape_i): return factories.array(ret) else: # case for full matrix calculation (axis is None) - mu_in = torch.mean(x._DNDarray__array) - var_in = torch.var(x._DNDarray__array, unbiased=unbiased) - skew_in = __torch_skew(x._DNDarray__array.float(), biased=unbiased) + mu_in = torch.mean(x._DNDarray__array.float()) + var_in = torch.var(x._DNDarray__array.float(), unbiased=unbiased) + skew_in = __torch_skew(x._DNDarray__array.float(), unbiased=unbiased) # Nan is returned when local tensor is empty if torch.isnan(skew_in): skew_in = 0.0 @@ -1192,6 +1145,54 @@ def reduce_skews_elementwise(output_shape_i): return __moment_w_axis(__torch_skew, x, axis, reduce_skews_elementwise, unbiased) +def std(x, axis=None, ddof=0, **kwargs): + """ + Calculates and returns the standard deviation of a tensor with the bessel correction. + If a axis is given, the variance will be taken in that direction. + + Parameters + ---------- + x : ht.DNDarray + Values for which the std is calculated for. + The dtype of x must be a float + axis : None, Int, iterable, defaults to None + Axis which the std is taken in. Default None calculates std of all data items. + ddof : int, optional + Delta Degrees of Freedom: the denominator implicitely used in the calculation is N - ddof, where N + represents the number of elements. Default: ddof=0. If ddof=1, the Bessel correction will be applied. + Setting ddof > 1 raises a NotImplementedError. + + Returns + ------- + stds : ht.DNDarray + The std/s, if split, then split in the same direction as x. + + Examples + -------- + >>> a = ht.random.randn(1,3) + >>> a + tensor([[ 0.3421, 0.5736, -2.2377]]) + >>> ht.std(a) + tensor(1.2742) + >>> a = ht.random.randn(4,4) + >>> a + tensor([[-1.0206, 0.3229, 1.1800, 1.5471], + [ 0.2732, -0.0965, -0.1087, -1.3805], + [ 0.2647, 0.5998, -0.1635, -0.0848], + [ 0.0343, 0.1618, -0.8064, -0.1031]]) + >>> ht.std(a, 0, ddof=1) + tensor([0.6157, 0.2918, 0.8324, 1.1996]) + >>> ht.std(a, 1, ddof=1) + tensor([1.1405, 0.7236, 0.3506, 0.4324]) + >>> ht.std(a, 1) + tensor([0.9877, 0.6267, 0.3037, 0.3745]) + """ + if not axis: + return np.sqrt(var(x, axis, ddof, **kwargs)) + else: + return exponential.sqrt(var(x, axis, ddof, **kwargs), out=None) + + def __moment_w_axis(function, x, axis, elementwise_function, unbiased=None): # helper for calculating statistical moment with a given axis # case for var in one dimension @@ -1262,10 +1263,11 @@ def __torch_skew(torch_tensor, dim=None, unbiased=False): # calculate the sample skewness of a torch tensor # return the bias corrected Fischer-Pearson standardized moment coefficient by default n = torch_tensor.numel() - diff = torch_tensor - torch.mean(torch_tensor, dim) - m3 = torch.true_divide(torch.pow(diff, 3), n) - m2 = torch.true_divide(torch.pow(diff, 2), n) - coeff = torch.sqrt(n * (n - 1)) / (n - 2) + diff = torch_tensor - torch.mean(torch_tensor) + m3 = torch.true_divide(torch.sum(torch.pow(diff, 3)), n) + m2 = torch.true_divide(torch.sum(torch.pow(diff, 2)), n) + print(n) + coeff = (n * (n - 1) / float((n - 2))) ** 0.5 if not unbiased: return torch.true_divide(m3, torch.pow(m2, 1.5)) return coeff * torch.true_divide(m3, torch.pow(m2, 1.5)) From 0d8ca88c48cf0a54325e5505a71be9236df940cb Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Mon, 29 Jun 2020 17:01:56 +0200 Subject: [PATCH 03/14] added skew function and tests, renamed gnumel to numel --- heat/core/dndarray.py | 7 +- heat/core/statistics.py | 254 ++++++++++++++++++----------- heat/core/tests/test_statistics.py | 100 +++++++++++- 3 files changed, 260 insertions(+), 101 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 648539c691..faf29f80f1 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -147,12 +147,12 @@ def size(self): number of total elements of the tensor """ try: - return np.prod(self.__gshape) + return torch.prod(torch.tensor(self.gshape, device=self.device.torch_device)).item() except TypeError: return 1 @property - def gnumel(self): + def numel(self): """ Returns @@ -3225,6 +3225,9 @@ def sinh(self, out=None): """ return trigonometrics.sinh(self, out) + def skew(self, axis=None, unbiased=True): + return statistics.skew(self, axis, unbiased) + def sqrt(self, out=None): """ Return the non-negative square-root of the tensor element-wise. diff --git a/heat/core/statistics.py b/heat/core/statistics.py index 6c66a02e9d..3ac223134b 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -3,6 +3,7 @@ import torch from .communication import MPI +from . import arithmetics from . import exponential from . import factories from . import linalg @@ -756,7 +757,7 @@ def reduce_means_elementwise(output_shape_i): return __moment_w_axis(torch.mean, x, axis, reduce_means_elementwise) -def __merge_moments(m1, m2, bessel=True): +def __merge_moments(m1, m2, unbiased=True): """ Merge two statistical moments. If the length of m1/m2 (must be equal) is == 3 then the second moment (variance) is merged. This function can be expanded to merge other moments according to Reference 1 as well. @@ -770,8 +771,8 @@ def __merge_moments(m1, m2, bessel=True): m2 : tuple Tuple of the moments to merge together, the 0th element is the moment to be merged. The tuple must be sorted in descending order of moments - bessel : bool - Flag for the use of the bessel correction for the calculation of the variance + unbiased : bool + Flag for the use of unbiased estimators (when available) Returns ------- @@ -797,7 +798,7 @@ def __merge_moments(m1, m2, bessel=True): return mu, n var1, var2 = m1[-3], m2[-3] - if bessel: + if unbiased: var_m = (var1 * (n1 - 1) + var2 * (n2 - 1) + (delta ** 2) * n1 * n2 / n) / (n - 1) else: var_m = (var1 * n1 + var2 * n2 + (delta ** 2) * n1 * n2 / n) / n @@ -806,13 +807,17 @@ def __merge_moments(m1, m2, bessel=True): return var_m, mu, n sk1, sk2 = m1[-4], m2[-4] + # print(sk1) + # print(sk2) dn = delta / n - if var_m != 0: # Skewness does nto exist of var is 0 - skew_m = ( - sk1 - + sk2 - + 3 * ((n1 * var2 - n2 * var1) + (dn ** 2) * (n1 * n2) * ((n1 ** 2) - (n2 ** 2))) * dn - ) + # todo: fix this condition + if all(var_m != 0): # Skewness does not exist if var is 0 + s1 = sk1 + sk2 + s2 = dn * (n1 * var2 - n2 * var1) / 6.0 + # print('n', dn / 6) + s3 = (dn ** 3) * n1 * n2 * (n1 ** 2 - n2 ** 2) + # print(s1, s2) + skew_m = s1 + s2 + s3 else: skew_m = None if len(m1) == 4: @@ -821,8 +826,8 @@ def __merge_moments(m1, m2, bessel=True): k1, k2 = m1[-5], m2[-5] if skew_m is None: return None, skew_m, var_m, mu, n - s1 = 4 * dn * (n1 * sk2 - n2 * sk1) - s2 = 6 * (dn ** 2) * ((n1 ** 2) * var2 + (n2 ** 2) * var1) + s1 = (1 / 24.0) * dn * (n1 * sk2 - n2 * sk1) + s2 = (1 / 12.0) * (dn ** 2) * ((n1 ** 2) * var2 + (n2 ** 2) * var1) s3 = (dn ** 4) * n1 * n2 * ((n1 ** 3) + (n2 ** 3)) k = k1 + k2 + s1 + s2 + s3 if len(m1) == 5: @@ -1064,90 +1069,135 @@ def mpi_argmin(a, b, _): def skew(x, axis=None, unbiased=True): - bessel = unbiased + """ + Compute the sample skewness of a data set. + + Parameters + ---------- + x : ht.DNDarray + Input array + axis : NoneType or Int or iterable + Axis along which skewness is calculated, Default is to compute over the whole array `x` + unbiased : Bool + if True (default) the calculations are corrected for bias - def reduce_skews_elementwise(output_shape_i): + """ + + def __reduce_skews_elementwise(output_shape_i): """ - Function to combine the calculated vars together. This does an element-wise update of the + Function to combine the calculated skews together. This does an element-wise update of the calculated vars to merge them together using the merge_vars function. This function operates - using x from the var function parameters. + using x from the var function parameters. Parameters ---------- output_shape_i : iterable Iterable with the dimensions of the output of the var function. - - Returns - ------- - variances : ht.DNDarray - The calculated variances. """ if x.lshape[x.split] != 0: mu = torch.mean(x._DNDarray__array, dim=axis) - var = torch.var(x._DNDarray__array, dim=axis, unbiased=bessel) - sk = __torch_skew(x._DNDarray__array, dim=axis, unbiased=unbiased) + var = torch.var(x._DNDarray__array, dim=axis, unbiased=unbiased) + skl = __torch_skew(x._DNDarray__array, dim=axis, unbiased=unbiased) else: mu = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) var = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) - sk = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) + skl = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) - sk_shape = list(sk.shape) if list(sk.shape) else [1] + sk_shape = list(skl.shape) if list(skl.shape) else [1] - tot = factories.zeros(([x.comm.size, 4] + sk_shape), dtype=x.dtype, device=x.device) - tot[x.comm.rank, 0, :] = sk - tot[x.comm.rank, 1, :] = var - tot[x.comm.rank, 2, :] = mu - tot[x.comm.rank, 3, :] = float(x.lshape[x.split]) - x.comm.Allreduce(MPI.IN_PLACE, tot, MPI.SUM) + rtot = factories.zeros(([x.comm.size, 4] + sk_shape), dtype=x.dtype, device=x.device) + rtot[x.comm.rank, 0, :] = skl + rtot[x.comm.rank, 1, :] = var + rtot[x.comm.rank, 2, :] = mu + rtot[x.comm.rank, 3, :] = float(x.lshape[x.split]) + x.comm.Allreduce(MPI.IN_PLACE, rtot, MPI.SUM) + # print(rtot[x.comm.rank, 0, :]) for i in range(1, x.comm.size): - tot[0, 0, :], tot[0, 1, :], tot[0, 2, :], tot[0, 3, :] = __merge_moments( - (tot[0, 0, :], tot[0, 1, :], tot[0, 2, :], tot[0, 3, :]), - (tot[i, 0, :], tot[i, 1, :], tot[0, 2, :], tot[0, 3, :]), - bessel=bessel, + rtot[0, 0, :], rtot[0, 1, :], rtot[0, 2, :], rtot[0, 3, :] = __merge_moments( + (rtot[0, 0, :], rtot[0, 1, :], rtot[0, 2, :], rtot[0, 3, :]), + (rtot[i, 0, :], rtot[i, 1, :], rtot[i, 2, :], rtot[i, 3, :]), + unbiased=unbiased, ) - return tot[0, 0, :][0] if tot[0, 0, :].size == 1 else tot[0, 0, :] + return rtot[0, 0, :][0] if rtot[0, 0, :].size == 1 else rtot[0, 0, :] # ---------------------------------------------------------------------------------------------- if axis is None: # no axis given - if not x.is_distributed(): # not distributed (full tensor on one node) - ret = __torch_skew(x._DNDarray__array.float(), unbiased=unbiased) - return factories.array(ret) - - else: # case for full matrix calculation (axis is None) - mu_in = torch.mean(x._DNDarray__array.float()) - var_in = torch.var(x._DNDarray__array.float(), unbiased=unbiased) - skew_in = __torch_skew(x._DNDarray__array.float(), unbiased=unbiased) - # Nan is returned when local tensor is empty - if torch.isnan(skew_in): - skew_in = 0.0 - if torch.isnan(var_in): - var_in = 0.0 - if torch.isnan(mu_in): - mu_in = 0.0 - - n = x.lnumel - tot = factories.zeros((x.comm.size, 4), dtype=x.dtype, device=x.device) - skew_proc = factories.zeros((x.comm.size, 4), dtype=x.dtype, device=x.device) - skew_proc[x.comm.rank] = skew_in, var_in, mu_in, float(n) - x.comm.Allreduce(skew_proc, tot, MPI.SUM) - - for i in range(1, x.comm.size): - tot[0, 0], tot[0, 1], tot[0, 2], tot[0, 3] = __merge_moments( - (tot[0, 0], tot[0, 1], tot[0, 2], tot[0, 3]), - (tot[i, 0], tot[i, 1], tot[i, 2], tot[0, 3]), - bessel=bessel, - ) - return tot[0][0] - - else: # axis is given - return __moment_w_axis(__torch_skew, x, axis, reduce_skews_elementwise, unbiased) + # todo: determine if this is a valid (and fast implementation) + mu = mean(x) + # diff = x + # diff._DNDarray__array -= float(mu.item()) + diff = x - mu + n = x.numel + + m3 = arithmetics.sum(arithmetics.pow(diff, 3)) / n + m2 = arithmetics.sum(arithmetics.pow(diff, 2)) / n + res = m3 / arithmetics.pow(m2, 1.5) + if unbiased: + res *= ((n * (n - 1)) ** 0.5) / (n - 2.0) + return res.item() + # if not x.is_distributed(): # not distributed (full tensor on one node) + # print('here') + # ret = __torch_skew(x._DNDarray__array.float(), unbiased=unbiased) + # return factories.array(ret) + # + # else: # case for full matrix calculation (axis is None) + # # todo: time this to see if its faster to do this than merging + # + # m2 = torch.true_divide(torch.sum(torch.pow(diff, 2)), n) + # work_arr = x._DNDarray__array.float() + # mu_in = torch.mean(work_arr) + # # work_arr -= mu_in + # var_in = torch.var(work_arr, unbiased=unbiased) + # skew_in = __torch_skew(work_arr, unbiased=unbiased) + # # Nan is returned when local tensor is empty + # if torch.isnan(skew_in): + # skew_in = 0.0 + # if torch.isnan(var_in): + # var_in = 0.0 + # if torch.isnan(mu_in): + # mu_in = 0.0 + # + # n = x.lnumel + # tot = factories.zeros( + # (x.comm.size, 4), dtype=types.promote_types(x.dtype, types.float), device=x.device + # ) + # skew_proc = factories.zeros( + # (x.comm.size, 4), dtype=types.promote_types(x.dtype, types.float), device=x.device + # ) + # skew_proc[x.comm.rank] = skew_in, var_in, mu_in, n + # x.comm.Allreduce(skew_proc, tot, MPI.SUM) + # print(skew_proc[x.comm.rank]) + # + # for i in range(1, x.comm.size): + # tot[0, 0], tot[0, 1], tot[0, 2], tot[0, 3] = __merge_moments( + # (tot[0, 0], tot[0, 1], tot[0, 2], tot[0, 3]), + # (tot[i, 0], tot[i, 1], tot[i, 2], tot[0, 3]), + # unbiased=unbiased, + # ) + # return tot[0][0] + + elif isinstance(axis, int) and x.split == axis: # axis is given + if axis > 0: + diff = x - mean(x, axis=axis).expand_dims(axis) + else: + diff = x - mean(x, axis=axis) + n = float(x.shape[axis]) + + m3 = arithmetics.sum(arithmetics.pow(diff, 3.0), axis) / n + m2 = arithmetics.sum(arithmetics.pow(diff, 2.0), axis) / n + res = m3 / arithmetics.pow(m2, 1.5) + if unbiased: + res *= ((n * (n - 1.0)) ** 0.5) / (n - 2.0) + return res + else: + return __moment_w_axis(__torch_skew, x, axis, __reduce_skews_elementwise, unbiased) def std(x, axis=None, ddof=0, **kwargs): """ - Calculates and returns the standard deviation of a tensor with the bessel correction. + Calculates and returns the standard deviation of a tensor. The default estimator is biased. If a axis is given, the variance will be taken in that direction. Parameters @@ -1194,8 +1244,7 @@ def std(x, axis=None, ddof=0, **kwargs): def __moment_w_axis(function, x, axis, elementwise_function, unbiased=None): - # helper for calculating statistical moment with a given axis - # case for var in one dimension + # helper for calculating a statistical moment with a given axis kwargs = {"dim": axis} if unbiased: kwargs["unbiased"] = unbiased @@ -1262,14 +1311,19 @@ def __moment_w_axis(function, x, axis, elementwise_function, unbiased=None): def __torch_skew(torch_tensor, dim=None, unbiased=False): # calculate the sample skewness of a torch tensor # return the bias corrected Fischer-Pearson standardized moment coefficient by default - n = torch_tensor.numel() - diff = torch_tensor - torch.mean(torch_tensor) - m3 = torch.true_divide(torch.sum(torch.pow(diff, 3)), n) - m2 = torch.true_divide(torch.sum(torch.pow(diff, 2)), n) - print(n) - coeff = (n * (n - 1) / float((n - 2))) ** 0.5 + if dim is not None: + n = torch_tensor.shape[dim] + diff = torch_tensor - torch.mean(torch_tensor, dim=dim, keepdim=True) + m3 = torch.true_divide(torch.sum(torch.pow(diff, 3), dim=dim), n) + m2 = torch.true_divide(torch.sum(torch.pow(diff, 2), dim=dim), n) + else: + n = torch_tensor.numel() + diff = torch_tensor - torch.mean(torch_tensor) + m3 = torch.true_divide(torch.sum(torch.pow(diff, 3)), n) + m2 = torch.true_divide(torch.sum(torch.pow(diff, 2)), n) if not unbiased: return torch.true_divide(m3, torch.pow(m2, 1.5)) + coeff = ((n * (n - 1)) ** 0.5) / (n - 2.0) return coeff * torch.true_divide(m3, torch.pow(m2, 1.5)) @@ -1277,10 +1331,16 @@ def __torch_kurtosis(torch_tensor, dim=None, excess=True): # calculate the sample kurtosis of a torch tensor, Pearson's definition # returns the excess Kurtosis if excess is True # there is not unbiased estimator for Kurtosis - n = torch_tensor.numel() - diff = torch_tensor - torch.mean(torch_tensor, dim) - m4 = torch.true_divide(torch.pow(diff, 4), n) - m2 = torch.true_divide(torch.pow(diff, 2), n) + if dim is not None: + n = torch_tensor.shape[dim] + diff = torch_tensor - torch.mean(torch_tensor, dim=dim, keepdim=True) + m4 = torch.true_divide(torch.sum(torch.pow(diff, 4), dim=dim), n) + m2 = torch.true_divide(torch.sum(torch.pow(diff, 2), dim=dim), n) + else: + n = torch_tensor.numel() + diff = torch_tensor - torch.mean(torch_tensor) + m4 = torch.true_divide(torch.pow(diff, 4), n) + m2 = torch.true_divide(torch.pow(diff, 2), n) k = torch.true_divide(m4, torch.pow(m2, 2)) if excess: k -= 3.0 @@ -1289,8 +1349,8 @@ def __torch_kurtosis(torch_tensor, dim=None, excess=True): def var(x, axis=None, ddof=0, **kwargs): """ - Calculates and returns the variance of a tensor. If an axis is given, the variance will be - taken in that direction. + Calculates and returns the variance of a tensor. The default estimator is biased. + If an axis is given, the variance will be taken in that direction. Parameters ---------- @@ -1359,9 +1419,9 @@ def var(x, axis=None, ddof=0, **kwargs): raise ValueError("Expected ddof=0 or ddof=1, got {}".format(ddof)) else: if kwargs.get("bessel"): - bessel = kwargs.get("bessel") + unbiased = kwargs.get("bessel") else: - bessel = bool(ddof) + unbiased = bool(ddof) def reduce_vars_elementwise(output_shape_i): """ @@ -1382,38 +1442,36 @@ def reduce_vars_elementwise(output_shape_i): if x.lshape[x.split] != 0: mu = torch.mean(x._DNDarray__array, dim=axis) - var = torch.var(x._DNDarray__array, dim=axis, unbiased=bessel) + var = torch.var(x._DNDarray__array, dim=axis, unbiased=unbiased) else: mu = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) var = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) var_shape = list(var.shape) if list(var.shape) else [1] - var_tot = factories.zeros(([x.comm.size, 2] + var_shape), dtype=x.dtype, device=x.device) - n_tot = factories.zeros(x.comm.size, device=x.device) + var_tot = factories.zeros(([x.comm.size, 3] + var_shape), dtype=x.dtype, device=x.device) var_tot[x.comm.rank, 0, :] = var var_tot[x.comm.rank, 1, :] = mu - n_tot[x.comm.rank] = float(x.lshape[x.split]) + var_tot[x.comm.rank, 2, :] = float(x.lshape[x.split]) x.comm.Allreduce(MPI.IN_PLACE, var_tot, MPI.SUM) - x.comm.Allreduce(MPI.IN_PLACE, n_tot, MPI.SUM) for i in range(1, x.comm.size): - var_tot[0, 0, :], var_tot[0, 1, :], n_tot[0] = __merge_moments( - (var_tot[0, 0, :], var_tot[0, 1, :], n_tot[0]), - (var_tot[i, 0, :], var_tot[i, 1, :], n_tot[i]), - bessel=bessel, + var_tot[0, 0, :], var_tot[0, 1, :], var_tot[0, 2, :] = __merge_moments( + (var_tot[0, 0, :], var_tot[0, 1, :], var_tot[0, 2, :]), + (var_tot[i, 0, :], var_tot[i, 1, :], var_tot[i, 2, :]), + unbiased=unbiased, ) return var_tot[0, 0, :][0] if var_tot[0, 0, :].size == 1 else var_tot[0, 0, :] # ---------------------------------------------------------------------------------------------- if axis is None: # no axis given if not x.is_distributed(): # not distributed (full tensor on one node) - ret = torch.var(x._DNDarray__array.float(), unbiased=bessel) + ret = torch.var(x._DNDarray__array.float(), unbiased=unbiased) return factories.array(ret) else: # case for full matrix calculation (axis is None) mu_in = torch.mean(x._DNDarray__array) - var_in = torch.var(x._DNDarray__array, unbiased=bessel) + var_in = torch.var(x._DNDarray__array, unbiased=unbiased) # Nan is returned when local tensor is empty if torch.isnan(var_in): var_in = 0.0 @@ -1430,9 +1488,9 @@ def reduce_vars_elementwise(output_shape_i): var_tot[0, 0], var_tot[0, 1], var_tot[0, 2] = __merge_moments( (var_tot[0, 0], var_tot[0, 1], var_tot[0, 2]), (var_tot[i, 0], var_tot[i, 1], var_tot[i, 2]), - bessel=bessel, + unbiased=unbiased, ) return var_tot[0][0] else: # axis is given - return __moment_w_axis(torch.var, x, axis, reduce_vars_elementwise, bessel) + return __moment_w_axis(torch.var, x, axis, reduce_vars_elementwise, unbiased) diff --git a/heat/core/tests/test_statistics.py b/heat/core/tests/test_statistics.py index 40675f7ec4..e74a0c8d43 100644 --- a/heat/core/tests/test_statistics.py +++ b/heat/core/tests/test_statistics.py @@ -3,6 +3,7 @@ import unittest from itertools import combinations import os +from scipy import stats as ss import heat as ht from .test_suites.basic_test import TestCase @@ -836,6 +837,103 @@ def test_minimum(self): with self.assertRaises(ValueError): ht.minimum(random_volume_1, random_volume_2, out=output) + def test_skew(self): + # x = ht.zeros((2, 3, 4)) + # with self.assertRaises(ValueError): + # x.skew(axis=10) + # with self.assertRaises(ValueError): + # x.skew(axis=[4]) + # with self.assertRaises(ValueError): + # x.skew(axis=[-4]) + # with self.assertRaises(TypeError): + # ht.skew(x, axis="01") + # with self.assertRaises(ValueError): + # ht.skew(x, axis=(0, "10")) + # with self.assertRaises(ValueError): + # ht.skew(x, axis=(0, 0)) + # with self.assertRaises(ValueError): + # ht.skew(x, axis=torch.Tensor([0, 0])) + + a = ht.arange(1, 5) + self.assertEqual(a.skew(), 0.0) + + # todo: 1 dim, 2 dims, 3 dims, 1 axis, 2 axes, 3 axes, + # split != axis, split == axis, split in axis, split not in axis + # comp with numpy for all, rand and randn + + def __split_calc(ht_split, axis): + if ht_split is None: + return None + else: + sp = ht_split if axis > ht_split else ht_split - 1 + if axis == ht_split: + sp = None + return sp + + # 1 dim + ht_data = ht.random.rand(50) + np_data = ht_data.copy().numpy() + np_skew32 = ht.array((ss.skew(np_data, bias=False)), dtype=ht_data.dtype) + self.assertAlmostEqual(ht.skew(ht_data), np_skew32.item(), places=5) + ht_data = ht.resplit(ht_data, 0) + self.assertAlmostEqual(ht.skew(ht_data), np_skew32.item(), places=5) + + # 2 dim + ht_data = ht.random.rand(50, 30) + np_data = ht_data.copy().numpy() + np_skew32 = ss.skew(np_data, axis=None, bias=False) + self.assertAlmostEqual(ht.skew(ht_data) - np_skew32, 0, places=5) + ht_data = ht.resplit(ht_data, 0) + for ax in range(2): + np_skew32 = ht.array((ss.skew(np_data, axis=ax, bias=False)), dtype=ht_data.dtype) + ht_skew = ht.skew(ht_data, axis=ax) + self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + sp = __split_calc(ht_data.split, ax) + self.assertEqual(ht_skew.split, sp) + ht_data = ht.resplit(ht_data, 1) + for ax in range(2): + np_skew32 = ht.array((ss.skew(np_data, axis=ax, bias=False)), dtype=ht_data.dtype) + ht_skew = ht.skew(ht_data, axis=ax) + self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + sp = __split_calc(ht_data.split, ax) + self.assertEqual(ht_skew.split, sp) + + # 2 dim float64 + ht_data = ht.random.rand(50, 30, dtype=ht.float64) + np_data = ht_data.copy().numpy() + np_skew32 = ss.skew(np_data, axis=None, bias=False) + self.assertAlmostEqual(ht.skew(ht_data) - np_skew32, 0, places=5) + ht_data = ht.resplit(ht_data, 0) + for ax in range(2): + np_skew32 = ht.array((ss.skew(np_data, axis=ax, bias=False)), dtype=ht_data.dtype) + ht_skew = ht.skew(ht_data, axis=ax) + self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + sp = __split_calc(ht_data.split, ax) + self.assertEqual(ht_skew.split, sp) + self.assertEqual(ht_skew.dtype, ht.float64) + ht_data = ht.resplit(ht_data, 1) + for ax in range(2): + np_skew32 = ht.array((ss.skew(np_data, axis=ax, bias=False)), dtype=ht_data.dtype) + ht_skew = ht.skew(ht_data, axis=ax) + self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + sp = __split_calc(ht_data.split, ax) + self.assertEqual(ht_skew.split, sp) + self.assertEqual(ht_skew.dtype, ht.float64) + + # 3 dim + ht_data = ht.random.rand(50, 30, 16) + np_data = ht_data.copy().numpy() + np_skew32 = ss.skew(np_data, axis=None, bias=False) + self.assertAlmostEqual(ht.skew(ht_data) - np_skew32, 0, places=5) + for split in range(3): + ht_data = ht.resplit(ht_data, split) + for ax in range(3): + np_skew32 = ht.array((ss.skew(np_data, axis=ax, bias=False)), dtype=ht_data.dtype) + ht_skew = ht.skew(ht_data, axis=ax) + self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + sp = __split_calc(ht_data.split, ax) + self.assertEqual(ht_skew.split, sp) + def test_std(self): # test basics a = ht.arange(1, 5) @@ -881,7 +979,7 @@ def test_var(self): with self.assertRaises(ValueError): ht.var(x, ddof=-2) with self.assertRaises(ValueError): - ht.mean(x, axis=torch.Tensor([0, 0])) + ht.var(x, axis=torch.Tensor([0, 0])) a = ht.arange(1, 5) self.assertEqual(a.var(ddof=1), 1.666666666666666) From 80577f70cf7c91822569288817296bfd858cff73 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 30 Jun 2020 11:20:51 +0200 Subject: [PATCH 04/14] Ghost == busted. distance uses max value instead of inf --- heat/cluster/spectral.py | 8 +++++++- heat/core/linalg/solver.py | 4 ++-- heat/core/statistics.py | 4 ++-- heat/core/tests/test_dndarray.py | 10 +++++----- heat/spatial/distance.py | 3 ++- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/heat/cluster/spectral.py b/heat/cluster/spectral.py index 04c17c32e4..6d228d08c2 100644 --- a/heat/cluster/spectral.py +++ b/heat/cluster/spectral.py @@ -107,7 +107,13 @@ def _spectral_embedding(self, X): """ L = self._laplacian.construct(X) # 3. Eigenvalue and -vector calculation via Lanczos Algorithm - v0 = ht.ones((L.shape[0],), dtype=L.dtype, split=0, device=L.device) / math.sqrt(L.shape[0]) + v0 = ht.full( + (L.shape[0],), + fill_value=1.0 / math.sqrt(L.shape[0]), + dtype=L.dtype, + split=0, + device=L.device, + ) V, T = ht.lanczos(L, self.n_lanczos, v0) # 4. Calculate and Sort Eigenvalues and Eigenvectors of tridiagonal matrix T diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index ab070d6246..66a3733c00 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -155,13 +155,13 @@ def lanczos(A, m, v0=None, V_out=None, T_out=None): b = torch.dot(vi_loc, vi_loc) A.comm.Allreduce(ht.communication.MPI.IN_PLACE, a, ht.communication.MPI.SUM) A.comm.Allreduce(ht.communication.MPI.IN_PLACE, b, ht.communication.MPI.SUM) - vr._DNDarray__array = vr._DNDarray__array - a / b * vi_loc + vr._DNDarray__array -= a / b * vi_loc vi = vr / ht.norm(vr) w = ht.matmul(A, vi) alpha = ht.dot(w, vi) - w = w - alpha * vi - beta * V[:, i - 1] + w -= alpha * vi - beta * V[:, i - 1] T[i - 1, i] = beta T[i, i - 1] = beta diff --git a/heat/core/statistics.py b/heat/core/statistics.py index 3ac223134b..60940b6ba7 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -313,7 +313,7 @@ def average(x, axis=None, weights=None, returned=False): if weights is None: result = mean(x, axis) - num_elements = x.gnumel / result.gnumel + num_elements = x.numel / result.numel cumwgt = factories.empty(1, dtype=result.dtype) cumwgt._DNDarray__array = num_elements else: @@ -434,7 +434,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None): norm = x.shape[1] - ddof # find normalization: if norm <= 0: - raise ValueError("ddof >= number of elements in m, {} {}".format(ddof, m.gnumel)) + raise ValueError("ddof >= number of elements in m, {} {}".format(ddof, m.numel)) x -= avg.expand_dims(1) c = linalg.dot(x, x.T) c /= norm diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index ea6a1b91e6..92ab8e1209 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1011,22 +1011,22 @@ def test_setitem_getitem(self): a[0] = ht.array([6, 6, 6, 6, 6]) self.assertTrue((a[ht.array((0,))] == 6).all()) - def test_size_gnumel(self): + def test_size_numel(self): a = ht.zeros((10, 10, 10), split=None) self.assertEqual(a.size, 10 * 10 * 10) - self.assertEqual(a.gnumel, 10 * 10 * 10) + self.assertEqual(a.numel, 10 * 10 * 10) a = ht.zeros((10, 10, 10), split=0) self.assertEqual(a.size, 10 * 10 * 10) - self.assertEqual(a.gnumel, 10 * 10 * 10) + self.assertEqual(a.numel, 10 * 10 * 10) a = ht.zeros((10, 10, 10), split=1) self.assertEqual(a.size, 10 * 10 * 10) - self.assertEqual(a.gnumel, 10 * 10 * 10) + self.assertEqual(a.numel, 10 * 10 * 10) a = ht.zeros((10, 10, 10), split=2) self.assertEqual(a.size, 10 * 10 * 10) - self.assertEqual(a.gnumel, 10 * 10 * 10) + self.assertEqual(a.numel, 10 * 10 * 10) self.assertEqual(ht.array(0).size, 1) diff --git a/heat/spatial/distance.py b/heat/spatial/distance.py index 25de1705ca..0a4184a8f1 100644 --- a/heat/spatial/distance.py +++ b/heat/spatial/distance.py @@ -69,7 +69,8 @@ def _quadratic_expand(x, y): y_norm = (y ** 2).sum(1).view(1, -1) dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) - return torch.clamp(dist, 0.0, np.inf) + info = torch.finfo(dist.dtype) + return torch.clamp(dist, 0.0, info.max) def _gaussian(x, y, sigma=1.0): From 6413aa43fc8e292e9a5210c8911db1e2c0b9f440 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 30 Jun 2020 11:49:13 +0200 Subject: [PATCH 05/14] added based off --- heat/core/statistics.py | 145 ++++++++++++++++++++++++++++------------ 1 file changed, 102 insertions(+), 43 deletions(-) diff --git a/heat/core/statistics.py b/heat/core/statistics.py index 60940b6ba7..84f78bd1a2 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -21,6 +21,7 @@ "argmin", "average", "cov", + "kurtosis", "max", "maximum", "mean", @@ -441,6 +442,104 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None): return c +def kurtosis(x, axis=None, unbiased=True, Fischer=True): + """ + Compute the kurtosis (Fisher or Pearson) of a dataset. + + Kurtosis is the fourth central moment divided by the square of the variance. + If Fisher’s definition is used, then 3.0 is subtracted from the result to give 0.0 for a normal distribution. + + If unbiased is True (defualt) then the kurtosis is calculated using k statistics to + eliminate bias coming from biased moment estimators + + Parameters + ---------- + x : ht.DNDarray + Input array + axis : NoneType or Int or iterable + Axis along which skewness is calculated, Default is to compute over the whole array `x` + unbiased : Bool + if True (default) the calculations are corrected for bias + Fischer : bool + Wheather use Fischer's definition or not, if true 3. is subtracted from the result + + """ + + def __reduce_kurts_elementwise(output_shape_i): + """ + Function to combine the calculated skews together. This does an element-wise update of the + calculated vars to merge them together using the merge_vars function. This function operates + using x from the var function parameters. + + Parameters + ---------- + output_shape_i : iterable + Iterable with the dimensions of the output of the var function. + """ + + if x.lshape[x.split] != 0: + mu = torch.mean(x._DNDarray__array, dim=axis) + var = torch.var(x._DNDarray__array, dim=axis, unbiased=unbiased) + skl = __torch_skew(x._DNDarray__array, dim=axis, unbiased=unbiased) + kurtl = __torch_kurtosis(x._DNDarray__array, dim=axis, Fischer=Fischer) + else: + mu = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) + var = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) + skl = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) + kurtl = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) + + sk_shape = list(skl.shape) if list(skl.shape) else [1] + + rtot = factories.zeros(([x.comm.size, 4] + sk_shape), dtype=x.dtype, device=x.device) + rtot[x.comm.rank, 0, :] = kurtl + rtot[x.comm.rank, 1, :] = skl + rtot[x.comm.rank, 2, :] = var + rtot[x.comm.rank, 3, :] = mu + rtot[x.comm.rank, 4, :] = float(x.lshape[x.split]) + x.comm.Allreduce(MPI.IN_PLACE, rtot, MPI.SUM) + + # print(rtot[x.comm.rank, 0, :]) + for i in range(1, x.comm.size): + rtot[0, 0, :], rtot[0, 1, :], rtot[0, 2, :], rtot[0, 3, :], rtot[ + 0, 4, : + ] = __merge_moments( + (rtot[0, 0, :], rtot[0, 1, :], rtot[0, 2, :], rtot[0, 3, :], rtot[0, 4, :]), + (rtot[i, 0, :], rtot[i, 1, :], rtot[i, 2, :], rtot[i, 3, :], rtot[i, 4, :]), + unbiased=unbiased, + ) + return rtot[0, 0, :][0] if rtot[0, 0, :].size == 1 else rtot[0, 0, :] + + # ---------------------------------------------------------------------------------------------- + if axis is None: # no axis given + # todo: determine if this is a valid (and fast implementation) + mu = mean(x) + diff = x - mu + n = x.numel + + m4 = arithmetics.sum(arithmetics.pow(diff, 4.0)) / n + m2 = arithmetics.sum(arithmetics.pow(diff, 2.0)) / n + res = m4 / arithmetics.pow(m2, 2.0) + if Fischer: + res -= 3.0 + return res.item() + + elif isinstance(axis, int) and x.split == axis: # axis is given + if axis > 0: + diff = x - mean(x, axis=axis).expand_dims(axis) + else: + diff = x - mean(x, axis=axis) + n = float(x.shape[axis]) + # possible speedup to be had here: + m4 = arithmetics.sum(arithmetics.pow(diff, 4.0), axis) / n + m2 = arithmetics.sum(arithmetics.pow(diff, 2.0), axis) / n + res = m4 / arithmetics.pow(m2, 2.0) + if Fischer: + res -= 3.0 + return res + else: + return __moment_w_axis(__torch_kurtosis, x, axis, __reduce_kurts_elementwise, Fischer) + + def max(x, axis=None, out=None, keepdim=None): # TODO: initial : scalar, optional Issue #101 """ @@ -1137,46 +1236,6 @@ def __reduce_skews_elementwise(output_shape_i): if unbiased: res *= ((n * (n - 1)) ** 0.5) / (n - 2.0) return res.item() - # if not x.is_distributed(): # not distributed (full tensor on one node) - # print('here') - # ret = __torch_skew(x._DNDarray__array.float(), unbiased=unbiased) - # return factories.array(ret) - # - # else: # case for full matrix calculation (axis is None) - # # todo: time this to see if its faster to do this than merging - # - # m2 = torch.true_divide(torch.sum(torch.pow(diff, 2)), n) - # work_arr = x._DNDarray__array.float() - # mu_in = torch.mean(work_arr) - # # work_arr -= mu_in - # var_in = torch.var(work_arr, unbiased=unbiased) - # skew_in = __torch_skew(work_arr, unbiased=unbiased) - # # Nan is returned when local tensor is empty - # if torch.isnan(skew_in): - # skew_in = 0.0 - # if torch.isnan(var_in): - # var_in = 0.0 - # if torch.isnan(mu_in): - # mu_in = 0.0 - # - # n = x.lnumel - # tot = factories.zeros( - # (x.comm.size, 4), dtype=types.promote_types(x.dtype, types.float), device=x.device - # ) - # skew_proc = factories.zeros( - # (x.comm.size, 4), dtype=types.promote_types(x.dtype, types.float), device=x.device - # ) - # skew_proc[x.comm.rank] = skew_in, var_in, mu_in, n - # x.comm.Allreduce(skew_proc, tot, MPI.SUM) - # print(skew_proc[x.comm.rank]) - # - # for i in range(1, x.comm.size): - # tot[0, 0], tot[0, 1], tot[0, 2], tot[0, 3] = __merge_moments( - # (tot[0, 0], tot[0, 1], tot[0, 2], tot[0, 3]), - # (tot[i, 0], tot[i, 1], tot[i, 2], tot[0, 3]), - # unbiased=unbiased, - # ) - # return tot[0][0] elif isinstance(axis, int) and x.split == axis: # axis is given if axis > 0: @@ -1184,7 +1243,7 @@ def __reduce_skews_elementwise(output_shape_i): else: diff = x - mean(x, axis=axis) n = float(x.shape[axis]) - + # possible speedup to be had here: m3 = arithmetics.sum(arithmetics.pow(diff, 3.0), axis) / n m2 = arithmetics.sum(arithmetics.pow(diff, 2.0), axis) / n res = m3 / arithmetics.pow(m2, 1.5) @@ -1327,7 +1386,7 @@ def __torch_skew(torch_tensor, dim=None, unbiased=False): return coeff * torch.true_divide(m3, torch.pow(m2, 1.5)) -def __torch_kurtosis(torch_tensor, dim=None, excess=True): +def __torch_kurtosis(torch_tensor, dim=None, Fischer=True): # calculate the sample kurtosis of a torch tensor, Pearson's definition # returns the excess Kurtosis if excess is True # there is not unbiased estimator for Kurtosis @@ -1342,7 +1401,7 @@ def __torch_kurtosis(torch_tensor, dim=None, excess=True): m4 = torch.true_divide(torch.pow(diff, 4), n) m2 = torch.true_divide(torch.pow(diff, 2), n) k = torch.true_divide(m4, torch.pow(m2, 2)) - if excess: + if Fischer: k -= 3.0 return k From ac0a38eae251b008da9a9b5114ef57e8ad9ba913 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 30 Jun 2020 11:50:22 +0200 Subject: [PATCH 06/14] added based off --- heat/core/statistics.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/heat/core/statistics.py b/heat/core/statistics.py index 84f78bd1a2..732f12a1a1 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -500,9 +500,7 @@ def __reduce_kurts_elementwise(output_shape_i): # print(rtot[x.comm.rank, 0, :]) for i in range(1, x.comm.size): - rtot[0, 0, :], rtot[0, 1, :], rtot[0, 2, :], rtot[0, 3, :], rtot[ - 0, 4, : - ] = __merge_moments( + rtot[0] = __merge_moments( (rtot[0, 0, :], rtot[0, 1, :], rtot[0, 2, :], rtot[0, 3, :], rtot[0, 4, :]), (rtot[i, 0, :], rtot[i, 1, :], rtot[i, 2, :], rtot[i, 3, :], rtot[i, 4, :]), unbiased=unbiased, @@ -1214,7 +1212,7 @@ def __reduce_skews_elementwise(output_shape_i): # print(rtot[x.comm.rank, 0, :]) for i in range(1, x.comm.size): - rtot[0, 0, :], rtot[0, 1, :], rtot[0, 2, :], rtot[0, 3, :] = __merge_moments( + rtot[0] = __merge_moments( (rtot[0, 0, :], rtot[0, 1, :], rtot[0, 2, :], rtot[0, 3, :]), (rtot[i, 0, :], rtot[i, 1, :], rtot[i, 2, :], rtot[i, 3, :]), unbiased=unbiased, From 7d73f197911deabf6a176b8b3a30e65951ddd784 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 30 Jun 2020 14:10:27 +0200 Subject: [PATCH 07/14] added kurtosis tests and expanded skew with bias flag --- heat/core/dndarray.py | 24 +++ heat/core/statistics.py | 41 +++-- heat/core/tests/test_statistics.py | 254 ++++++++++++++++++++--------- 3 files changed, 227 insertions(+), 92 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index faf29f80f1..cbf0fc5004 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -1702,6 +1702,30 @@ def item(self): """ return self.__array.item() + def kurtosis(self, axis=None, unbiased=True, Fischer=True): + """ + Compute the kurtosis (Fisher or Pearson) of a dataset. + + Kurtosis is the fourth central moment divided by the square of the variance. + If Fisher’s definition is used, then 3.0 is subtracted from the result to give 0.0 for a normal distribution. + + If unbiased is True (defualt) then the kurtosis is calculated using k statistics to + eliminate bias coming from biased moment estimators + + Parameters + ---------- + x : ht.DNDarray + Input array + axis : NoneType or Int or iterable + Axis along which skewness is calculated, Default is to compute over the whole array `x` + unbiased : Bool + if True (default) the calculations are corrected for bias + Fischer : bool + Wheather use Fischer's definition or not, if true 3. is subtracted from the result + + """ + return statistics.kurtosis(self, axis, unbiased, Fischer) + def __le__(self, other): """ Element-wise rich comparison of relation "less than or equal" with values from second operand (scalar or tensor) diff --git a/heat/core/statistics.py b/heat/core/statistics.py index 732f12a1a1..d84f379c0b 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -481,7 +481,9 @@ def __reduce_kurts_elementwise(output_shape_i): mu = torch.mean(x._DNDarray__array, dim=axis) var = torch.var(x._DNDarray__array, dim=axis, unbiased=unbiased) skl = __torch_skew(x._DNDarray__array, dim=axis, unbiased=unbiased) - kurtl = __torch_kurtosis(x._DNDarray__array, dim=axis, Fischer=Fischer) + kurtl = __torch_kurtosis( + x._DNDarray__array, dim=axis, unbiased=unbiased, Fischer=Fischer + ) else: mu = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) var = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) @@ -490,7 +492,7 @@ def __reduce_kurts_elementwise(output_shape_i): sk_shape = list(skl.shape) if list(skl.shape) else [1] - rtot = factories.zeros(([x.comm.size, 4] + sk_shape), dtype=x.dtype, device=x.device) + rtot = factories.zeros(([x.comm.size, 5] + sk_shape), dtype=x.dtype, device=x.device) rtot[x.comm.rank, 0, :] = kurtl rtot[x.comm.rank, 1, :] = skl rtot[x.comm.rank, 2, :] = var @@ -505,7 +507,10 @@ def __reduce_kurts_elementwise(output_shape_i): (rtot[i, 0, :], rtot[i, 1, :], rtot[i, 2, :], rtot[i, 3, :], rtot[i, 4, :]), unbiased=unbiased, ) - return rtot[0, 0, :][0] if rtot[0, 0, :].size == 1 else rtot[0, 0, :] + res = rtot[0, 0, :][0] if rtot[0, 0, :].size == 1 else rtot[0, 0, :] + if unbiased: + res = ((n - 1.0) / ((n - 2.0) * (n - 3.0))) * ((n + 1.0) * res - 3.0 * (n - 1.0)) + 3.0 + return res # ---------------------------------------------------------------------------------------------- if axis is None: # no axis given @@ -517,6 +522,8 @@ def __reduce_kurts_elementwise(output_shape_i): m4 = arithmetics.sum(arithmetics.pow(diff, 4.0)) / n m2 = arithmetics.sum(arithmetics.pow(diff, 2.0)) / n res = m4 / arithmetics.pow(m2, 2.0) + if unbiased: + res = ((n - 1.0) / ((n - 2.0) * (n - 3.0))) * ((n + 1.0) * res - 3 * (n - 1.0)) + 3.0 if Fischer: res -= 3.0 return res.item() @@ -531,11 +538,15 @@ def __reduce_kurts_elementwise(output_shape_i): m4 = arithmetics.sum(arithmetics.pow(diff, 4.0), axis) / n m2 = arithmetics.sum(arithmetics.pow(diff, 2.0), axis) / n res = m4 / arithmetics.pow(m2, 2.0) + if unbiased: + res = ((n - 1.0) / ((n - 2.0) * (n - 3.0))) * ((n + 1.0) * res - 3 * (n - 1.0)) + 3.0 if Fischer: res -= 3.0 return res else: - return __moment_w_axis(__torch_kurtosis, x, axis, __reduce_kurts_elementwise, Fischer) + return __moment_w_axis( + __torch_kurtosis, x, axis, __reduce_kurts_elementwise, unbiased, Fischer + ) def max(x, axis=None, out=None, keepdim=None): @@ -1300,11 +1311,13 @@ def std(x, axis=None, ddof=0, **kwargs): return exponential.sqrt(var(x, axis, ddof, **kwargs), out=None) -def __moment_w_axis(function, x, axis, elementwise_function, unbiased=None): +def __moment_w_axis(function, x, axis, elementwise_function, unbiased=None, Fischer=None): # helper for calculating a statistical moment with a given axis kwargs = {"dim": axis} if unbiased: kwargs["unbiased"] = unbiased + if Fischer: + kwargs["Fischer"] = Fischer output_shape = list(x.shape) if isinstance(axis, (list, tuple, dndarray.DNDarray, torch.Tensor)): @@ -1384,24 +1397,26 @@ def __torch_skew(torch_tensor, dim=None, unbiased=False): return coeff * torch.true_divide(m3, torch.pow(m2, 1.5)) -def __torch_kurtosis(torch_tensor, dim=None, Fischer=True): +def __torch_kurtosis(torch_tensor, dim=None, Fischer=True, unbiased=False): # calculate the sample kurtosis of a torch tensor, Pearson's definition # returns the excess Kurtosis if excess is True # there is not unbiased estimator for Kurtosis if dim is not None: n = torch_tensor.shape[dim] diff = torch_tensor - torch.mean(torch_tensor, dim=dim, keepdim=True) - m4 = torch.true_divide(torch.sum(torch.pow(diff, 4), dim=dim), n) - m2 = torch.true_divide(torch.sum(torch.pow(diff, 2), dim=dim), n) + m4 = torch.true_divide(torch.sum(torch.pow(diff, 4.0), dim=dim), n) + m2 = torch.true_divide(torch.sum(torch.pow(diff, 2.0), dim=dim), n) else: n = torch_tensor.numel() diff = torch_tensor - torch.mean(torch_tensor) - m4 = torch.true_divide(torch.pow(diff, 4), n) - m2 = torch.true_divide(torch.pow(diff, 2), n) - k = torch.true_divide(m4, torch.pow(m2, 2)) + m4 = torch.true_divide(torch.pow(diff, 4.0), n) + m2 = torch.true_divide(torch.pow(diff, 2.0), n) + res = torch.true_divide(m4, torch.pow(m2, 2.0)) + if unbiased: + res = ((n - 1.0) / ((n - 2.0) * (n - 3.0))) * ((n + 1.0) * res - 3.0 * (n - 1.0)) + 3.0 if Fischer: - k -= 3.0 - return k + res -= 3.0 + return res def var(x, axis=None, ddof=0, **kwargs): diff --git a/heat/core/tests/test_statistics.py b/heat/core/tests/test_statistics.py index e74a0c8d43..8440f79a48 100644 --- a/heat/core/tests/test_statistics.py +++ b/heat/core/tests/test_statistics.py @@ -193,64 +193,6 @@ def test_argmin(self): with self.assertRaises(ValueError): ht.argmin(data, axis=-4) - def test_cov(self): - x = ht.array([[0, 2], [1, 1], [2, 0]], dtype=ht.float, split=1).T - if x.comm.size < 3: - cov = ht.cov(x) - actual = ht.array([[1, -1], [-1, 1]], split=0) - self.assertTrue(ht.equal(cov, actual)) - - data = np.loadtxt("heat/datasets/data/iris.csv", delimiter=";") - np_cov = np.cov(data[:, 0], data[:, 1:3], rowvar=False) - - htdata = ht.load("heat/datasets/data/iris.csv", sep=";", split=0) - ht_cov = ht.cov(htdata[:, 0], htdata[:, 1:3], rowvar=False) - comp = ht.array(np_cov, dtype=ht.float) - self.assertTrue(ht.allclose(comp - ht_cov, 0, atol=1e-4)) - - np_cov = np.cov(data, rowvar=False) - ht_cov = ht.cov(htdata, rowvar=False) - self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float) - ht_cov, 0, atol=1e-4)) - - np_cov = np.cov(data, rowvar=False, ddof=1) - ht_cov = ht.cov(htdata, rowvar=False, ddof=1) - self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float) - ht_cov, 0, atol=1e-4)) - - np_cov = np.cov(data, rowvar=False, bias=True) - ht_cov = ht.cov(htdata, rowvar=False, bias=True) - self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float) - ht_cov, 0, atol=1e-4)) - - if 1 < x.comm.size < 5: - htdata = ht.load("heat/datasets/data/iris.csv", sep=";", split=1) - np_cov = np.cov(data, rowvar=False) - ht_cov = ht.cov(htdata, rowvar=False) - self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float), ht_cov, atol=1e-4)) - - np_cov = np.cov(data, data, rowvar=True) - - htdata = ht.load("heat/datasets/data/iris.csv", sep=";", split=0) - ht_cov = ht.cov(htdata, htdata, rowvar=True) - self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float), ht_cov, atol=1e-4)) - - htdata = ht.load("heat/datasets/data/iris.csv", sep=";", split=0) - with self.assertRaises(RuntimeError): - ht.cov(htdata[1:], rowvar=False) - with self.assertRaises(RuntimeError): - ht.cov(htdata, htdata[1:], rowvar=False) - - with self.assertRaises(TypeError): - ht.cov(np_cov) - with self.assertRaises(TypeError): - ht.cov(htdata, np_cov) - with self.assertRaises(TypeError): - ht.cov(htdata, ddof="str") - with self.assertRaises(ValueError): - ht.cov(ht.zeros((1, 2, 3))) - with self.assertRaises(ValueError): - ht.cov(htdata, ht.zeros((1, 2, 3))) - with self.assertRaises(ValueError): - ht.cov(htdata, ddof=10000) - def test_average(self): data = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] @@ -390,6 +332,160 @@ def test_average(self): with self.assertRaises(ValueError): ht.average(ht_array, axis=-4) + def test_cov(self): + x = ht.array([[0, 2], [1, 1], [2, 0]], dtype=ht.float, split=1).T + if x.comm.size < 3: + cov = ht.cov(x) + actual = ht.array([[1, -1], [-1, 1]], split=0) + self.assertTrue(ht.equal(cov, actual)) + + data = np.loadtxt("heat/datasets/data/iris.csv", delimiter=";") + np_cov = np.cov(data[:, 0], data[:, 1:3], rowvar=False) + + htdata = ht.load("heat/datasets/data/iris.csv", sep=";", split=0) + ht_cov = ht.cov(htdata[:, 0], htdata[:, 1:3], rowvar=False) + comp = ht.array(np_cov, dtype=ht.float) + self.assertTrue(ht.allclose(comp - ht_cov, 0, atol=1e-4)) + + np_cov = np.cov(data, rowvar=False) + ht_cov = ht.cov(htdata, rowvar=False) + self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float) - ht_cov, 0, atol=1e-4)) + + np_cov = np.cov(data, rowvar=False, ddof=1) + ht_cov = ht.cov(htdata, rowvar=False, ddof=1) + self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float) - ht_cov, 0, atol=1e-4)) + + np_cov = np.cov(data, rowvar=False, bias=True) + ht_cov = ht.cov(htdata, rowvar=False, bias=True) + self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float) - ht_cov, 0, atol=1e-4)) + + if 1 < x.comm.size < 5: + htdata = ht.load("heat/datasets/data/iris.csv", sep=";", split=1) + np_cov = np.cov(data, rowvar=False) + ht_cov = ht.cov(htdata, rowvar=False) + self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float), ht_cov, atol=1e-4)) + + np_cov = np.cov(data, data, rowvar=True) + + htdata = ht.load("heat/datasets/data/iris.csv", sep=";", split=0) + ht_cov = ht.cov(htdata, htdata, rowvar=True) + self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float), ht_cov, atol=1e-4)) + + htdata = ht.load("heat/datasets/data/iris.csv", sep=";", split=0) + with self.assertRaises(RuntimeError): + ht.cov(htdata[1:], rowvar=False) + with self.assertRaises(RuntimeError): + ht.cov(htdata, htdata[1:], rowvar=False) + + with self.assertRaises(TypeError): + ht.cov(np_cov) + with self.assertRaises(TypeError): + ht.cov(htdata, np_cov) + with self.assertRaises(TypeError): + ht.cov(htdata, ddof="str") + with self.assertRaises(ValueError): + ht.cov(ht.zeros((1, 2, 3))) + with self.assertRaises(ValueError): + ht.cov(htdata, ht.zeros((1, 2, 3))) + with self.assertRaises(ValueError): + ht.cov(htdata, ddof=10000) + + def test_kurtosis(self): + x = ht.zeros((2, 3, 4)) + with self.assertRaises(ValueError): + x.kurtosis(axis=10) + with self.assertRaises(ValueError): + x.kurtosis(axis=[4]) + with self.assertRaises(ValueError): + x.kurtosis(axis=[-4]) + with self.assertRaises(TypeError): + ht.kurtosis(x, axis="01") + with self.assertRaises(ValueError): + ht.kurtosis(x, axis=(0, "10")) + with self.assertRaises(ValueError): + ht.kurtosis(x, axis=(0, 0)) + with self.assertRaises(ValueError): + ht.kurtosis(x, axis=torch.Tensor([0, 0])) + + # 1 dim, 2 dims, 3 dims, 1 axis, 2 axes, 3 axes, + # split != axis, split == axis, split in axis, split not in axis + # comp with numpy for all, rand and randn + + def __split_calc(ht_split, axis): + if ht_split is None: + return None + else: + sp = ht_split if axis > ht_split else ht_split - 1 + if axis == ht_split: + sp = None + return sp + + # 1 dim + ht_data = ht.random.rand(50) + np_data = ht_data.copy().numpy() + np_skew32 = ht.array((ss.kurtosis(np_data, bias=False)), dtype=ht_data.dtype) + self.assertAlmostEqual(ht.kurtosis(ht_data), np_skew32.item(), places=5) + ht_data = ht.resplit(ht_data, 0) + self.assertAlmostEqual(ht.kurtosis(ht_data), np_skew32.item(), places=5) + + # 2 dim + ht_data = ht.random.rand(50, 30) + np_data = ht_data.copy().numpy() + np_skew32 = ss.kurtosis(np_data, axis=None, bias=False) + self.assertAlmostEqual(ht.kurtosis(ht_data) - np_skew32, 0, places=5) + ht_data = ht.resplit(ht_data, 0) + for ax in range(2): + np_skew32 = ht.array((ss.kurtosis(np_data, axis=ax, bias=True)), dtype=ht_data.dtype) + ht_skew = ht.kurtosis(ht_data, axis=ax, unbiased=False) + self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + sp = __split_calc(ht_data.split, ax) + self.assertEqual(ht_skew.split, sp) + ht_data = ht.resplit(ht_data, 1) + for ax in range(2): + np_skew32 = ht.array((ss.kurtosis(np_data, axis=ax, bias=True)), dtype=ht_data.dtype) + ht_skew = ht.kurtosis(ht_data, axis=ax, unbiased=False) + self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + sp = __split_calc(ht_data.split, ax) + self.assertEqual(ht_skew.split, sp) + + # 2 dim float64 + ht_data = ht.random.rand(50, 30, dtype=ht.float64) + np_data = ht_data.copy().numpy() + np_skew32 = ss.kurtosis(np_data, axis=None, bias=False) + self.assertAlmostEqual(ht.kurtosis(ht_data) - np_skew32, 0, places=5) + ht_data = ht.resplit(ht_data, 0) + for ax in range(2): + np_skew32 = ht.array((ss.kurtosis(np_data, axis=ax, bias=False)), dtype=ht_data.dtype) + ht_skew = ht.kurtosis(ht_data, axis=ax) + self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + sp = __split_calc(ht_data.split, ax) + self.assertEqual(ht_skew.split, sp) + self.assertEqual(ht_skew.dtype, ht.float64) + ht_data = ht.resplit(ht_data, 1) + for ax in range(2): + np_skew32 = ht.array((ss.kurtosis(np_data, axis=ax, bias=False)), dtype=ht_data.dtype) + ht_skew = ht.kurtosis(ht_data, axis=ax) + self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + sp = __split_calc(ht_data.split, ax) + self.assertEqual(ht_skew.split, sp) + self.assertEqual(ht_skew.dtype, ht.float64) + + # 3 dim + ht_data = ht.random.rand(50, 30, 16) + np_data = ht_data.copy().numpy() + np_skew32 = ss.kurtosis(np_data, axis=None, bias=False) + self.assertAlmostEqual(ht.kurtosis(ht_data) - np_skew32, 0, places=5) + for split in range(3): + ht_data = ht.resplit(ht_data, split) + for ax in range(3): + np_skew32 = ht.array( + (ss.kurtosis(np_data, axis=ax, bias=False)), dtype=ht_data.dtype + ) + ht_skew = ht.kurtosis(ht_data, axis=ax) + self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + sp = __split_calc(ht_data.split, ax) + self.assertEqual(ht_skew.split, sp) + def test_max(self): data = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] @@ -838,21 +934,21 @@ def test_minimum(self): ht.minimum(random_volume_1, random_volume_2, out=output) def test_skew(self): - # x = ht.zeros((2, 3, 4)) - # with self.assertRaises(ValueError): - # x.skew(axis=10) - # with self.assertRaises(ValueError): - # x.skew(axis=[4]) - # with self.assertRaises(ValueError): - # x.skew(axis=[-4]) - # with self.assertRaises(TypeError): - # ht.skew(x, axis="01") - # with self.assertRaises(ValueError): - # ht.skew(x, axis=(0, "10")) - # with self.assertRaises(ValueError): - # ht.skew(x, axis=(0, 0)) - # with self.assertRaises(ValueError): - # ht.skew(x, axis=torch.Tensor([0, 0])) + x = ht.zeros((2, 3, 4)) + with self.assertRaises(ValueError): + x.skew(axis=10) + with self.assertRaises(ValueError): + x.skew(axis=[4]) + with self.assertRaises(ValueError): + x.skew(axis=[-4]) + with self.assertRaises(TypeError): + ht.skew(x, axis="01") + with self.assertRaises(ValueError): + ht.skew(x, axis=(0, "10")) + with self.assertRaises(ValueError): + ht.skew(x, axis=(0, 0)) + with self.assertRaises(ValueError): + ht.skew(x, axis=torch.Tensor([0, 0])) a = ht.arange(1, 5) self.assertEqual(a.skew(), 0.0) @@ -881,19 +977,19 @@ def __split_calc(ht_split, axis): # 2 dim ht_data = ht.random.rand(50, 30) np_data = ht_data.copy().numpy() - np_skew32 = ss.skew(np_data, axis=None, bias=False) - self.assertAlmostEqual(ht.skew(ht_data) - np_skew32, 0, places=5) + np_skew32 = ss.skew(np_data, axis=None, bias=True) + self.assertAlmostEqual(ht.skew(ht_data, unbiased=False) - np_skew32, 0, places=5) ht_data = ht.resplit(ht_data, 0) for ax in range(2): - np_skew32 = ht.array((ss.skew(np_data, axis=ax, bias=False)), dtype=ht_data.dtype) - ht_skew = ht.skew(ht_data, axis=ax) + np_skew32 = ht.array((ss.skew(np_data, axis=ax, bias=True)), dtype=ht_data.dtype) + ht_skew = ht.skew(ht_data, axis=ax, unbiased=False) self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) sp = __split_calc(ht_data.split, ax) self.assertEqual(ht_skew.split, sp) ht_data = ht.resplit(ht_data, 1) for ax in range(2): - np_skew32 = ht.array((ss.skew(np_data, axis=ax, bias=False)), dtype=ht_data.dtype) - ht_skew = ht.skew(ht_data, axis=ax) + np_skew32 = ht.array((ss.skew(np_data, axis=ax, bias=True)), dtype=ht_data.dtype) + ht_skew = ht.skew(ht_data, axis=ax, unbiased=False) self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) sp = __split_calc(ht_data.split, ax) self.assertEqual(ht_skew.split, sp) From d7a2c2c49f87d9132919881212cc5f08a6b73a04 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 30 Jun 2020 14:12:09 +0200 Subject: [PATCH 08/14] doc string correction --- heat/core/statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/statistics.py b/heat/core/statistics.py index d84f379c0b..9caf151891 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -461,7 +461,7 @@ def kurtosis(x, axis=None, unbiased=True, Fischer=True): unbiased : Bool if True (default) the calculations are corrected for bias Fischer : bool - Wheather use Fischer's definition or not, if true 3. is subtracted from the result + Whether use Fischer's definition or not, if true 3. is subtracted from the result """ From ff386d8a66c76d5e9f02842078b1f6a804a48383 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 30 Jun 2020 14:33:40 +0200 Subject: [PATCH 09/14] changelog update, dead code removal, cleanup --- CHANGELOG.md | 14 ++++++++------ heat/core/statistics.py | 9 --------- heat/spatial/distance.py | 1 - 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a5d698b5bb..d3718a5464 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,13 +6,15 @@ - [#575](https://github.com/helmholtz-analytics/heat/pull/558) Bugfix: `where` and `cov` convert ints to floats when given as parameters - [#577](https://github.com/helmholtz-analytics/heat/pull/577) Add ndim property in dndarray - [#578](https://github.com/helmholtz-analytics/heat/pull/578) Bugfix: Bad variable in reshape -- [#580](https://github.com/helmholtz-analytics/heat/pull/580) New feature: fliplr() -- [#581](https://github.com/helmholtz-analytics/heat/pull/581) New feature: DNDarray.tolist() -- [#593](https://github.com/helmholtz-analytics/heat/pull/593) New feature: arctan2() +- [#580](https://github.com/helmholtz-analytics/heat/pull/580) New feature: `fliplr()` +- [#581](https://github.com/helmholtz-analytics/heat/pull/581) New feature: `DNDarray.tolist()` +- [#593](https://github.com/helmholtz-analytics/heat/pull/593) New feature: `arctan2()` - [#594](https://github.com/helmholtz-analytics/heat/pull/594) New feature: Advanced indexing -- [#594](https://github.com/helmholtz-analytics/heat/pull/594) Bugfix: getitem and setitem memory consumption heavily reduced -- [#596](https://github.com/helmholtz-analytics/heat/pull/596) New feature: outer() -- [#600](https://github.com/helmholtz-analytics/heat/pull/600) New feature: shape() +- [#594](https://github.com/helmholtz-analytics/heat/pull/594) Bugfix: `__getitem__` and `__setitem__` memory consumption heavily reduced +- [#596](https://github.com/helmholtz-analytics/heat/pull/596) New feature: `outer()` +- [#600](https://github.com/helmholtz-analytics/heat/pull/600) New feature: `shape()` +- [#615](https://github.com/helmholtz-analytics/heat/pull/615) New feature: `skew()` +- [#615](https://github.com/helmholtz-analytics/heat/pull/615) New feature: `kurtosis()` # v0.4.0 diff --git a/heat/core/statistics.py b/heat/core/statistics.py index 9caf151891..c318185d56 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -500,7 +500,6 @@ def __reduce_kurts_elementwise(output_shape_i): rtot[x.comm.rank, 4, :] = float(x.lshape[x.split]) x.comm.Allreduce(MPI.IN_PLACE, rtot, MPI.SUM) - # print(rtot[x.comm.rank, 0, :]) for i in range(1, x.comm.size): rtot[0] = __merge_moments( (rtot[0, 0, :], rtot[0, 1, :], rtot[0, 2, :], rtot[0, 3, :], rtot[0, 4, :]), @@ -915,16 +914,11 @@ def __merge_moments(m1, m2, unbiased=True): return var_m, mu, n sk1, sk2 = m1[-4], m2[-4] - # print(sk1) - # print(sk2) dn = delta / n - # todo: fix this condition if all(var_m != 0): # Skewness does not exist if var is 0 s1 = sk1 + sk2 s2 = dn * (n1 * var2 - n2 * var1) / 6.0 - # print('n', dn / 6) s3 = (dn ** 3) * n1 * n2 * (n1 ** 2 - n2 ** 2) - # print(s1, s2) skew_m = s1 + s2 + s3 else: skew_m = None @@ -1221,7 +1215,6 @@ def __reduce_skews_elementwise(output_shape_i): rtot[x.comm.rank, 3, :] = float(x.lshape[x.split]) x.comm.Allreduce(MPI.IN_PLACE, rtot, MPI.SUM) - # print(rtot[x.comm.rank, 0, :]) for i in range(1, x.comm.size): rtot[0] = __merge_moments( (rtot[0, 0, :], rtot[0, 1, :], rtot[0, 2, :], rtot[0, 3, :]), @@ -1234,8 +1227,6 @@ def __reduce_skews_elementwise(output_shape_i): if axis is None: # no axis given # todo: determine if this is a valid (and fast implementation) mu = mean(x) - # diff = x - # diff._DNDarray__array -= float(mu.item()) diff = x - mu n = x.numel diff --git a/heat/spatial/distance.py b/heat/spatial/distance.py index 0a4184a8f1..bf39f457a5 100644 --- a/heat/spatial/distance.py +++ b/heat/spatial/distance.py @@ -1,5 +1,4 @@ import torch -import numpy as np from mpi4py import MPI from ..core import factories From b255d9d10efce86b1977c47a19a6baf369d142e2 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 30 Jun 2020 14:47:59 +0200 Subject: [PATCH 10/14] more cleanup --- heat/core/statistics.py | 49 ++++++++++------------------------------- 1 file changed, 12 insertions(+), 37 deletions(-) diff --git a/heat/core/statistics.py b/heat/core/statistics.py index c318185d56..818aebc132 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -512,28 +512,14 @@ def __reduce_kurts_elementwise(output_shape_i): return res # ---------------------------------------------------------------------------------------------- - if axis is None: # no axis given + if axis is None or (isinstance(axis, int) and x.split == axis): # no axis given # todo: determine if this is a valid (and fast implementation) - mu = mean(x) + mu = mean(x, axis=axis) + if axis is not None and axis > 0: + mu = mu.expand_dims(axis) diff = x - mu - n = x.numel - - m4 = arithmetics.sum(arithmetics.pow(diff, 4.0)) / n - m2 = arithmetics.sum(arithmetics.pow(diff, 2.0)) / n - res = m4 / arithmetics.pow(m2, 2.0) - if unbiased: - res = ((n - 1.0) / ((n - 2.0) * (n - 3.0))) * ((n + 1.0) * res - 3 * (n - 1.0)) + 3.0 - if Fischer: - res -= 3.0 - return res.item() + n = float(x.shape[axis]) if axis is not None else x.numel - elif isinstance(axis, int) and x.split == axis: # axis is given - if axis > 0: - diff = x - mean(x, axis=axis).expand_dims(axis) - else: - diff = x - mean(x, axis=axis) - n = float(x.shape[axis]) - # possible speedup to be had here: m4 = arithmetics.sum(arithmetics.pow(diff, 4.0), axis) / n m2 = arithmetics.sum(arithmetics.pow(diff, 2.0), axis) / n res = m4 / arithmetics.pow(m2, 2.0) @@ -541,7 +527,7 @@ def __reduce_kurts_elementwise(output_shape_i): res = ((n - 1.0) / ((n - 2.0) * (n - 3.0))) * ((n + 1.0) * res - 3 * (n - 1.0)) + 3.0 if Fischer: res -= 3.0 - return res + return res.item() if res.numel == 1 else res else: return __moment_w_axis( __torch_kurtosis, x, axis, __reduce_kurts_elementwise, unbiased, Fischer @@ -1224,32 +1210,21 @@ def __reduce_skews_elementwise(output_shape_i): return rtot[0, 0, :][0] if rtot[0, 0, :].size == 1 else rtot[0, 0, :] # ---------------------------------------------------------------------------------------------- - if axis is None: # no axis given + if axis is None or (isinstance(axis, int) and x.split == axis): # no axis given # todo: determine if this is a valid (and fast implementation) - mu = mean(x) + mu = mean(x, axis=axis) + if axis is not None and axis > 0: + mu = mu.expand_dims(axis) diff = x - mu - n = x.numel - m3 = arithmetics.sum(arithmetics.pow(diff, 3)) / n - m2 = arithmetics.sum(arithmetics.pow(diff, 2)) / n - res = m3 / arithmetics.pow(m2, 1.5) - if unbiased: - res *= ((n * (n - 1)) ** 0.5) / (n - 2.0) - return res.item() + n = float(x.shape[axis]) if axis is not None else x.numel - elif isinstance(axis, int) and x.split == axis: # axis is given - if axis > 0: - diff = x - mean(x, axis=axis).expand_dims(axis) - else: - diff = x - mean(x, axis=axis) - n = float(x.shape[axis]) - # possible speedup to be had here: m3 = arithmetics.sum(arithmetics.pow(diff, 3.0), axis) / n m2 = arithmetics.sum(arithmetics.pow(diff, 2.0), axis) / n res = m3 / arithmetics.pow(m2, 1.5) if unbiased: res *= ((n * (n - 1.0)) ** 0.5) / (n - 2.0) - return res + return res.item() if res.numel == 1 else res else: return __moment_w_axis(__torch_skew, x, axis, __reduce_skews_elementwise, unbiased) From 244c2517bd4bdc02d499cfac5a96c7240d60786c Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Tue, 30 Jun 2020 14:55:02 +0200 Subject: [PATCH 11/14] added scipy requirement, needed for tests of skew and kurtosis added reverences to skew and kurtosis cleanup, removal of useless code added warning to documentation coverage increases and removal of code block in merge moments needed for skew and kurtosis, only required if said functions support multiple axes coverage update moved to the correct alphabetical place in statistics removed excess try except in requested changes, doc strings, reordering of moments w axis --- heat/core/dndarray.py | 33 ++- heat/core/statistics.py | 319 ++++++++++++----------------- heat/core/tests/test_statistics.py | 122 +++++------ setup.py | 2 +- 4 files changed, 205 insertions(+), 271 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index cbf0fc5004..811c8d0dc8 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -4,7 +4,7 @@ import math import torch import warnings -from typing import List +from typing import List, Union from . import arithmetics from . import devices @@ -146,10 +146,7 @@ def size(self): size : int number of total elements of the tensor """ - try: - return torch.prod(torch.tensor(self.gshape, device=self.device.torch_device)).item() - except TypeError: - return 1 + return torch.prod(torch.tensor(self.gshape, device=self.device.torch_device)).item() @property def numel(self): @@ -1705,6 +1702,7 @@ def item(self): def kurtosis(self, axis=None, unbiased=True, Fischer=True): """ Compute the kurtosis (Fisher or Pearson) of a dataset. + TODO: add return type annotation (DNDarray) and x annotation (DNDarray) Kurtosis is the fourth central moment divided by the square of the variance. If Fisher’s definition is used, then 3.0 is subtracted from the result to give 0.0 for a normal distribution. @@ -1716,13 +1714,17 @@ def kurtosis(self, axis=None, unbiased=True, Fischer=True): ---------- x : ht.DNDarray Input array - axis : NoneType or Int or iterable + axis : NoneType or Int Axis along which skewness is calculated, Default is to compute over the whole array `x` unbiased : Bool if True (default) the calculations are corrected for bias Fischer : bool - Wheather use Fischer's definition or not, if true 3. is subtracted from the result + Whether use Fischer's definition or not. If true 3. is subtracted from the result. + Warnings + -------- + UserWarning: Dependent on the axis given and the split configuration a UserWarning may be thrown during this + function as data is transferred between processes """ return statistics.kurtosis(self, axis, unbiased, Fischer) @@ -3250,6 +3252,23 @@ def sinh(self, out=None): return trigonometrics.sinh(self, out) def skew(self, axis=None, unbiased=True): + """ + Compute the sample skewness of a data set. + + Parameters + ---------- + x : ht.DNDarray + Input array + axis : NoneType or Int + Axis along which skewness is calculated, Default is to compute over the whole array `x` + unbiased : Bool + if True (default) the calculations are corrected for bias + + Warnings + -------- + UserWarning: Dependent on the axis given and the split configuration a UserWarning may be thrown during this + function as data is transferred between processes + """ return statistics.skew(self, axis, unbiased) def sqrt(self, out=None): diff --git a/heat/core/statistics.py b/heat/core/statistics.py index 818aebc132..64f28a7886 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -1,6 +1,6 @@ -import math import numpy as np import torch +from typing import Callable, Union, Tuple from .communication import MPI from . import arithmetics @@ -435,7 +435,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None): norm = x.shape[1] - ddof # find normalization: if norm <= 0: - raise ValueError("ddof >= number of elements in m, {} {}".format(ddof, m.numel)) + raise ValueError(f"ddof >= number of elements in m, {ddof} {m.numel}") x -= avg.expand_dims(1) c = linalg.dot(x, x.T) c /= norm @@ -445,6 +445,8 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None): def kurtosis(x, axis=None, unbiased=True, Fischer=True): """ Compute the kurtosis (Fisher or Pearson) of a dataset. + TODO: type annotations: + def kurtosis(x : DNDarray, axis : Union[None, int] = None, unbiased : bool = True, Fischer : bool = True) -> DNDarray: Kurtosis is the fourth central moment divided by the square of the variance. If Fisher’s definition is used, then 3.0 is subtracted from the result to give 0.0 for a normal distribution. @@ -456,64 +458,20 @@ def kurtosis(x, axis=None, unbiased=True, Fischer=True): ---------- x : ht.DNDarray Input array - axis : NoneType or Int or iterable + axis : NoneType or Int Axis along which skewness is calculated, Default is to compute over the whole array `x` unbiased : Bool if True (default) the calculations are corrected for bias Fischer : bool - Whether use Fischer's definition or not, if true 3. is subtracted from the result + Whether use Fischer's definition or not. If true 3. is subtracted from the result. + Warnings + -------- + UserWarning: Dependent on the axis given and the split configuration a UserWarning may be thrown during this + function as data is transferred between processes """ - - def __reduce_kurts_elementwise(output_shape_i): - """ - Function to combine the calculated skews together. This does an element-wise update of the - calculated vars to merge them together using the merge_vars function. This function operates - using x from the var function parameters. - - Parameters - ---------- - output_shape_i : iterable - Iterable with the dimensions of the output of the var function. - """ - - if x.lshape[x.split] != 0: - mu = torch.mean(x._DNDarray__array, dim=axis) - var = torch.var(x._DNDarray__array, dim=axis, unbiased=unbiased) - skl = __torch_skew(x._DNDarray__array, dim=axis, unbiased=unbiased) - kurtl = __torch_kurtosis( - x._DNDarray__array, dim=axis, unbiased=unbiased, Fischer=Fischer - ) - else: - mu = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) - var = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) - skl = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) - kurtl = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) - - sk_shape = list(skl.shape) if list(skl.shape) else [1] - - rtot = factories.zeros(([x.comm.size, 5] + sk_shape), dtype=x.dtype, device=x.device) - rtot[x.comm.rank, 0, :] = kurtl - rtot[x.comm.rank, 1, :] = skl - rtot[x.comm.rank, 2, :] = var - rtot[x.comm.rank, 3, :] = mu - rtot[x.comm.rank, 4, :] = float(x.lshape[x.split]) - x.comm.Allreduce(MPI.IN_PLACE, rtot, MPI.SUM) - - for i in range(1, x.comm.size): - rtot[0] = __merge_moments( - (rtot[0, 0, :], rtot[0, 1, :], rtot[0, 2, :], rtot[0, 3, :], rtot[0, 4, :]), - (rtot[i, 0, :], rtot[i, 1, :], rtot[i, 2, :], rtot[i, 3, :], rtot[i, 4, :]), - unbiased=unbiased, - ) - res = rtot[0, 0, :][0] if rtot[0, 0, :].size == 1 else rtot[0, 0, :] - if unbiased: - res = ((n - 1.0) / ((n - 2.0) * (n - 3.0))) * ((n + 1.0) * res - 3.0 * (n - 1.0)) + 3.0 - return res - - # ---------------------------------------------------------------------------------------------- if axis is None or (isinstance(axis, int) and x.split == axis): # no axis given - # todo: determine if this is a valid (and fast implementation) + # TODO: determine if this is a valid (and fast implementation) mu = mean(x, axis=axis) if axis is not None and axis > 0: mu = mu.expand_dims(axis) @@ -528,10 +486,10 @@ def __reduce_kurts_elementwise(output_shape_i): if Fischer: res -= 3.0 return res.item() if res.numel == 1 else res + elif isinstance(axis, (list, tuple)): + raise TypeError("axis cannot be a list or a tuple, currently {}".format(type(axis))) else: - return __moment_w_axis( - __torch_kurtosis, x, axis, __reduce_kurts_elementwise, unbiased, Fischer - ) + return __moment_w_axis(__torch_kurtosis, x, axis, None, unbiased, Fischer) def max(x, axis=None, out=None, keepdim=None): @@ -855,12 +813,15 @@ def __merge_moments(m1, m2, unbiased=True): Merge two statistical moments. If the length of m1/m2 (must be equal) is == 3 then the second moment (variance) is merged. This function can be expanded to merge other moments according to Reference 1 as well. Note: all tensors/arrays must be either the same size or individual values + TODO: Type annotation: + def __merge_moments(m1 : Tuple, m2 : Tuple, unbiased : bool=True) -> Tuple: Parameters ---------- m1 : tuple Tuple of the moments to merge together, the 0th element is the moment to be merged. The tuple must be sorted in descending order of moments + Can be m2 : tuple Tuple of the moments to merge together, the 0th element is the moment to be merged. The tuple must be sorted in descending order of moments @@ -899,27 +860,28 @@ def __merge_moments(m1, m2, unbiased=True): if len(m1) == 3: # merge vars return var_m, mu, n - sk1, sk2 = m1[-4], m2[-4] - dn = delta / n - if all(var_m != 0): # Skewness does not exist if var is 0 - s1 = sk1 + sk2 - s2 = dn * (n1 * var2 - n2 * var1) / 6.0 - s3 = (dn ** 3) * n1 * n2 * (n1 ** 2 - n2 ** 2) - skew_m = s1 + s2 + s3 - else: - skew_m = None - if len(m1) == 4: - return skew_m, var_m, mu, n - - k1, k2 = m1[-5], m2[-5] - if skew_m is None: - return None, skew_m, var_m, mu, n - s1 = (1 / 24.0) * dn * (n1 * sk2 - n2 * sk1) - s2 = (1 / 12.0) * (dn ** 2) * ((n1 ** 2) * var2 + (n2 ** 2) * var1) - s3 = (dn ** 4) * n1 * n2 * ((n1 ** 3) + (n2 ** 3)) - k = k1 + k2 + s1 + s2 + s3 - if len(m1) == 5: - return k, skew_m, var_m, mu, n + # TODO: This code block can be added if skew or kurtosis support multiple axes: + # sk1, sk2 = m1[-4], m2[-4] + # dn = delta / n + # if all(var_m != 0): # Skewness does not exist if var is 0 + # s1 = sk1 + sk2 + # s2 = dn * (n1 * var2 - n2 * var1) / 6.0 + # s3 = (dn ** 3) * n1 * n2 * (n1 ** 2 - n2 ** 2) + # skew_m = s1 + s2 + s3 + # else: + # skew_m = None + # if len(m1) == 4: + # return skew_m, var_m, mu, n + # + # k1, k2 = m1[-5], m2[-5] + # if skew_m is None: + # return None, skew_m, var_m, mu, n + # s1 = (1 / 24.0) * dn * (n1 * sk2 - n2 * sk1) + # s2 = (1 / 12.0) * (dn ** 2) * ((n1 ** 2) * var2 + (n2 ** 2) * var1) + # s3 = (dn ** 4) * n1 * n2 * ((n1 ** 3) + (n2 ** 3)) + # k = k1 + k2 + s1 + s2 + s3 + # if len(m1) == 5: + # return k, skew_m, var_m, mu, n def min(x, axis=None, out=None, keepdim=None): @@ -1121,6 +1083,76 @@ def minimum(x1, x2, out=None): return lresult +def __moment_w_axis(function, x, axis, elementwise_function, unbiased=None, Fischer=None): + # TODO: type annotations: + # def __moment_w_axis(function: Callable, x, axis: Union[None, int, list, tuple], elementwise_function: Callable, + # unbiased: bool = None, Fischer: bool = None) -> DNDarray: + + # helper for calculating a statistical moment with a given axis + kwargs = {"dim": axis} + if unbiased: + kwargs["unbiased"] = unbiased + if Fischer: + kwargs["Fischer"] = Fischer + + output_shape = list(x.shape) + if isinstance(axis, int): + if axis >= len(x.shape): + raise ValueError("axis must be < {}, currently is {}".format(len(x.shape), axis)) + axis = stride_tricks.sanitize_axis(x.shape, axis) + # only one axis given + output_shape = [output_shape[it] for it in range(len(output_shape)) if it != axis] + output_shape = output_shape if output_shape else (1,) + + if x.split is None: # x is *not* distributed -> no need to distributed + return factories.array( + function(x._DNDarray__array, **kwargs), dtype=x.dtype, device=x.device + ) + elif axis == x.split: # x is distributed and axis chosen is == to split + return elementwise_function(output_shape) + # singular axis given (axis) not equal to split direction (x.split) + lcl = function(x._DNDarray__array, **kwargs) + return factories.array( + lcl, is_split=x.split if axis > x.split else x.split - 1, dtype=x.dtype, device=x.device + ) + elif not isinstance(axis, (list, tuple, torch.Tensor)): + raise TypeError( + f"axis must be an int, tuple, list, or torch.Tensor; currently it is {type(axis)}." + ) + # else: + if isinstance(axis, torch.Tensor): + axis = axis.tolist() + + if isinstance(axis, (list, tuple)) and len(set(axis)) != len(axis): # most common case + raise ValueError("duplicate value in axis") + if any(not isinstance(j, int) for j in axis): + raise ValueError( + f"items in axis iterable must be integers, axes: {[type(q) for q in axis]}" + ) + + if any(d < 0 for d in axis): + axis = [stride_tricks.sanitize_axis(x.shape, j) for j in axis] + if any(d > len(x.shape) for d in axis): + raise ValueError(f"axes (axis) must be < {len(x.shape)}, currently are {axis}") + + output_shape = [output_shape[it] for it in range(len(output_shape)) if it not in axis] + # multiple dimensions + if x.split is None: + return factories.array( + function(x._DNDarray__array, **kwargs), is_split=x.split, device=x.device + ) + if x.split in axis: + # merge in the direction of the split + return elementwise_function(output_shape) + # multiple dimensions which does *not* include the split axis + # combine along the split axis + return factories.array( + function(x._DNDarray__array, **kwargs), + is_split=x.split if x.split < len(output_shape) else len(output_shape) - 1, + device=x.device, + ) + + def mpi_argmax(a, b, _): lhs = torch.from_numpy(np.frombuffer(a, dtype=np.float64)) rhs = torch.from_numpy(np.frombuffer(b, dtype=np.float64)) @@ -1159,59 +1191,26 @@ def mpi_argmin(a, b, _): def skew(x, axis=None, unbiased=True): """ Compute the sample skewness of a data set. + TODO: type annotations + def skew(x : DNDarray, axis : Union[None, int] = None, unbiased : bool = True) -> DNDarray: + Parameters ---------- x : ht.DNDarray Input array - axis : NoneType or Int or iterable + axis : NoneType or Int Axis along which skewness is calculated, Default is to compute over the whole array `x` unbiased : Bool if True (default) the calculations are corrected for bias + Warnings + -------- + UserWarning: Dependent on the axis given and the split configuration a UserWarning may be thrown during this + function as data is transferred between processes """ - - def __reduce_skews_elementwise(output_shape_i): - """ - Function to combine the calculated skews together. This does an element-wise update of the - calculated vars to merge them together using the merge_vars function. This function operates - using x from the var function parameters. - - Parameters - ---------- - output_shape_i : iterable - Iterable with the dimensions of the output of the var function. - """ - - if x.lshape[x.split] != 0: - mu = torch.mean(x._DNDarray__array, dim=axis) - var = torch.var(x._DNDarray__array, dim=axis, unbiased=unbiased) - skl = __torch_skew(x._DNDarray__array, dim=axis, unbiased=unbiased) - else: - mu = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) - var = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) - skl = factories.zeros(output_shape_i, dtype=x.dtype, device=x.device) - - sk_shape = list(skl.shape) if list(skl.shape) else [1] - - rtot = factories.zeros(([x.comm.size, 4] + sk_shape), dtype=x.dtype, device=x.device) - rtot[x.comm.rank, 0, :] = skl - rtot[x.comm.rank, 1, :] = var - rtot[x.comm.rank, 2, :] = mu - rtot[x.comm.rank, 3, :] = float(x.lshape[x.split]) - x.comm.Allreduce(MPI.IN_PLACE, rtot, MPI.SUM) - - for i in range(1, x.comm.size): - rtot[0] = __merge_moments( - (rtot[0, 0, :], rtot[0, 1, :], rtot[0, 2, :], rtot[0, 3, :]), - (rtot[i, 0, :], rtot[i, 1, :], rtot[i, 2, :], rtot[i, 3, :]), - unbiased=unbiased, - ) - return rtot[0, 0, :][0] if rtot[0, 0, :].size == 1 else rtot[0, 0, :] - - # ---------------------------------------------------------------------------------------------- if axis is None or (isinstance(axis, int) and x.split == axis): # no axis given - # todo: determine if this is a valid (and fast implementation) + # TODO: determine if this is a valid (and fast implementation) mu = mean(x, axis=axis) if axis is not None and axis > 0: mu = mu.expand_dims(axis) @@ -1225,8 +1224,11 @@ def __reduce_skews_elementwise(output_shape_i): if unbiased: res *= ((n * (n - 1.0)) ** 0.5) / (n - 2.0) return res.item() if res.numel == 1 else res + elif isinstance(axis, (list, tuple)): + raise TypeError(f"axis cannot be a list or a tuple, currently {type(axis)}") else: - return __moment_w_axis(__torch_skew, x, axis, __reduce_skews_elementwise, unbiased) + # if multiple axes are required, need to add a reduce_skews_elementwise function + return __moment_w_axis(__torch_skew, x, axis, None, unbiased) def std(x, axis=None, ddof=0, **kwargs): @@ -1277,74 +1279,9 @@ def std(x, axis=None, ddof=0, **kwargs): return exponential.sqrt(var(x, axis, ddof, **kwargs), out=None) -def __moment_w_axis(function, x, axis, elementwise_function, unbiased=None, Fischer=None): - # helper for calculating a statistical moment with a given axis - kwargs = {"dim": axis} - if unbiased: - kwargs["unbiased"] = unbiased - if Fischer: - kwargs["Fischer"] = Fischer - - output_shape = list(x.shape) - if isinstance(axis, (list, tuple, dndarray.DNDarray, torch.Tensor)): - if isinstance(axis, (list, tuple)) and len(set(axis)) != len(axis): - raise ValueError("duplicate value in axis") - if ( - isinstance(axis, (dndarray.DNDarray, torch.Tensor)) - and axis.unique().numel() != axis.numel() - ): - raise ValueError("duplicate value in axis") - if any([not isinstance(j, int) for j in axis]): - raise ValueError( - "items in axis iterable must be integers, axes: {}".format([type(q) for q in axis]) - ) - if any(d < 0 for d in axis): - axis = [stride_tricks.sanitize_axis(x.shape, j) for j in axis] - if any(d > len(x.shape) for d in axis): - raise ValueError( - "axes (axis) must be < {}, currently are {}".format(len(x.shape), axis) - ) - - output_shape = [output_shape[it] for it in range(len(output_shape)) if it not in axis] - # multiple dimensions - if x.split is None: - return factories.array( - function(x._DNDarray__array, **kwargs), is_split=x.split, device=x.device - ) - if x.split in axis: - # merge in the direction of the split - return elementwise_function(output_shape) - # multiple dimensions which does *not* include the split axis - # combine along the split axis - return factories.array( - function(x._DNDarray__array, **kwargs), - is_split=x.split if x.split < len(output_shape) else len(output_shape) - 1, - device=x.device, - ) - elif isinstance(axis, int): - if axis >= len(x.shape): - raise ValueError("axis must be < {}, currently is {}".format(len(x.shape), axis)) - axis = stride_tricks.sanitize_axis(x.shape, axis) - # only one axis given - output_shape = [output_shape[it] for it in range(len(output_shape)) if it != axis] - output_shape = output_shape if output_shape else (1,) - - if x.split is None: # x is *not* distributed -> no need to distributed - return factories.array( - function(x._DNDarray__array, **kwargs), dtype=x.dtype, device=x.device - ) - elif axis == x.split: # x is distributed and axis chosen is == to split - return elementwise_function(output_shape) - # singular axis given (axis) not equal to split direction (x.split) - lcl = function(x._DNDarray__array, **kwargs) - return factories.array( - lcl, is_split=x.split if axis > x.split else x.split - 1, dtype=x.dtype, device=x.device - ) - else: - raise TypeError("axis (axis) must be an int, tuple, list, etc.; currently it is {}. ") - - def __torch_skew(torch_tensor, dim=None, unbiased=False): + # TODO: type annotations: + # def __torch_skew(torch_tensor : torch.Tensor, dim : int = None, unbiased : bool = False) -> torch.Tensor: # calculate the sample skewness of a torch tensor # return the bias corrected Fischer-Pearson standardized moment coefficient by default if dim is not None: @@ -1364,6 +1301,8 @@ def __torch_skew(torch_tensor, dim=None, unbiased=False): def __torch_kurtosis(torch_tensor, dim=None, Fischer=True, unbiased=False): + # TODO: type annotations: + # def __torch_kurtosis(torch_tensor : torch.Tensor, dim : int = None, Fischer : bool = True, unbiased : bool = False) -> torch.Tensor: # calculate the sample kurtosis of a torch tensor, Pearson's definition # returns the excess Kurtosis if excess is True # there is not unbiased estimator for Kurtosis @@ -1450,11 +1389,11 @@ def var(x, axis=None, ddof=0, **kwargs): """ if not isinstance(ddof, int): - raise TypeError("ddof must be integer, is {}".format(type(ddof))) + raise TypeError(f"ddof must be integer, is {type(ddof)}") elif ddof > 1: - raise NotImplementedError("Not implemented for ddof > 1.") + raise NotImplementedError(f"Not implemented for ddof > 1.") elif ddof < 0: - raise ValueError("Expected ddof=0 or ddof=1, got {}".format(ddof)) + raise ValueError(f"Expected ddof=0 or ddof=1, got {ddof}") else: if kwargs.get("bessel"): unbiased = kwargs.get("bessel") diff --git a/heat/core/tests/test_statistics.py b/heat/core/tests/test_statistics.py index 8440f79a48..2fbe85e0eb 100644 --- a/heat/core/tests/test_statistics.py +++ b/heat/core/tests/test_statistics.py @@ -394,97 +394,90 @@ def test_kurtosis(self): x = ht.zeros((2, 3, 4)) with self.assertRaises(ValueError): x.kurtosis(axis=10) - with self.assertRaises(ValueError): - x.kurtosis(axis=[4]) - with self.assertRaises(ValueError): - x.kurtosis(axis=[-4]) with self.assertRaises(TypeError): ht.kurtosis(x, axis="01") - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): ht.kurtosis(x, axis=(0, "10")) - with self.assertRaises(ValueError): - ht.kurtosis(x, axis=(0, 0)) - with self.assertRaises(ValueError): - ht.kurtosis(x, axis=torch.Tensor([0, 0])) - - # 1 dim, 2 dims, 3 dims, 1 axis, 2 axes, 3 axes, - # split != axis, split == axis, split in axis, split not in axis - # comp with numpy for all, rand and randn def __split_calc(ht_split, axis): - if ht_split is None: - return None - else: - sp = ht_split if axis > ht_split else ht_split - 1 - if axis == ht_split: - sp = None - return sp + sp = ht_split if axis > ht_split else ht_split - 1 + if axis == ht_split: + sp = None + return sp # 1 dim ht_data = ht.random.rand(50) np_data = ht_data.copy().numpy() - np_skew32 = ht.array((ss.kurtosis(np_data, bias=False)), dtype=ht_data.dtype) - self.assertAlmostEqual(ht.kurtosis(ht_data), np_skew32.item(), places=5) + np_kurtosis32 = ht.array((ss.kurtosis(np_data, bias=False)), dtype=ht_data.dtype) + self.assertAlmostEqual(ht.kurtosis(ht_data), np_kurtosis32.item(), places=5) ht_data = ht.resplit(ht_data, 0) - self.assertAlmostEqual(ht.kurtosis(ht_data), np_skew32.item(), places=5) + self.assertAlmostEqual(ht.kurtosis(ht_data), np_kurtosis32.item(), places=5) # 2 dim ht_data = ht.random.rand(50, 30) np_data = ht_data.copy().numpy() - np_skew32 = ss.kurtosis(np_data, axis=None, bias=False) - self.assertAlmostEqual(ht.kurtosis(ht_data) - np_skew32, 0, places=5) + np_kurtosis32 = ss.kurtosis(np_data, axis=None, bias=False) + self.assertAlmostEqual(ht.kurtosis(ht_data) - np_kurtosis32, 0, places=5) ht_data = ht.resplit(ht_data, 0) for ax in range(2): - np_skew32 = ht.array((ss.kurtosis(np_data, axis=ax, bias=True)), dtype=ht_data.dtype) - ht_skew = ht.kurtosis(ht_data, axis=ax, unbiased=False) - self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + np_kurtosis32 = ht.array( + (ss.kurtosis(np_data, axis=ax, bias=True)), dtype=ht_data.dtype + ) + ht_kurtosis = ht.kurtosis(ht_data, axis=ax, unbiased=False) + self.assertTrue(ht.allclose(ht_kurtosis, np_kurtosis32, atol=1e-5)) sp = __split_calc(ht_data.split, ax) - self.assertEqual(ht_skew.split, sp) + self.assertEqual(ht_kurtosis.split, sp) ht_data = ht.resplit(ht_data, 1) for ax in range(2): - np_skew32 = ht.array((ss.kurtosis(np_data, axis=ax, bias=True)), dtype=ht_data.dtype) - ht_skew = ht.kurtosis(ht_data, axis=ax, unbiased=False) - self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + np_kurtosis32 = ht.array( + (ss.kurtosis(np_data, axis=ax, bias=True)), dtype=ht_data.dtype + ) + ht_kurtosis = ht.kurtosis(ht_data, axis=ax, unbiased=False) + self.assertTrue(ht.allclose(ht_kurtosis, np_kurtosis32, atol=1e-5)) sp = __split_calc(ht_data.split, ax) - self.assertEqual(ht_skew.split, sp) + self.assertEqual(ht_kurtosis.split, sp) # 2 dim float64 ht_data = ht.random.rand(50, 30, dtype=ht.float64) np_data = ht_data.copy().numpy() - np_skew32 = ss.kurtosis(np_data, axis=None, bias=False) - self.assertAlmostEqual(ht.kurtosis(ht_data) - np_skew32, 0, places=5) + np_kurtosis32 = ss.kurtosis(np_data, axis=None, bias=False) + self.assertAlmostEqual(ht.kurtosis(ht_data) - np_kurtosis32, 0, places=5) ht_data = ht.resplit(ht_data, 0) for ax in range(2): - np_skew32 = ht.array((ss.kurtosis(np_data, axis=ax, bias=False)), dtype=ht_data.dtype) - ht_skew = ht.kurtosis(ht_data, axis=ax) - self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + np_kurtosis32 = ht.array( + (ss.kurtosis(np_data, axis=ax, bias=False)), dtype=ht_data.dtype + ) + ht_kurtosis = ht.kurtosis(ht_data, axis=ax) + self.assertTrue(ht.allclose(ht_kurtosis, np_kurtosis32, atol=1e-5)) sp = __split_calc(ht_data.split, ax) - self.assertEqual(ht_skew.split, sp) - self.assertEqual(ht_skew.dtype, ht.float64) + self.assertEqual(ht_kurtosis.split, sp) + self.assertEqual(ht_kurtosis.dtype, ht.float64) ht_data = ht.resplit(ht_data, 1) for ax in range(2): - np_skew32 = ht.array((ss.kurtosis(np_data, axis=ax, bias=False)), dtype=ht_data.dtype) - ht_skew = ht.kurtosis(ht_data, axis=ax) - self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + np_kurtosis32 = ht.array( + (ss.kurtosis(np_data, axis=ax, bias=False)), dtype=ht_data.dtype + ) + ht_kurtosis = ht.kurtosis(ht_data, axis=ax) + self.assertTrue(ht.allclose(ht_kurtosis, np_kurtosis32, atol=1e-5)) sp = __split_calc(ht_data.split, ax) - self.assertEqual(ht_skew.split, sp) - self.assertEqual(ht_skew.dtype, ht.float64) + self.assertEqual(ht_kurtosis.split, sp) + self.assertEqual(ht_kurtosis.dtype, ht.float64) # 3 dim ht_data = ht.random.rand(50, 30, 16) np_data = ht_data.copy().numpy() - np_skew32 = ss.kurtosis(np_data, axis=None, bias=False) - self.assertAlmostEqual(ht.kurtosis(ht_data) - np_skew32, 0, places=5) + np_kurtosis32 = ss.kurtosis(np_data, axis=None, bias=False) + self.assertAlmostEqual(ht.kurtosis(ht_data) - np_kurtosis32, 0, places=5) for split in range(3): ht_data = ht.resplit(ht_data, split) for ax in range(3): - np_skew32 = ht.array( + np_kurtosis32 = ht.array( (ss.kurtosis(np_data, axis=ax, bias=False)), dtype=ht_data.dtype ) - ht_skew = ht.kurtosis(ht_data, axis=ax) - self.assertTrue(ht.allclose(ht_skew, np_skew32, atol=1e-5)) + ht_kurtosis = ht.kurtosis(ht_data, axis=ax) + self.assertTrue(ht.allclose(ht_kurtosis, np_kurtosis32, atol=1e-5)) sp = __split_calc(ht_data.split, ax) - self.assertEqual(ht_skew.split, sp) + self.assertEqual(ht_kurtosis.split, sp) def test_max(self): data = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] @@ -937,34 +930,17 @@ def test_skew(self): x = ht.zeros((2, 3, 4)) with self.assertRaises(ValueError): x.skew(axis=10) - with self.assertRaises(ValueError): - x.skew(axis=[4]) - with self.assertRaises(ValueError): - x.skew(axis=[-4]) with self.assertRaises(TypeError): - ht.skew(x, axis="01") - with self.assertRaises(ValueError): - ht.skew(x, axis=(0, "10")) - with self.assertRaises(ValueError): - ht.skew(x, axis=(0, 0)) - with self.assertRaises(ValueError): - ht.skew(x, axis=torch.Tensor([0, 0])) + x.skew(axis=[1, 0]) a = ht.arange(1, 5) self.assertEqual(a.skew(), 0.0) - # todo: 1 dim, 2 dims, 3 dims, 1 axis, 2 axes, 3 axes, - # split != axis, split == axis, split in axis, split not in axis - # comp with numpy for all, rand and randn - def __split_calc(ht_split, axis): - if ht_split is None: - return None - else: - sp = ht_split if axis > ht_split else ht_split - 1 - if axis == ht_split: - sp = None - return sp + sp = ht_split if axis > ht_split else ht_split - 1 + if axis == ht_split: + sp = None + return sp # 1 dim ht_data = ht.random.rand(50) diff --git a/setup.py b/setup.py index fc5c02344e..a6d3dfcfbd 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering", ], - install_requires=["mpi4py>=3.0.0", "numpy>=1.13.0", "torch>=1.5.0"], + install_requires=["mpi4py>=3.0.0", "numpy>=1.13.0", "torch>=1.5.0", "scipy>=0.14.0"], extras_require={ "hdf5": ["h5py>=2.8.0"], "netcdf": ["netCDF4>=1.4.0,<=1.5.2"], From c0fda35065ee57dca342d547704b020cf3c0b3d8 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Wed, 1 Jul 2020 12:55:14 +0200 Subject: [PATCH 12/14] added split=None test for cov --- heat/core/tests/test_statistics.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/heat/core/tests/test_statistics.py b/heat/core/tests/test_statistics.py index 2fbe85e0eb..55ab721fbe 100644 --- a/heat/core/tests/test_statistics.py +++ b/heat/core/tests/test_statistics.py @@ -342,6 +342,28 @@ def test_cov(self): data = np.loadtxt("heat/datasets/data/iris.csv", delimiter=";") np_cov = np.cov(data[:, 0], data[:, 1:3], rowvar=False) + # split = None tests + htdata = ht.load("heat/datasets/data/iris.csv", sep=";", split=None) + ht_cov = ht.cov(htdata[:, 0], htdata[:, 1:3], rowvar=False) + comp = ht.array(np_cov, dtype=ht.float) + self.assertTrue(ht.allclose(comp - ht_cov, 0, atol=1e-4)) + + np_cov = np.cov(data, rowvar=False) + ht_cov = ht.cov(htdata, rowvar=False) + self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float) - ht_cov, 0, atol=1e-4)) + + np_cov = np.cov(data, rowvar=False, ddof=1) + ht_cov = ht.cov(htdata, rowvar=False, ddof=1) + self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float) - ht_cov, 0, atol=1e-4)) + + np_cov = np.cov(data, rowvar=False, bias=True) + ht_cov = ht.cov(htdata, rowvar=False, bias=True) + self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float) - ht_cov, 0, atol=1e-4)) + + # split = 0 tests + data = np.loadtxt("heat/datasets/data/iris.csv", delimiter=";") + np_cov = np.cov(data[:, 0], data[:, 1:3], rowvar=False) + htdata = ht.load("heat/datasets/data/iris.csv", sep=";", split=0) ht_cov = ht.cov(htdata[:, 0], htdata[:, 1:3], rowvar=False) comp = ht.array(np_cov, dtype=ht.float) @@ -360,6 +382,7 @@ def test_cov(self): self.assertTrue(ht.allclose(ht.array(np_cov, dtype=ht.float) - ht_cov, 0, atol=1e-4)) if 1 < x.comm.size < 5: + # split 1 tests htdata = ht.load("heat/datasets/data/iris.csv", sep=";", split=1) np_cov = np.cov(data, rowvar=False) ht_cov = ht.cov(htdata, rowvar=False) From 9a67be030536c9e1072d196f009b5027b9b22dc3 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Wed, 1 Jul 2020 13:54:12 +0200 Subject: [PATCH 13/14] removed f-string notation for standard string --- heat/core/statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heat/core/statistics.py b/heat/core/statistics.py index 64f28a7886..85b43f28c1 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -1391,7 +1391,7 @@ def var(x, axis=None, ddof=0, **kwargs): if not isinstance(ddof, int): raise TypeError(f"ddof must be integer, is {type(ddof)}") elif ddof > 1: - raise NotImplementedError(f"Not implemented for ddof > 1.") + raise NotImplementedError("Not implemented for ddof > 1.") elif ddof < 0: raise ValueError(f"Expected ddof=0 or ddof=1, got {ddof}") else: From 1e9c7faf388cac0954d7b3aa63a92739e3d3c930 Mon Sep 17 00:00:00 2001 From: coquelin77 Date: Thu, 2 Jul 2020 09:12:33 +0200 Subject: [PATCH 14/14] changed numel back to gnumel --- heat/core/dndarray.py | 2 +- heat/core/statistics.py | 16 ++++++++-------- heat/core/tests/test_dndarray.py | 10 +++++----- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index 351ad958f2..57d88ffefa 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -150,7 +150,7 @@ def size(self): return torch.prod(torch.tensor(self.gshape, device=self.device.torch_device)).item() @property - def numel(self): + def gnumel(self): """ Returns diff --git a/heat/core/statistics.py b/heat/core/statistics.py index 85b43f28c1..09f55e3fe7 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -314,7 +314,7 @@ def average(x, axis=None, weights=None, returned=False): if weights is None: result = mean(x, axis) - num_elements = x.numel / result.numel + num_elements = x.gnumel / result.gnumel cumwgt = factories.empty(1, dtype=result.dtype) cumwgt._DNDarray__array = num_elements else: @@ -435,7 +435,7 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None): norm = x.shape[1] - ddof # find normalization: if norm <= 0: - raise ValueError(f"ddof >= number of elements in m, {ddof} {m.numel}") + raise ValueError(f"ddof >= number of elements in m, {ddof} {m.gnumel}") x -= avg.expand_dims(1) c = linalg.dot(x, x.T) c /= norm @@ -476,7 +476,7 @@ def kurtosis(x : DNDarray, axis : Union[None, int] = None, unbiased : bool = Tru if axis is not None and axis > 0: mu = mu.expand_dims(axis) diff = x - mu - n = float(x.shape[axis]) if axis is not None else x.numel + n = float(x.shape[axis]) if axis is not None else x.gnumel m4 = arithmetics.sum(arithmetics.pow(diff, 4.0), axis) / n m2 = arithmetics.sum(arithmetics.pow(diff, 2.0), axis) / n @@ -485,7 +485,7 @@ def kurtosis(x : DNDarray, axis : Union[None, int] = None, unbiased : bool = Tru res = ((n - 1.0) / ((n - 2.0) * (n - 3.0))) * ((n + 1.0) * res - 3 * (n - 1.0)) + 3.0 if Fischer: res -= 3.0 - return res.item() if res.numel == 1 else res + return res.item() if res.gnumel == 1 else res elif isinstance(axis, (list, tuple)): raise TypeError("axis cannot be a list or a tuple, currently {}".format(type(axis))) else: @@ -1216,14 +1216,14 @@ def skew(x : DNDarray, axis : Union[None, int] = None, unbiased : bool = True) - mu = mu.expand_dims(axis) diff = x - mu - n = float(x.shape[axis]) if axis is not None else x.numel + n = float(x.shape[axis]) if axis is not None else x.gnumel m3 = arithmetics.sum(arithmetics.pow(diff, 3.0), axis) / n m2 = arithmetics.sum(arithmetics.pow(diff, 2.0), axis) / n res = m3 / arithmetics.pow(m2, 1.5) if unbiased: res *= ((n * (n - 1.0)) ** 0.5) / (n - 2.0) - return res.item() if res.numel == 1 else res + return res.item() if res.gnumel == 1 else res elif isinstance(axis, (list, tuple)): raise TypeError(f"axis cannot be a list or a tuple, currently {type(axis)}") else: @@ -1290,7 +1290,7 @@ def __torch_skew(torch_tensor, dim=None, unbiased=False): m3 = torch.true_divide(torch.sum(torch.pow(diff, 3), dim=dim), n) m2 = torch.true_divide(torch.sum(torch.pow(diff, 2), dim=dim), n) else: - n = torch_tensor.numel() + n = torch_tensor.gnumel() diff = torch_tensor - torch.mean(torch_tensor) m3 = torch.true_divide(torch.sum(torch.pow(diff, 3)), n) m2 = torch.true_divide(torch.sum(torch.pow(diff, 2)), n) @@ -1312,7 +1312,7 @@ def __torch_kurtosis(torch_tensor, dim=None, Fischer=True, unbiased=False): m4 = torch.true_divide(torch.sum(torch.pow(diff, 4.0), dim=dim), n) m2 = torch.true_divide(torch.sum(torch.pow(diff, 2.0), dim=dim), n) else: - n = torch_tensor.numel() + n = torch_tensor.gnumel() diff = torch_tensor - torch.mean(torch_tensor) m4 = torch.true_divide(torch.pow(diff, 4.0), n) m2 = torch.true_divide(torch.pow(diff, 2.0), n) diff --git a/heat/core/tests/test_dndarray.py b/heat/core/tests/test_dndarray.py index c6174a57c8..12d803ca12 100644 --- a/heat/core/tests/test_dndarray.py +++ b/heat/core/tests/test_dndarray.py @@ -1009,22 +1009,22 @@ def test_setitem_getitem(self): a[0] = ht.array([6, 6, 6, 6, 6]) self.assertTrue((a[ht.array((0,))] == 6).all()) - def test_size_numel(self): + def test_size_gnumel(self): a = ht.zeros((10, 10, 10), split=None) self.assertEqual(a.size, 10 * 10 * 10) - self.assertEqual(a.numel, 10 * 10 * 10) + self.assertEqual(a.gnumel, 10 * 10 * 10) a = ht.zeros((10, 10, 10), split=0) self.assertEqual(a.size, 10 * 10 * 10) - self.assertEqual(a.numel, 10 * 10 * 10) + self.assertEqual(a.gnumel, 10 * 10 * 10) a = ht.zeros((10, 10, 10), split=1) self.assertEqual(a.size, 10 * 10 * 10) - self.assertEqual(a.numel, 10 * 10 * 10) + self.assertEqual(a.gnumel, 10 * 10 * 10) a = ht.zeros((10, 10, 10), split=2) self.assertEqual(a.size, 10 * 10 * 10) - self.assertEqual(a.numel, 10 * 10 * 10) + self.assertEqual(a.gnumel, 10 * 10 * 10) self.assertEqual(ht.array(0).size, 1)