In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from aging.plotting import format_plots, PlotConfig, save_factory, figure, legend, format_pizza_plots
from collections import Counter
from matplotlib.lines import Line2D
from tqdm import tqdm

In [2]:
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn import svm
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from matplotlib.gridspec import GridSpec
from collections import defaultdict
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.model_selection import ShuffleSplit
from sklearn.svm import LinearSVC
from sklearn.model_selection import GroupKFold, LeaveOneGroupOut, LeaveOneOut, KFold
from sklearn.metrics import accuracy_score
import random

In [3]:
def mm_norm_col(column):
    return (column - column.min()) / (column.max() - column.min())

In [4]:
## update data
def filter_df(df):
    thresh=8
    age_counts = df.index.get_level_values('age').value_counts()
    ages_greater = list(age_counts[age_counts > thresh].index)
    return df.loc[df.index.get_level_values('age').isin(ages_greater)]

In [5]:
from matplotlib.colors import LinearSegmentedColormap
colors = ['#c7eae5','#008C8D']
custom_cmap = LinearSegmentedColormap.from_list("custom_purples", colors, N=256)
cmm=custom_cmap

image_ctx = {'image.cmap': 'cubehelix', 'image.interpolation': 'none'}

from matplotlib.colors import LinearSegmentedColormap
# Define the color map
colors = ['#dadaeb','#6a51a3']
custom_cmap = LinearSegmentedColormap.from_list("custom_purples", colors, N=256)
cmm=custom_cmap

In [6]:
## arrange data
keep_syllables = np.loadtxt('/n/groups/datta/win/longtogeny/data/ontogeny/version_11/to_keep_syllables_raw.txt', dtype=int)

male_df = pd.read_parquet('/n/groups/datta/win/longtogeny/data/ontogeny/version_11/longtogeny_males_raw_usage_matrix_v00.parquet').astype(float)
male_df = male_df[keep_syllables]
male_df=filter_df(male_df)
male_df = male_df.query('age<100')

In [7]:
long_df=male_df.copy()
# find first recording per week in the longtogeny dataset
from collections import defaultdict
ages = list(long_df.index.get_level_values('age').unique())
week_entries = defaultdict(list)
for age in ages:
    week, day = str(age).split('.')
    week_entries[week].append(age)

# Find the entry with the lowest day for each week
lowest_day_entries = [min(entries, key=lambda x: float(str(x).split('.')[1])) for entries in week_entries.values()]
long_df = long_df[long_df.index.get_level_values('age').isin(lowest_day_entries)]
long_df = long_df.rename_axis(index={'age': 'age_old'})
long_df['age'] = np.floor(long_df.index.get_level_values('age_old')).astype(int)
long_df.set_index('age',inplace=True, append=True)

In [8]:
# Group by 'bin' and 'mouse' and count the occurrences
#long_df.groupby(['age', 'mouse']).size().unstack(fill_value=0)
long_df = long_df.drop(index=['03_03', '02_04'], level='mouse')
#long_df.groupby(['age', 'mouse']).size().unstack(fill_value=0)

In [9]:
# avarage per age or session
m_df= long_df.groupby(['age','uuid','mouse']).mean()
avg_m_df = m_df.groupby(['age']).mean()

# normalize the data
m_norm = mm_norm_col(m_df)
avg_m_norm = mm_norm_col(avg_m_df)
df_indv=m_norm.groupby(['mouse','age']).mean()
df=df_indv.copy()

In [10]:
## get data to test on - all mice that have the last final 5 datapints:
#create test data - average of the 5 final data point of each mouse
nend=4
xtest=pd.DataFrame()
df=df_indv.copy()
# Identify the last 3 ages
last_5_ages = df.index.get_level_values('age').unique()[-nend:]

# Filter the DataFrame to include only the last 5 ages
last_5_df = df[df.index.get_level_values('age').isin(last_5_ages)]

xtest = df[df.index.get_level_values('age').isin(last_5_ages)]
xtest = mm_norm_col(xtest)
ytest = xtest.index.get_level_values('mouse').to_numpy()

In [11]:
# avarage per age or session
m_df= long_df.groupby(['age','uuid','mouse']).mean()
avg_m_df = m_df.groupby(['age']).mean()

# normalize the data
m_norm = mm_norm_col(m_df)
avg_m_norm = mm_norm_col(avg_m_df)

In [12]:
df_indv=m_norm.groupby(['mouse','age']).mean()
df=df_indv.copy()
df= df[~df.index.get_level_values('age').isin(last_5_ages)].reset_index()


In [13]:
# choose a model
clf = svm.SVC(kernel='linear') 

# choose cross validation scheme
cv = LeaveOneOut()
#cv = ShuffleSplit(n_splits=5, test_size=0.25, random_state=0)
#cv=5

#number of iterations for shuffle
it=1000

In [14]:
# prepare colors for plot
colors = ['#dadaeb','#6a51a3']
import matplotlib
cmap = LinearSegmentedColormap.from_list("custom_purples", colors, N=256)
matplotlib.cm.register_cmap("dana", cmap)
pl = sns.color_palette("dana", n_colors=50)

  matplotlib.cm.register_cmap("dana", cmap)


In [15]:
xmale = pd.DataFrame()
a = 0
n = 10  # size of sliding window
nmin = 8  # number of minimum sessions 
nmouse = 10  # number of minimum mice
age = []
acc = []
sh_acc = []
sh_ages = []
coefficients_list = []

while True:
    # Build the current age matrix
    for m, _df in enumerate(df.groupby('mouse')):
        xmale_temp = _df[1][a:a+n]
        xmale = pd.concat([xmale, xmale_temp])
    
    rep_counts = xmale.groupby('mouse')['age'].size()
    keep_mice = list(rep_counts[rep_counts >= nmin].index)  # keep mice that have at least nmin sessions
    xmale = xmale[xmale['mouse'].isin(keep_mice)]
    
    print(len(keep_mice))
    if xmale.empty or len(keep_mice) < nmouse:
        break
    
    # Select random nmin sessions per each mouse to balance the training set
    #selected_mice = random.sample(keep_mice, nmouse)  # select random 10 mice out of all mice
    #xmale = xmale[xmale['mouse'].isin(selected_mice)]
    xmale = xmale.groupby('mouse').apply(lambda x: x.sample(n=nmin, replace=False)).reset_index(level=0, drop=True)
    age.append(np.floor(np.mean(xmale['age'])).astype(int))
    
    ymale = xmale['mouse']
    x = xmale.drop(['mouse', 'age'], axis=1).to_numpy()
    
    # Fit the model and collect coefficients
    clf.fit(x, ymale)
    coefficients_list.append(np.mean(clf.coef_,axis=0))
    
    # Predict labels for the test data
    y_pred = clf.predict(xtest)

    # Calculate accuracy score
    accuracy = accuracy_score(ytest, y_pred)

    acc.append(accuracy)
    
    # Shuffle labels and calculate shuffled accuracy
    for i in tqdm(range(it)):
        ysh_temp = np.random.permutation(ymale)
        y_pred_temp = cross_val_score(clf, x, ysh_temp, cv=cv)
        sh_acc.append(np.mean(y_pred_temp))
        sh_ages.append(np.floor(np.mean(xmale['age'])).astype(int))
    
    # Handle variables for the loop
    xmale = pd.DataFrame()
    a = a + 1

  xmale = xmale.groupby('mouse').apply(lambda x: x.sample(n=nmin, replace=False)).reset_index(level=0, drop=True)


14


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [07:19<00:00,  2.28it/s]
  xmale = xmale.groupby('mouse').apply(lambda x: x.sample(n=nmin, replace=False)).reset_index(level=0, drop=True)


14


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [07:10<00:00,  2.32it/s]
  xmale = xmale.groupby('mouse').apply(lambda x: x.sample(n=nmin, replace=False)).reset_index(level=0, drop=True)


14


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [07:08<00:00,  2.33it/s]
  xmale = xmale.groupby('mouse').apply(lambda x: x.sample(n=nmin, replace=False)).reset_index(level=0, drop=True)


14


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [07:15<00:00,  2.30it/s]
  xmale = xmale.groupby('mouse').apply(lambda x: x.sample(n=nmin, replace=False)).reset_index(level=0, drop=True)


14


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [07:19<00:00,  2.28it/s]
  xmale = xmale.groupby('mouse').apply(lambda x: x.sample(n=nmin, replace=False)).reset_index(level=0, drop=True)


14


 10%|████████████████                                                                                                                                                       | 96/1000 [00:41<06:34,  2.29it/s]

In [16]:
# create df for plotting
df_sh = pd.DataFrame()
df_sh['acc']=sh_acc
df_sh['ages'] = sh_ages

df = pd.DataFrame()
df['acc']=acc
df['ages'] = age

In [17]:
format_plots()

In [18]:
fig, ax = plt.subplots(figsize=(1.3, 1.3))
sns.pointplot(data=df_sh, x="ages", y="acc", ax=ax, color='grey',estimator='mean', errorbar='se',join=True,scale=0.5)
plt.setp(ax.collections, alpha=.3) #for the markers
plt.setp(ax.lines, alpha=.3)       #for the lines
sns.pointplot(data=df, x="ages", y="acc",
              dodge=0, 
              join=True, 
              ax=ax, 
              scale=0.5,
              hue='ages',
              palette=pl,
              #edgecolor='gray',
              #linewidth=0.3
             )
plt.legend([],[], frameon=False)
#plt.xticks([])
ax.set_ylim([0,1])

# Set x-axis ticks to show only every 10 ages
tick_positions = range(0, len(df['ages']), 6)
tick_labels = [df['ages'].iloc[i] for i in tick_positions]
ax.set_xticks(tick_positions)
ax.set_xticklabels(tick_labels)

sns.despine()
#ax.set_title('identity decoding in different ages bin=' + str(n))
c = PlotConfig()
fig.savefig(c.dana_save_path / "fig5"/ 'long_id_decoder_last_data_points.pdf')

In [19]:
fig, ax = plt.subplots(figsize=(1.3, 1.3))
# Convert coefficients list to a DataFrame
coefficients_df = pd.DataFrame(coefficients_list)
sns.heatmap(coefficients_df.T, cmap='coolwarm', center=0, annot=False, fmt='.2f', vmin=-0.25, vmax=0.25)
plt.ylabel('Syllables')
plt.xlabel('Ages')

# Set x-axis ticks to show only every 6 ages
tick_positions = range(0, len(age), 6)
tick_labels = [age[i] for i in tick_positions]
plt.xticks(tick_positions, tick_labels)

plt.show()

In [20]:
fig.savefig(c.dana_save_path / "fig5"/ 'coeef_long_id_decoder_last_data_points.pdf')