In [9]:
import numpy as np
from statsmodels.stats.multitest import multipletests

In [2]:
def permutation_test(data1, data2, random_state=None, nsteps=100000, keep_vals=False):
    """
    2-tailed Permutation test: test if the difference between two groups is significant.
    """
    len1, len2 = len(data1), len(data2)
    k = np.zeros(nsteps)
    np.random.seed(random_state)

    diff = np.abs(np.mean(data1) - np.mean(data2))  # observed difference
    z = np.concatenate([data1, data2])
    for i in range(nsteps):
        np.random.shuffle(z)
        k[i] = np.abs(np.mean(z[:len1]) - np.mean(z[len1:]))

    # get index of sample mean difference that larger than or equal to observed difference
    p_value = len(np.where(k >= diff)[0]) / nsteps

    if keep_vals:
        return {"diff": diff, "k": k, "p_value": p_value}
    else:
        return {"diff": diff, "p_value": p_value}


In [3]:
msm4 = np.load('SASA_agg_states/sasa_agg_state_4.npy') 
native = np.load('SASA_agg_states/sasa_agg_state_5.npy')

In [4]:
permutation_test(msm4, native)

{'diff': 12.685387363535142, 'p_value': 0.12369}

In [6]:
p_values = []
native = np.load('SASA_agg_states/sasa_agg_state_5.npy')
for i in range(5):
    msm = np.load(f'SASA_agg_states/sasa_agg_state_{i}.npy')
    res = permutation_test(msm, native)
    p_values.append(res['p_value'])
    print(f"State {i}: p-value = {res['p_value']}, diff: {res['diff']}")

State 0: p-value = 0.00015, diff: 15.608220537713805
State 1: p-value = 0.0, diff: 45.88256036206451
State 2: p-value = 0.0, diff: 9.608756217977088
State 3: p-value = 0.0, diff: 119.123324180733
State 4: p-value = 0.1229, diff: 12.685387363535142


In [7]:
p_values

[0.00015, 0.0, 0.0, 0.0, 0.1229]

In [10]:
adjusted_pvalues = multipletests(p_values, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)


In [14]:
for pval, adj_pval in zip(p_values, adjusted_pvalues[1]):
    print(f"{pval:.5f} {adj_pval:.5f}")

0.00015 0.00019
0.00000 0.00000
0.00000 0.00000
0.00000 0.00000
0.12290 0.12290


In [12]:
adjusted_pvalues

(array([ True,  True,  True,  True, False]),
 array([0.0001875, 0.       , 0.       , 0.       , 0.1229   ]),
 0.010206218313011495,
 0.01)