Skip to content

Commit

Permalink
Simplified nan_weighted_mean
Browse files Browse the repository at this point in the history
  • Loading branch information
morganjwilliams committed Jul 11, 2018
1 parent 2c35aca commit 65a51f2
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 35 deletions.
36 changes: 13 additions & 23 deletions pyrolite/compositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,25 @@ def weights_from_array(arr:np.ndarray):
"""
wts = np.ones((arr.shape[0]))
wts = wts/np.sum(wts)
wts = wts.T
wts = wts
return wts


def nan_weighted_mean(arr:np.ndarray, weights=None,):
if weights is None:
weights = weights or weights_from_array(arr)

#if arr.ndim == 1: arr = arr.reshape((1, *arr.shape))
#if weights.ndim == 1: weights = weights.reshape((*weights.shape, 1))

if np.isnan(arr).any():
mean = np.nanmean(arr, axis=0)
if not (weights == weights[0]).all(): # if weights needed
cs = np.arange(arr.shape[1])
nonnan_idx = np.nonzero(~np.isnan(arr[:, cs]))
if len(nonnan_idx[0]): # if there are any non-nan elements
c_weights = weights.copy()
c_weights = c_weights[nonnan_idx] / \
c_weights[nonnan_idx].sum()
mean[c] = arr[:, cs][nonnan_idx] @ c_weights
weights = weights_from_array(arr)
weights = np.array(weights)/np.nansum(weights)

mask = (np.isnan(arr) + np.isinf(arr)) > 0
if not mask.any():
return np.average(arr,
weights=weights,
axis=0)
else:
mean = arr.T @ weights
mean = mean.T.squeeze()
mean = mean.reshape(arr.shape[1:]) # this should be compatible
return mean
return np.ma.average(np.ma.array(arr, mask=mask),
weights=weights,
axis=0)



def compositional_mean(df, weights=[], **kwargs):
Expand Down Expand Up @@ -101,9 +94,6 @@ def nan_weighted_compositional_mean(arr: np.ndarray,
else:
weights = np.array(weights)/np.sum(weights, axis=-1)

if weights.ndim == 1:
weights = weights.reshape((*weights.shape, 1))

if ind is None: # take the first column which has no nans
ind = get_nonnan_column(arr)

Expand Down
64 changes: 52 additions & 12 deletions test/test_compositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
import logging
log = logging.getLogger(__name__)


def test_df(cols=['SiO2', 'CaO', 'MgO', 'FeO', 'TiO2'],
index_length=10):
return pd.DataFrame({k: v for k,v in zip(cols,
np.random.rand(len(cols), index_length))})


class TestClose(unittest.TestCase):
"""Tests array closure operator."""

Expand Down Expand Up @@ -36,8 +43,7 @@ class TestCompositionalMean(unittest.TestCase):

def setUp(self):
self.cols = ['SiO2', 'CaO', 'MgO', 'FeO', 'TiO2']
self.df = pd.DataFrame({k: v for k,v in zip(self.cols,
np.random.rand(len(self.cols), 10))})
self.df = test_df(cols=self.cols)

def test_1D(self):
"""Checks results on single records."""
Expand Down Expand Up @@ -80,8 +86,7 @@ class TestWeightsFromArray(unittest.TestCase):

def setUp(self):
self.cols = ['SiO2', 'CaO', 'MgO', 'FeO', 'TiO2']
self.df = pd.DataFrame({k: v for k,v in zip(self.cols,
np.random.rand(len(self.cols), 10))})
self.df = test_df(cols=self.cols)

def test_single(self):
"""Checks results on single records."""
Expand All @@ -101,8 +106,7 @@ class TestGetNonNanColumn(unittest.TestCase):

def setUp(self):
self.cols = ['SiO2', 'CaO', 'MgO', 'FeO', 'TiO2']
self.df = pd.DataFrame({k: v for k,v in zip(self.cols,
np.random.rand(len(self.cols), 10))})
self.df = test_df(cols=self.cols)
nans = 10
self.df.iloc[np.random.randint(1, 10, size=nans),
np.random.randint(1, len(self.cols), size=nans)] = np.nan
Expand All @@ -126,8 +130,7 @@ class TestNANWeightedMean(unittest.TestCase):

def setUp(self):
self.cols = ['SiO2', 'CaO', 'MgO', 'FeO', 'TiO2']
self.df = pd.DataFrame({k: v for k,v in zip(self.cols,
np.random.rand(len(self.cols), 10))})
self.df = test_df(cols=self.cols)

def test_single(self):
"""Checks results on single records."""
Expand All @@ -141,15 +144,50 @@ def test_multiple(self):
out = nan_weighted_mean(df.values)
self.assertTrue(np.allclose(out, np.mean(df.values, axis=0)))

def test_multiple_equal_weights(self):
"""Checks results on multiple records with equal weights."""
df = self.df
weights = np.array([1./ len(df.index)] * len(df.index))
out = nan_weighted_mean(df.values, weights=weights)
self.assertTrue(np.allclose(out, np.average(df.values,
weights=weights,
axis=0))
)

def test_multiple_unequal_weights(self):
"""Checks results on multiple records with unequal weights."""
df = self.df
weights = np.random.rand(1, df.index.size).squeeze()
out = nan_weighted_mean(df.values, weights=weights)
check = np.average(df.values.T, weights=weights, axis=1)
self.assertTrue(np.allclose(out, np.average(df.values,
weights=weights,
axis=0))
)

def test_multiple_unequal_weights_withnan(self):
"""
Checks results on multiple records with unequal weights,
where the data includes some null data.
"""
df = self.df
df.iloc[0, :] = np.nan # make one record nan
# Some non-negative weights

weights = np.random.rand(1, df.index.size).squeeze()
weights = np.array(weights)/np.nansum(weights)
out = nan_weighted_mean(df.values, weights=weights)
check = np.average(df.iloc[1:, :].values, weights=weights[1:], axis=0)
self.assertTrue(np.allclose(out, check))



class TestNANWeightedCompositionalMean(unittest.TestCase):
"""Tests numpy weighted compositonal NaN-mean operator."""

def setUp(self):
self.cols = ['SiO2', 'CaO', 'MgO', 'FeO', 'TiO2']
self.df = pd.DataFrame({k: v for k,v in zip(self.cols,
np.random.rand(len(self.cols), 10))})
self.df = test_df(cols=self.cols)
self.df = self.df.apply(lambda x: x/np.sum(x), axis='columns')

def test_single(self):
Expand Down Expand Up @@ -196,8 +234,7 @@ def setUp(self):
self.cols = ['SiO2', 'CaO', 'MgO', 'FeO', 'TiO2']
self.d = len(self.cols)
self.n = 10
self.df = pd.DataFrame({k: v for k,v in zip(self.cols,
np.random.rand(len(self.cols), self.n))})
self.df = test_df(cols=self.cols, index_length=self.n)

def test_single(self):
"""Checks results on single record."""
Expand Down Expand Up @@ -342,6 +379,9 @@ def test_fixed_record(self):
class TestComplexStandardiseAggregate(unittest.TestCase):
"""Tests pandas complex internal standardisation aggregation method."""

def setUp(self):
pass

def test_single(self):
"""Checks results on single records."""
pass
Expand Down

0 comments on commit 65a51f2

Please sign in to comment.