In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from aging.plotting import format_plots, PlotConfig, save_factory, figure, legend, format_pizza_plots
from collections import Counter
from matplotlib.lines import Line2D
from tqdm import tqdm

In [2]:
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn import svm
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from matplotlib.gridspec import GridSpec
from collections import defaultdict
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.model_selection import ShuffleSplit
from sklearn.svm import LinearSVC
from sklearn.model_selection import GroupKFold, LeaveOneGroupOut, LeaveOneOut, KFold
from sklearn.metrics import accuracy_score
import random

In [3]:
def mm_norm_col(column):
    return (column - column.min()) / (column.max() - column.min())

In [4]:
keep_syllables = np.loadtxt('/n/groups/datta/win/longtogeny/data/ontogeny/version_11/to_keep_syllables_raw.txt', dtype=int)

df = pd.read_parquet('/n/groups/datta/win/longtogeny/data/ontogeny/version_11-1/longtogeny_v2_females_raw_usage_matrix_v00.parquet').astype(float)
df = df[keep_syllables].groupby(['age','uuid','mouse']).mean()

ages= df.index.get_level_values('age')
weeks = (ages * 7) // 7
df['binned_age'] = weeks
sample = df.groupby(['mouse','binned_age']).first().reset_index()
sample2 = df.groupby(['mouse','binned_age']).last().reset_index()
sample2['binned_age'] = sample2['binned_age']+0.5
df_male = pd.concat([sample,sample2])
#df_male = sample.copy()
df_male.rename(columns={'binned_age': 'age'}, inplace=True)

df_male.set_index(['age', 'mouse'], inplace=True)

In [5]:
# avarage per age or session
m_df= df_male.groupby(['age','mouse']).mean()
avg_m_df = m_df.groupby(['age']).mean()

# normalize the data
m_norm = mm_norm_col(m_df)
avg_m_norm = mm_norm_col(avg_m_df)

In [6]:
df_indv=m_norm.groupby(['mouse','age']).mean()
df=df_indv.copy().reset_index()

In [7]:
# choose a model
clf = svm.SVC(kernel='linear') 

# choose cross validation scheme
cv = LeaveOneOut()
#cv = ShuffleSplit(n_splits=5, test_size=0.25, random_state=0)
#cv=5

#number of iterations for shuffle
it=100

In [8]:
# prepare colors for plot
from matplotlib.colors import LinearSegmentedColormap
colors = ['#fee6ce','#d94801']
import matplotlib
cmap = LinearSegmentedColormap.from_list("custom_purples", colors, N=256)
matplotlib.cm.register_cmap("dana", cmap)
pl = sns.color_palette("dana", n_colors=50)

In [9]:
xmale = pd.DataFrame()
a = 0
n=8
age = []
acc = []
sh_acc = []
sh_ages = []
coefficients_list = []

while True:
    # Build the current age matrix
    for m, _df in enumerate(df.groupby('mouse')):
        xmale_temp = _df[1][a:a+n]
        xmale = pd.concat([xmale, xmale_temp])
        
    rep_counts = xmale.groupby('mouse')['age'].size()
    keep_mice = list(rep_counts[rep_counts >= n].index)  # keep mice that have at least nmin sessions
    xmale = xmale[xmale['mouse'].isin(keep_mice)]
    if xmale.empty:
        break
    age.append(np.floor(np.mean(xmale['age'])).astype(int))    
    ymale = xmale['mouse']
    x = xmale.drop(['mouse', 'age'], axis=1).to_numpy()
    
    # Fit the model and collect coefficients
    clf.fit(x, ymale)
    coefficients_list.append(np.mean(clf.coef_,axis=0))
    
    # Perform cross-validation
    y_pred = cross_val_predict(clf, x, ymale, cv=cv)
    temp_acc = np.mean(cross_val_score(clf, x, ymale, cv=cv))
    acc.append(temp_acc)
    
    # Shuffle labels and calculate shuffled accuracy
    for i in tqdm(range(it)):
        ysh_temp = np.random.permutation(ymale)
        y_pred_temp = cross_val_score(clf, x, ysh_temp, cv=cv, n_jobs=-1)
        sh_acc.append(np.mean(y_pred_temp))
        sh_ages.append(np.floor(np.mean(xmale['age'])).astype(int))
    
    # Handle variables for the loop
    xmale = pd.DataFrame()
    a = a + 1


In [10]:
# create df for plotting
df_sh = pd.DataFrame()
df_sh['acc']=sh_acc
df_sh['ages'] = sh_ages

df = pd.DataFrame()
df['acc']=acc
df['ages'] = age

In [11]:
format_plots()

In [12]:
fig, ax = plt.subplots(figsize=(1.3, 1.3))
sns.pointplot(data=df_sh, x="ages", y="acc", ax=ax, color='grey',estimator='mean', errorbar='se',join=True)
plt.setp(ax.collections, alpha=.3) #for the markers
plt.setp(ax.lines, alpha=.3)       #for the lines
sns.pointplot(data=df, x="ages", y="acc",
              dodge=0, 
              join=True, 
              ax=ax, 
              scale=0.5,
              hue='ages',
              palette=pl,
              #edgecolor='gray',
              #linewidth=0.3
             )
plt.legend([],[], frameon=False)
#plt.xticks([])
ax.set_ylim([0,1.1])

# Set x-axis ticks to show only every 10 ages
#tick_positions = range(0, np.max(df['ages']), 4)
#tick_labels = [df['ages'].iloc[i] for i in tick_positions]
ax.set_xticks([5,10,15,20,25,30,35,40,45])
#ax.set_xticklabels([10,20,30,40])

sns.despine()
ax.set_title('identity classifier male')
c = PlotConfig()
#fig.savefig(c.dana_save_path / "fig5"/ 'long_id_decoder.pdf')

In [13]:
c = PlotConfig()
fig.savefig(c.dana_save_path / "fig4"/ 'longv2_indv_decoder_female.pdf')

In [14]:
fig, ax = plt.subplots(figsize=(1.3, 1.3))
# Convert coefficients list to a DataFrame
coefficients_df = pd.DataFrame(coefficients_list)
sns.heatmap(coefficients_df.T, cmap='coolwarm', center=0, annot=False, fmt='.2f', vmin=-0.25, vmax=0.25)
plt.ylabel('Syllables')
plt.xlabel('Ages')

# Set x-axis ticks to show only every 6 ages
tick_positions = range(0, len(age), 6)
tick_labels = [age[i] for i in tick_positions]
plt.xticks(tick_positions, tick_labels)

plt.show()

In [15]:
c = PlotConfig()
fig.savefig(c.dana_save_path / "fig4"/ 'longv2_indv_decoder_coefficient_female.pdf')

In [16]:
from sklearn.feature_selection import mutual_info_classif as MIC
from sklearn.feature_selection import mutual_info_regression as MIR
long_data = m_df.copy()
X = long_data.to_numpy()
y = list(long_data.index.get_level_values('age'))
mi_score_long = MIR(X,y)

long_indx = np.argsort(mi_score_long)[::-1][0:len(keep_syllables)] # syllble index
impsyl_long= list(long_data.columns[long_indx]) #syllable id

In [17]:
fig, ax = plt.subplots(figsize=(1.3, 1.3))
# Convert coefficients list to a DataFrame
coefficients_df = pd.DataFrame(coefficients_list)

# Ensure all columns in impsyl_id are present in the DataFrame

heatmap_df = coefficients_df.pivot(index='syll', columns='age', values='coef')
# Ensure all indices in impsyl_long are present in the DataFrame
assert set(impsyl_long) <= set(heatmap_df.index), "impsyl_long contains rows not present in heatmap_df"

# Reorder the rows according to impsyl_long
heatmap_df = heatmap_df.loc[impsyl_long]

sns.heatmap(heatmap_df.T, cmap='coolwarm', center=0, annot=False, fmt='.2f', vmin=-0.25, vmax=0.25)
plt.ylabel('Syllables')
plt.xlabel('Ages')

# Set x-axis ticks to show only every 6 ages
tick_positions = range(0, len(age), 6)
tick_labels = [age[i] for i in tick_positions]
plt.xticks(tick_positions, tick_labels)
plt.show()

In [None]:
c = PlotConfig()
fig.savefig(c.dana_save_path / "fig4"/ 'longv2_indv_decoder_coefficient_resorted_female.pdf')