In [None]:
from kinetics_modelling.config import PROCESSED_DATA_DIR
from sklearn.linear_model import LassoCV 
from sklearn.model_selection import train_test_split
import polars as pl
import numpy as np
import altair as alt
alt.data_transformers.enable("vegafusion")


DataTransformerRegistry.enable('vegafusion')

In [42]:
df = pl.read_parquet(PROCESSED_DATA_DIR / 'ob006-run1-sample.parquet')
df.head()

center_qual_2,center_seq,fn,center_sx_0,center_sx_1,center_sx_2,center_sm_0,center_sm_1,center_sm_2,center_ipd_fwd_0,center_ipd_rev_0,center_pw_fwd_0,center_pw_rev_0,center_ipd_fwd_1,center_ipd_rev_1,center_pw_fwd_1,center_pw_rev_1,center_ipd_fwd_2,center_ipd_rev_2,center_pw_fwd_2,center_pw_rev_2,center_qual_0,center_qual_1
u8,str,i64,u8,u8,u8,u8,u8,u8,u16,u16,u8,u8,u16,u16,u8,u8,u16,u16,u8,u8,u8,u8
36,"""GTA""",2,1,0,0,5,6,5,21,112,13,14,49,65,9,12,6,20,7,16,39,69
36,"""TTT""",2,0,0,0,6,6,6,19,16,11,16,16,25,24,7,19,29,6,19,73,67
36,"""TAT""",2,0,0,0,6,6,6,14,16,17,31,31,58,9,11,21,27,7,19,49,38
36,"""ATT""",5,1,0,0,9,11,11,16,10,9,6,40,30,14,8,21,9,10,11,82,93
36,"""CAA""",2,0,0,0,4,3,4,2,36,16,22,8,17,3,18,30,24,17,6,50,28


In [13]:

X = df.select(
    'fn', 
    'center_sx_0',
    'center_sx_1',
    'center_sx_2',
    'center_sm_0',
    'center_sm_1',
    'center_sm_2',
    'center_ipd_fwd_0', 
    'center_ipd_rev_0', 
    'center_pw_fwd_0',
    'center_pw_rev_0',
    'center_ipd_fwd_1', 
    'center_ipd_rev_1', 
    'center_pw_fwd_1',
    'center_pw_rev_1',
    'center_ipd_fwd_2', 
    'center_ipd_rev_2', 
    'center_pw_fwd_2',
    'center_pw_rev_2').to_numpy()
y = df.select('center_qual_2').to_numpy().reshape(-1,)

In [37]:

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1337)
alphas = np.linspace(0.001,1, 100)
lasso = LassoCV(cv=5, random_state=0, n_jobs=-1, alphas = alphas).fit(X_train, y_train)

In [None]:
lasso = model[-1]
plt.semilogx(lasso.cv_alphas_, lasso.mse_path_, ":")
plt.semilogx(
    lasso.cv_alphas_,
    lasso.mse_path_.mean(axis=-1),
    color="black",
    label="Average across the folds",
    linewidth=2,
)
plt.axvline(lasso.alpha_, linestyle="--", color="black", label="alpha CV")

plt.ylim(ymin, ymax)
plt.xlabel(r"$\alpha$")
plt.ylabel("Mean square error")
plt.legend()
_ = plt.title(f"Mean square error on each fold: Lars (train time: {fit_time:.2f}s)")

In [38]:
lasso.alphas_

array([1.        , 0.98990909, 0.97981818, 0.96972727, 0.95963636,
       0.94954545, 0.93945455, 0.92936364, 0.91927273, 0.90918182,
       0.89909091, 0.889     , 0.87890909, 0.86881818, 0.85872727,
       0.84863636, 0.83854545, 0.82845455, 0.81836364, 0.80827273,
       0.79818182, 0.78809091, 0.778     , 0.76790909, 0.75781818,
       0.74772727, 0.73763636, 0.72754545, 0.71745455, 0.70736364,
       0.69727273, 0.68718182, 0.67709091, 0.667     , 0.65690909,
       0.64681818, 0.63672727, 0.62663636, 0.61654545, 0.60645455,
       0.59636364, 0.58627273, 0.57618182, 0.56609091, 0.556     ,
       0.54590909, 0.53581818, 0.52572727, 0.51563636, 0.50554545,
       0.49545455, 0.48536364, 0.47527273, 0.46518182, 0.45509091,
       0.445     , 0.43490909, 0.42481818, 0.41472727, 0.40463636,
       0.39454545, 0.38445455, 0.37436364, 0.36427273, 0.35418182,
       0.34409091, 0.334     , 0.32390909, 0.31381818, 0.30372727,
       0.29363636, 0.28354545, 0.27345455, 0.26336364, 0.25327

In [45]:
lasso.coef_

array([ 1.78499351e+00, -1.61003001e+00, -4.32538704e+00,  3.05618577e+00,
       -5.17033122e-01, -9.24857286e-01,  5.67047909e+00,  2.71775306e-02,
        9.14411639e-03, -2.46752655e-02, -3.71186534e-02,  2.56929539e-03,
        7.23575073e-03, -7.36347126e-02, -8.42157803e-02, -6.77332181e-03,
       -7.48344865e-03, -1.52259272e-01, -1.81913241e-02])

In [39]:
model_df = pl.DataFrame({
    'mean_mse' :  lasso.mse_path_.mean(axis=-1),
    'alphas' : lasso.alphas_,
})

In [40]:
alt.Chart(model_df).mark_line().encode(
    alt.X('alphas'),
    alt.Y('mean_mse:Q')
)

In [44]:
alt.Chart(df).mark_circle(opacity=0.4).encode(
    alt.X('fn'),
    alt.Y('center_qual_1')
)

In [None]:
from 