In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.neighbors import KernelDensity
from scipy import stats
import pandas as pd

In [None]:
FIG_DIM = 5
DPI = 300
AX_FONT = 14
TICK_FONT = 12

In [None]:
def getKernelDensityAll(data, species_info, measure,selection_data = None, selection_key = None,weights_key=None, min_val = 0, 
                     max_val = 1,n_steps=1000):
    lines = {}
    for key in species_info['Sample ID']:
        print(key)
        values = data[key][measure]
        if selection_key:
            selection_ids = selection_data[key][selection_key]
        else: 
            selection_ids = np.arange(len(values))
        
        values = values[selection_ids]
        if weights_key:
            weight = data[key][weights_key]
            weight = weight[selection_ids]
        else: 
            weight = None
        
        kde = getKernelDensityEstimator(values, weight)
        test_vals = np.linspace(min_val, max_val, n_steps)
        pdf = kde(test_vals)
        lines[key] = pdf
    return lines, [min_val, max_val, n_steps]

def getKernelDensityEstimator(values, weights):
    nparam_density = stats.kde.gaussian_kde(values,weights = weights)
    return nparam_density
    
def getLikelyPerSpecies(lines, species_info):
    specs = ['bifurcata', 'cytherea', 'hyacinthus']
    species_relevant = [species_info[species_info['Morphospecies'] == spec]['Sample ID'].values for spec in specs]
    results  = {}
    for x in range(3):
        IDS = species_relevant[x]
        vals = np.zeros((len(IDS),len(lines[IDS[0]])))
        for i,sp in enumerate(IDS):
            vals[i] = lines[sp]
        results[specs[x]] = {
            'mean': np.mean(vals, axis = 0),
            'SD': np.std(vals, axis = 0)
        }
    return results

def plotKDE(lines, rang,species_info, xlabel,xlim=None):
    if xlim == None:
        xlim = (rang[0], rang[1])
    per_spec = getLikelyPerSpecies(lines, species_info)
    x = np.linspace(rang[0], rang[1], rang[2])
    colors = {'bifurcata':'#1f77b4', 'cytherea':'#2ca02c','hyacinthus':'#ff7f0e'}

    fig = plt.figure(figsize=(5,5))
    for key in per_spec:
        data = per_spec[key]
        plt.plot(x,data['mean'], color = colors[key], label = key)
        plt.fill_between(x, data['mean'] - data['SD'], data['mean']+ data['SD'],
                        alpha = .2, color = colors[key])
    plt.xlabel(xlabel, fontsize=AX_FONT)
    plt.ylabel('density', fontsize = AX_FONT)
    plt.legend(fontsize = AX_FONT)
    plt.tick_params(axis='both', which='major', labelsize=TICK_FONT)
    plt.grid()
    plt.xlim(xlim)
    plt.ylim(0)
    plt.tight_layout()
    return fig



In [None]:
species_info = pd.read_csv(f'{ldir.DIR_DATA}/species_info_v1.csv')
da_measures = load_data.readPickle('curvatures.pickle')

# Example

In [None]:
sphere_data = load_data.readPickle('spheres_angles.pickle')
lines, rang = getKernelDensityAll(sphere_data, species_info, 'da',max_val = 1.7)


In [None]:
plotKDE(lines, rang, species_info, xlabel = f'da(cm)')
