Skip to content

Commit

Permalink
added median downsampling method
Browse files Browse the repository at this point in the history
  • Loading branch information
jcheong0428 committed Mar 2, 2017
1 parent a3a1614 commit 6541044
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
12 changes: 9 additions & 3 deletions nltools/stats.py
Expand Up @@ -193,22 +193,25 @@ def calc_bpm(beat_interval, sampling_freq):
'''
return 60*sampling_freq*(1/(beat_interval))

def downsample(data,sampling_freq=None, target=None, target_type='samples'):
def downsample(data,sampling_freq=None, target=None, target_type='samples',method='mean'):
''' Downsample pandas to a new target frequency or number of samples using averaging.
Args:
data: Pandas DataFrame or Series
sampling_freq: Sampling frequency of data
target: downsampling target
target_type: type of target can be [samples,seconds,hz]
method: (str) type of downsample method ['mean','median'], default: mean
Returns:
downsampled pandas object
'''

if not isinstance(data,(pd.DataFrame,pd.Series)):
raise ValueError('Data must by a pandas DataFrame or Series instance.')

if not (method=='median') | (method=='mean'):
raise ValueError("Metric must be either 'mean' or 'median' ")

if target_type is 'samples':
n_samples = target
elif target_type is 'seconds':
Expand All @@ -222,7 +225,10 @@ def downsample(data,sampling_freq=None, target=None, target_type='samples'):
# if data.shape[0] % n_samples:
if data.shape[0] > len(idx):
idx = np.concatenate([idx, np.repeat(idx[-1]+1,data.shape[0]-len(idx))])
return data.groupby(idx).mean()
if method=='mean':
return data.groupby(idx).mean()
elif method=='median':
return data.groupby(idx).median()

def fisher_r_to_z(r):
''' Use Fisher transformation to convert correlation to z score '''
Expand Down
11 changes: 9 additions & 2 deletions nltools/tests/test_stats.py
@@ -1,6 +1,6 @@
import numpy as np
import pandas as pd
from nltools.stats import one_sample_permutation,two_sample_permutation,correlation_permutation
from nltools.stats import one_sample_permutation,two_sample_permutation,correlation_permutation, downsample

def test_permutation():
dat = np.random.multivariate_normal([2,6],[[.5,2],[.5,3]],100)
Expand All @@ -19,4 +19,11 @@ def test_permutation():
assert stats['p']< .001
stats = correlation_permutation(x,y,metric='kendall')
assert (stats['correlation']>.4) & (stats['correlation']<.85)
assert stats['p']< .001
assert stats['p']< .001

def test_downsample():
dat = pd.DataFrame()
dat['x'] = range(0,100)
dat['y'] = np.repeat(range(1,11),10)
assert((dat.groupby('y').mean().values.ravel() == downsample(data=dat['x'],sampling_freq=10,target=1,target_type='hz',method='mean').values).all)
assert((dat.groupby('y').median().values.ravel() == downsample(data=dat['x'],sampling_freq=10,target=1,target_type='hz',method='median').values).all)

0 comments on commit 6541044

Please sign in to comment.