In [None]:
import numpy as np
import scipy.stats as spy
import pandas as pd
import copy
from ..helpers import load_data
from ..helpers import local_directories as ldir

In [None]:
def getAll(measure_array,id_name):
    return measure_array

def getSelection(measure_array, id_name, selection):
    ids = selection[id_name]
    return measure_array[ids]

def getPerc(measure_array, id_name,perc=2.5):
    min_s = np.percentile(measure_array,perc)
    max_s = np.percentile(measure_array,100-perc)
    return measure_array[(measure_array > min_s) & (measure_array < max_s)]

def getZ(measure_array, id_name, SD=3):
    Z = np.abs((measure_array - np.mean(measure_array))/np.std(measure_array))
    return measure_array[Z < SD]

def getSample(measure_array, n=10000):
    data = np.random.choice(measure_array, size=n, replace=True)
    return data

def getMannWhit(input_data, ids, measure_name,get_func=getAll, sample_size = None, **kwargs):
    
    size = len(ids)
    data_array = list()
    num_vals = np.zeros(size,dtype=int)
    for i in range(size):
        samp = ids[i]
        
        data_raw = input_data[samp][measure_name]
        data = get_func(data_raw,samp,**kwargs)
        if sample_size:
            data = getSample(data, sample_size)
        data_array.append(copy.deepcopy(data))
        num_vals[i] = len(data)
    stat_arr, pval_arr = calcMannWhitArray(data_array, size)
    return {'statistic':stat_arr, 'p-values': pval_arr, 'sizes': num_vals}

def calcMannWhitArray(data_arrays, size):

    stat_arr = np.zeros((size,size))
    pval_arr = np.zeros((size,size))
    for i in range(size):
        d1 = data_arrays[i]
        stat_arr[i,i], pval_arr[i,i] = spy.ks_2samp(d1,d1)
        for j in range(size):
            if j > i:
                stat, pval = spy.mannwhitneyu(d1, data_arrays[j])
                stat_arr[i,j] = stat
                stat_arr[j,i] = stat
                pval_arr[i,j] = pval
                pval_arr[j,i] = pval
            
    return stat_arr, pval_arr


def getKeys(data):
    for coral in data:
        data_arr = data[coral]
        keys = list()
        for key in data_arr:
            numbers = data_arr[key]
            if len(numbers.shape) == 1 and isinstance(numbers[0], float):
                keys.append(key)
        return keys

    

In [None]:
species_info = pd.read_csv(f'{ldir.DIR_DATA}/species_info_v1.csv')
species_info = species_info.sort_values(by='Morphospecies')
ids = list(species_info['Sample ID'])
del species_info

In [None]:
def prepareSelection(fn_selection, selection_name, ids):
    selection = load_data.readPickle(fn_selection)
    selection_dict = {}
    for coral_id in ids:
        selection_dict[coral_id] = selection[coral_id][selection_name]
    return selection_dict
        
    

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm

# Example

In [None]:
ks_measures = {}
n_sample = 1000
data = load_data.readPickle('curvatures.pickle')

measure = 'Gauss'
ks_test = getMannWhit(data,ids,key, sample_size=n_sample)
print(ks_test['p-values'])
print(ks_test['statistic'])
