In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import matplotlib as mpl
import os
from matplotlib.lines import Line2D
from collections import Counter
import math
from sklearn.decomposition import PCA
from aging.behavior.syllables import relabel_by_usage
from tqdm import tqdm
%matplotlib inline
import warnings
warnings.simplefilter('ignore')
import random
import scipy
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import adjusted_rand_score
from kneed import KneeLocator
from sklearn.metrics import silhouette_score
%matplotlib inline
from aging.plotting import format_plots, PlotConfig, save_factory, figure, legend, format_pizza_plots
from sklearn.linear_model import ElasticNet
from sklearn.preprocessing import MinMaxScaler
from sklearn.feature_selection import mutual_info_classif as MIC
from sklearn.feature_selection import mutual_info_regression as MIR

In [2]:
from sklearn.linear_model import LinearRegression, ElasticNet
from sklearn.preprocessing import OneHotEncoder
import statsmodels.api as sm
from statsmodels.formula.api import ols
from statsmodels.tools.tools import pinv_extended  
from statsmodels.stats.anova import anova_lm
from tqdm.auto import tqdm

In [3]:
format_plots()
#format_pizza_plots()

In [4]:
## update data
def filter_df_long(df):
    max=39
    thresh=8
    syll=df.columns.values[df.columns.values>max]
    df.drop(syll,axis=1, inplace=True)
    age_counts = df.index.get_level_values('age').value_counts()
    ages_greater = list(age_counts[age_counts > thresh].index)
    return df.loc[data.index.get_level_values('age').isin(ages_greater)]

In [5]:
path = Path('/n/groups/datta/win/longtogeny/data/ontogeny/version_11/longtogeny_males_relabeled_usage_matrix_v00.parquet')
df = pd.read_parquet(path)

# arrange data
data = df.astype(float, errors='ignore')
data= filter_df_long(data).groupby(['age','uuid','mouse']).mean()
data = data.query('age<100')
long_data = data[sorted(data.columns)].copy()

In [6]:
'''
# find first recording per week in the longtogeny dataset
from collections import defaultdict
ages = list(long_data.index.get_level_values('age').unique())
# Group entries by week
week_entries = defaultdict(list)
for age in ages:
    week, day = str(age).split('.')
    week_entries[week].append(age)

# Find the entry with the lowest day for each week
lowest_day_entries = [min(entries, key=lambda x: float(str(x).split('.')[1])) for entries in week_entries.values()]
long_data = long_data[long_data.index.get_level_values('age').isin(lowest_day_entries)]
'''

In [7]:
y = long_data.to_numpy()
y_log = np.log(y + 1e-6)
y = (y_log - y_log.mean(axis=0, keepdims=True)) / y_log.std(axis=0, keepdims=True) # z-score

In [8]:
# run GLM
indv_encoder = OneHotEncoder(sparse=False).fit_transform(long_data.index.get_level_values('mouse').to_numpy().reshape(-1, 1))
age_encoder = OneHotEncoder(sparse=False).fit_transform(long_data.index.get_level_values('age').to_numpy().reshape(-1, 1))
x = np.concatenate((age_encoder,indv_encoder), axis=1)
lr = ElasticNet(alpha=0.01)
lr.fit(x, y)

In [9]:
# plot weights for phase and individual identity
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, ax = plt.subplots()
fig.set_size_inches(4,4)
im=plt.imshow(lr.coef_, cmap='RdBu_r', vmin=-2, vmax=2)
#plt.imshow(lr.coef_, cmap='RdBu_r')
plt.title('Model weights')
ax = plt.gca()
rect = plt.Rectangle((-0.5, -1.5), age_encoder.shape[1], 1, facecolor='#623f99', alpha=1, label='age')
ax.add_patch(rect)
rect = plt.Rectangle((age_encoder.shape[1]-0.5, -1.5), indv_encoder.shape[1], 1, facecolor='purple', alpha=1, label='indv_identity')
ax.add_patch(rect)
plt.axvline(age_encoder.shape[1] - 0.5, c='k', ls='--', lw=1)
plt.ylim(-1.5, y.shape[1] - 0.5)

# create an axes on the right side of ax. The width of cax will be 5%
# of ax and the padding between cax and ax will be fixed at 0.05 inch.

plt.legend(frameon=False, loc='upper left', bbox_to_anchor=(1.25, 1))
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
   
plt.colorbar(im, cax=cax)
#fig.savefig(data_loc + 'long_GLM_coefficients.pdf', bbox_inches='tight')
c = PlotConfig()
fig.savefig(c.dana_save_path / "fig4"/ 'GLM_weights.pdf', bbox_inches='tight')