In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import matplotlib as mpl
import os
from toolz import concat
from matplotlib.lines import Line2D
from collections import Counter
import math
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsRegressor
from aging.behavior.syllables import relabel_by_usage
from tqdm import tqdm
%matplotlib inline
import warnings
warnings.simplefilter('ignore')
import random
import scipy
from sklearn.pipeline import make_pipeline
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import adjusted_rand_score
from kneed import KneeLocator
from sklearn.metrics import silhouette_score
from scipy.spatial.distance import jensenshannon
%matplotlib inline
from scipy.ndimage import gaussian_filter1d
from aging.plotting import format_plots, PlotConfig, save_factory, figure, legend, format_pizza_plots

In [2]:
format_plots()
#format_pizza_plots()

In [3]:
from matplotlib.colors import LinearSegmentedColormap
# Define the color map
colors = ['#fee6ce','#d94801']
custom_cmap = LinearSegmentedColormap.from_list("custom_purples", colors, N=256)
cmf=custom_cmap

colors = ['#c7eae5','#008C8D']
custom_cmap = LinearSegmentedColormap.from_list("custom_purples", colors, N=256)
cmm=custom_cmap

In [4]:
# prepare colors for plot
colors = ['#c7eae5','#008C8D']
import matplotlib
cmap = LinearSegmentedColormap.from_list("custom_purples", colors, N=256)
matplotlib.cm.register_cmap("dana", cmap)
pm = sns.color_palette("dana", n_colors=27)

# prepare colors for plot
colors = ['#fee6ce','#d94801']
import matplotlib
cmap = LinearSegmentedColormap.from_list("custom_purples", colors, N=256)
matplotlib.cm.register_cmap("dana2", cmap)
pf = sns.color_palette("dana2", n_colors=27)

In [5]:
## update data
def filter_df(df):
    max=39
    syll=df.columns.values[df.columns.values>max]
    df.drop(syll,axis=1, inplace=True)
    return df[~df.index.get_level_values('session_name').str.contains('CRL')]

In [6]:
path = Path('/n/groups/datta/win/longtogeny/data/ontogeny/version_11/dana_ontogeny_males_relabeled_usage_matrix_v00.parquet')
df_dana = pd.read_parquet(path)

path = Path('/n/groups/datta/win/longtogeny/data/ontogeny/version_11/ontogeny_males_relabeled_usage_matrix_v00.parquet')
df = pd.read_parquet(path)

# arrange data
data = df.astype(float, errors='ignore')
data= filter_df(data).groupby(['age','uuid']).mean()
data['data'] = 'min'
data.set_index('data', inplace=True, append=True)

data_dana = df_dana.astype(float, errors='ignore')
data_dana= filter_df(data_dana).groupby(['age','uuid']).mean()
data_dana['data'] = 'dana'
data_dana.set_index('data', inplace=True, append=True)

data_all = pd.concat([data,data_dana])
# choose 16 mice from all dataset
random_seed = 0
n=16
data_all = data_all.groupby('age').apply(lambda x: x.sample(n=n,replace=False,random_state=random_seed) if len(x)>n else x.sample(n=len(x),replace=False,random_state=random_seed)).reset_index(level=0, drop=True)
data_all = data_all[sorted(data_all.columns)]
m_data = data_all.query('age<100')

In [7]:
path = Path('/n/groups/datta/win/longtogeny/data/ontogeny/version_11/dana_ontogeny_females_relabeled_usage_matrix_v00.parquet')
df_dana = pd.read_parquet(path)

path = Path('/n/groups/datta/win/longtogeny/data/ontogeny/version_11/ontogeny_females_relabeled_usage_matrix_v00.parquet')
df = pd.read_parquet(path)

# arrange data
data = df.astype(float, errors='ignore')
data= filter_df(data).groupby(['age','uuid']).mean()
data['data'] = 'min'
data.set_index('data', inplace=True, append=True)

data_dana = df_dana.astype(float, errors='ignore')
data_dana= filter_df(data_dana).groupby(['age','uuid']).mean()
data_dana['data'] = 'dana'
data_dana.set_index('data', inplace=True, append=True)

data_all = pd.concat([data,data_dana])
# choose 16 mice from all dataset
random_seed = 0
n=16
data_all = data_all.groupby('age').apply(lambda x: x.sample(n=n,replace=False,random_state=random_seed) if len(x)>n else x.sample(n=len(x),replace=False,random_state=random_seed)).reset_index(level=0, drop=True)
data_all = data_all[sorted(data_all.columns)]
f_data = data_all.query('age<100')

In [8]:
# use only common ages
f_ages = f_data.index.get_level_values('age').unique().to_numpy()
m_ages = m_data.index.get_level_values('age').unique().to_numpy()
cages = list(np.intersect1d(f_ages,m_ages))
m_data = m_data.loc[m_data.index.get_level_values('age').isin(cages)]
m_data['sex'] = 'm'
m_data.set_index('sex',inplace=True,append=True)

f_data = f_data.loc[f_data.index.get_level_values('age').isin(cages)]
f_data['sex'] = 'f'
f_data.set_index('sex',inplace=True,append=True)

data=pd.concat([m_data,f_data])
#data.set_index('sex',inplace=True,append=True)

In [9]:
fig=plt.figure(figsize=(12, 12))
plt.subplots_adjust(hspace=0.5)
plt.suptitle("Entropy by age females", fontsize=10, y=0.95)
n=0
ages=[]
ent=[]


sigma = 2  
window_size = 3  

# Apply Gaussian smoothing
#smoothed_arr = gaussian_filter1d(arr, sigma=sigma)
#smoothed_arr = np.convolve(arr, np.ones(window_size)/window_size, mode='valid')
for age,_df in enumerate(f_data.groupby('age')):
    ages.extend(_df[1].index.get_level_values('age').to_numpy())
    data_temp = _df[1].to_numpy()
    arr = scipy.stats.entropy(data_temp, axis=1)
    #ent.extend(np.convolve(arr, np.ones(window_size)/window_size, mode='valid'))
    ent.extend(gaussian_filter1d(arr, sigma=sigma))
    ax = plt.subplot(6, 6, n + 1)
    sns.barplot(data_temp, errorbar='se',ax=ax)
    sns.lineplot(np.mean(data_temp, axis=0),ax=ax, color='#f16913')
    ax.set_title('syllable usage ' + str(_df[1].index.get_level_values('age').unique().to_numpy()[0]))
    ax.set_ylim(0,0.08)
    ax.set_xticks([])
    ax.set_ylabel('fraction of use')
    n=n+1
    sns.despine()

In [10]:
c = PlotConfig()
#fig.savefig(c.dana_save_path / "simons"/ 'female_entropy_examples.pdf')

In [11]:
# plot entropy values
dff=pd.DataFrame()
ages = [str(num) for num in ages]
dff['ages']=ages
dff['ent']=ent
dff['sex'] = ['f'] * len(ages)
fig, ax = plt.subplots()
fig.set_size_inches(2,2)
#sns.lineplot(data=df, x="ages", y="ent", estimator='mean', errorbar='se', c='black')

sns.boxplot(data=dff, x="ages", y="ent",
             palette = pf, dodge=True, showfliers=False, linewidth=0.5, showcaps=False)
#sns.boxplot(data=dff, x="ages", y="ent",showfliers=False, color='#f16913')
#sns.scatterplot(data=dff, x="ages", y="ent",color='#f16913', alpha=0.50,s=10)
ax.set_title('entropy females')
ax.set_ylim(3,3.6)
sns.despine()

In [12]:
c = PlotConfig()
fig.savefig(c.dana_save_path / "fig3"/ 'female_entropy_boxplots.pdf')

In [13]:
fig=plt.figure(figsize=(12, 12))
plt.subplots_adjust(hspace=0.5)
plt.suptitle("Entropy by age males", fontsize=10, y=0.95)
n=0
ages=[]
ent=[]

for age,_df in enumerate(m_data.groupby('age')):
    ages.extend(_df[1].index.get_level_values('age').to_numpy())
    data_temp = _df[1].to_numpy()
    #ent.extend(scipy.stats.entropy(data_temp, axis=1))
    arr = scipy.stats.entropy(data_temp, axis=1)
    #ent.extend(np.convolve(arr, np.ones(window_size)/window_size, mode='valid'))
    ent.extend(gaussian_filter1d(arr, sigma=sigma))
    ax = plt.subplot(6, 6, n + 1)
    sns.barplot(data_temp, errorbar='se',ax=ax)
    sns.lineplot(np.mean(data_temp, axis=0),ax=ax, color='#20B2AA')
    ax.set_title('syllable usage ' + str(_df[1].index.get_level_values('age').unique().to_numpy()[0]))
    ax.set_ylim(0,0.08)
    ax.set_xticks([])
    ax.set_ylabel('fraction of use')
    n=n+1
    sns.despine()

In [14]:
c = PlotConfig()
#fig.savefig(c.dana_save_path / "simons"/ 'male_entropy_examples.pdf')

In [15]:
# plot entropy values
dfm=pd.DataFrame()
ages = [str(num) for num in ages]
dfm['ages']=ages
dfm['ent']=ent
dfm['sex'] = ['m'] * len(ages)
fig, ax = plt.subplots()
fig.set_size_inches(2,2)
#sns.lineplot(data=df, x="ages", y="ent", estimator='mean', errorbar='se', c='black')
#sns.boxplot(data=dfm, x="ages", y="ent",showfliers=False, color='#20B2AA')
sns.boxplot(data=dfm, x="ages", y="ent",
             palette = pm, dodge=True, showfliers=False, linewidth=0.5, showcaps=False)
#sns.scatterplot(data=dfm, x="ages", y="ent",color='#20B2AA', alpha=0.50,s=10)
plt.setp(ax.collections, alpha=0.2)
plt.setp(ax.collections, edgecolor='none')
ax.set_title('entropy males')
ax.set_ylim(3,3.6)
sns.despine()

In [16]:
c = PlotConfig()
fig.savefig(c.dana_save_path / "fig3"/ 'male_entropy_boxplots.pdf')

In [17]:
## show mean
fig, ax = plt.subplots()
fig.set_size_inches(2,2)
df_all = pd.concat([dfm,dff])
sns.pointplot(data=df_all, x="ages", 
              y="ent",
              estimator='mean', 
              errorbar='se', 
              hue = 'sex', 
              palette=['#20B2AA','#f16913'],
              join=False)

#ax.set_ylim(3.2,3.7)
plt.setp(ax.collections, alpha=.7)
plt.setp(ax.collections, edgecolor='none')
sns.despine()

In [18]:
c = PlotConfig()
#fig.savefig(c.dana_save_path / "fig3"/ 'fvm_entropy_pointplot.pdf')

In [19]:
## show mean
fig, ax = plt.subplots()
fig.set_size_inches(2,2)
df_all = pd.concat([dfm,dff])
ax=sns.boxplot(data=df_all, x="ages", 
              y="ent",
              #estimator='mean', 
              #errorbar='se', 
              hue = 'sex', 
              palette=['#20B2AA','#f16913'],
            
            showfliers=False,
              #join=False
           )

#ax.set_ylim(3.2,3.7)
plt.setp(ax.collections, alpha=.7)
plt.setp(ax.collections, edgecolor='none')
sns.despine()