In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import matplotlib as mpl
import os
from matplotlib.lines import Line2D
from collections import Counter
import math
from sklearn.decomposition import PCA
from aging.behavior.syllables import relabel_by_usage
from tqdm import tqdm
%matplotlib inline
import warnings
warnings.simplefilter('ignore')
import random

In [4]:
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn import svm
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from matplotlib.gridspec import GridSpec
from collections import defaultdict
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.model_selection import ShuffleSplit
from sklearn.svm import LinearSVC
from sklearn.model_selection import GroupKFold, LeaveOneGroupOut, LeaveOneOut, KFold
from sklearn.metrics import accuracy_score

In [5]:
# plot/colors definitions
cpath = '/n/groups/datta/win/longtogeny/code/notebooks/exploration/Dana'
data_loc=cpath+'/figs/'
try:
    os.mkdir(data_loc)
except FileExistsError:
    # directory already exists
    pass

sns.set_style('white')

In [6]:
def _plot_cm(y_true, y_pred, ax, ax_labels, title):
    cm = confusion_matrix(y_true, y_pred)
    cm = cm / cm.sum(axis=1, keepdims=True)
    im = ax.imshow(cm, cmap='magma', vmin=0, vmax=1)
    plt.xticks(range(len(ax_labels)), ax_labels)
    plt.yticks(range(len(ax_labels)), ax_labels)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('Real')
    ax.set_title(title)
    return im

def plot_cm(y_true, y_pred, y_shuffle_true, y_shuffle_pred,name):
    '''
    plot confusion matrix

    Args:
        y_true ([np.array]): array for true label
        y_pred ([np.array]): array for predicted label
        y_shuffle_true ([np.array]): array for shffuled label
        y_shuffle_pred ([np.array]): array for shuffled predicted label
    '''
    fig = plt.figure(figsize=(23, 10), facecolor='white')
    gs = GridSpec(ncols=3, nrows=1, wspace=0.1, figure = fig, width_ratios=[10,10,0.3])
    fig_ax = fig.add_subplot(gs[0,0])
    labels = np.unique(y_true)
    _plot_cm(y_true, y_pred, fig_ax, labels, f'Real Accuracy {accuracy_score(y_true, y_pred):0.2f}')

    fig_ax = fig.add_subplot(gs[0,1])
    im = _plot_cm(y_shuffle_true, y_shuffle_pred, fig_ax, labels, f'Shuffle Accuracy {accuracy_score(y_shuffle_true, y_shuffle_pred):0.2f}')
    fig_ax.set_ylabel('')
    fig_ax.set_yticklabels([])

    # plot colorbar
    cb = fig.add_subplot(gs[0,2])
    fig.colorbar(mappable=im, cax=cb, label='Fraction of labels', )
    fig.tight_layout()
    plt.show()
    #fig.savefig(data_loc +name+'.pdf', bbox_inches='tight')

In [7]:
## for males

In [8]:
## upload data frame females
path = Path('/n/groups/datta/win/longtogeny/data/ontogeny/version_02/longtogeny_musages_mtx_all.parquet')
mdf_all = pd.read_parquet(path)
path = Path('/n/groups/datta/win/longtogeny/data/ontogeny/version_02/longtogeny_musages_mtx_most_used.parquet')
mdf = pd.read_parquet(path)

In [9]:
## filter out bad days - 
bad_session=[221, 228]
#bad_syllable=[44,89]
bad_syllable=[44]
bad_mouse='04_01'
new_df=mdf.copy()
new_df.drop(bad_session, level=0, axis=0, inplace=True)
new_df.drop(bad_mouse, level=1, axis=0, inplace=True)
new_df.drop(bad_syllable, axis=1, inplace=True)

In [14]:
df_indv=mdf.groupby(['mouse','age','cage']).mean().reset_index()

In [15]:
#days =df_indv['age'].to_numpy()
#df_indv['wks'] = np.floor(days/7).astype(int)
#df_indv['months'] = np.ceil(days/30).astype(int)
#df_indv.set_index(['wks','months'])

In [16]:
df=df_indv.copy()

In [17]:
df

syllables,mouse,age,cage,0,3,6,7,10,13,14,...,79,83,86,88,89,92,94,96,98,99
0,01_01,21,01,0.048960,0.020797,0.000867,0.050260,0.022964,0.005633,0.011698,...,0.006066,0.013432,0.051560,0.052860,0.039861,0.002166,0.003033,0.037262,0.003466,0.004333
1,01_01,22,01,0.042926,0.011827,0.001314,0.043802,0.013141,0.006132,0.009636,...,0.004818,0.015769,0.045116,0.065265,0.027157,0.003942,0.000876,0.034166,0.001752,0.005694
2,01_01,25,01,0.056874,0.013912,0.000409,0.060147,0.019231,0.013912,0.012275,...,0.005319,0.027823,0.036825,0.047054,0.025368,0.008592,0.000000,0.036416,0.003682,0.003273
3,01_01,26,01,0.051745,0.011633,0.000401,0.056558,0.021661,0.013237,0.011231,...,0.009226,0.027677,0.044525,0.027677,0.017649,0.008022,0.000401,0.042118,0.003209,0.007621
4,01_01,27,01,0.051433,0.010601,0.002356,0.051826,0.029839,0.021201,0.010993,...,0.004319,0.017668,0.043973,0.016490,0.019631,0.007460,0.001570,0.032980,0.006282,0.011386
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
951,04_04,460,04,0.024750,0.034745,0.015231,0.016183,0.010947,0.033793,0.011899,...,0.013803,0.006188,0.044265,0.011423,0.002380,0.003808,0.038553,0.029510,0.019990,0.038553
952,04_04,468,04,0.031191,0.040170,0.028355,0.023157,0.005671,0.027883,0.009452,...,0.009452,0.011815,0.044423,0.013705,0.005671,0.009924,0.027410,0.034499,0.017958,0.029773
953,04_04,479,04,0.029297,0.033203,0.011719,0.019043,0.008301,0.034668,0.015625,...,0.013184,0.007324,0.037109,0.009766,0.002441,0.001465,0.031738,0.027832,0.020996,0.028320
954,04_04,488,04,0.045588,0.036275,0.029412,0.018627,0.006373,0.024020,0.015686,...,0.015686,0.012745,0.043627,0.006863,0.002451,0.006863,0.033824,0.029412,0.017647,0.039706


In [18]:
## decoder

In [19]:
# choose a model
clf = svm.SVC(kernel='linear') 
#clf = RandomForestClassifier(n_estimators = 250)
#clf = LinearRegression()

# choose cross validation scheme
cv = LeaveOneOut()
#cv = ShuffleSplit(n_splits=5, test_size=0.25, random_state=0)
#cv=5

#number of iterations for shuffle
it=100

In [20]:
xmale = pd.DataFrame()
a=0
n=5
age=[]
acc=[]
b=False
sh_acc=[]
sh_ages=[]
#while not b:
while True:
    # build the current age matrix
    for m,_df in enumerate(df.groupby('mouse')):
        xmale_temp = _df[1][a:a+n]
        xmale = pd.concat([xmale,xmale_temp])
    rep_counts = xmale.groupby('mouse')['age'].nunique()
    keep_mice = list(rep_counts[rep_counts >= n].index) # keep mice that have at least 5 sessions
    xmale = xmale[xmale['mouse'].isin(keep_mice)]
    if xmale.empty:
        break
    # run identity decoder for this matrix
    age.append(np.floor(np.mean(xmale['age'])).astype(int))
    ymale=xmale['cage']
    x = xmale.drop(['mouse','age','cage'], axis=1).to_numpy()
    x = np.log(x + 1e-6) # convert to linear
    y_pred = cross_val_predict(clf, x, ymale, cv=cv)
    temp_acc=np.mean(np.mean(cross_val_score(clf, x, ymale, cv=cv)))
    acc.append(temp_acc)  
    
    #shuffle
    for i in tqdm(range(it)):
        ysh_temp=np.random.permutation(ymale)
        y_pred_temp = cross_val_score(clf, x, ysh_temp, cv=cv)
        sh_acc.append(np.mean(y_pred_temp))
        sh_ages.append(np.floor(np.mean(xmale['age'])).astype(int))

    # handle variables for the loop
    xmale = pd.DataFrame()
    #temp_acc=[]
    a=a+1

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:30<00:00,  3.26it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:29<00:00,  3.35it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:27<00:00,  3.59it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:30<00:00,  3.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:31<00:00,  3.18it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:29<00:00,  3.43it/s]
100%|█████████████████████████████████████████

In [21]:
xmale

In [None]:
# create df for plotting
df_sh = pd.DataFrame()
df_sh['acc']=sh_acc
df_sh['ages'] = sh_ages

df = pd.DataFrame()
df['acc']=acc
df['ages'] = age

In [None]:
fig, ax = plt.subplots(figsize=(12, 12))
#sns.violinplot(data=df_sh, x="ages", y="acc", ax=ax, color='gray')
sns.boxplot(data=df_sh, x="ages", y="acc", ax=ax, color='gray')
sns.pointplot(data=df, x="ages", y="acc",
              #xticklabels=1,
              dodge=0, 
              join=False, 
              ax=ax, 
              scale=1,
              hue='ages',
              palette='Blues',
              #edgecolor='gray',
              #linewidth=0.3
             )
plt.legend([],[], frameon=False)
plt.xticks(rotation=45)
ax.set_title('cage decoding in different ages')