In [59]:
import itertools

import numpy as np
import pandas as pd
import xarray as xr
import seaborn as sns  
import matplotlib.pyplot as plt

from scipy.stats import pearsonr
from sklearn.model_selection import KFold
from sklearn.linear_model import ElasticNet, Lasso
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error, root_mean_squared_error

## Data Loading

In [9]:
ds_hb = xr.open_dataset("../data/20241227/merged-20241227.nc")
ds_hb

In [10]:
np.unique(ds_hb["device"].to_numpy())

array([ 2,  7,  9, 11, 12, 13, 15])

## Regressor

In [94]:
def get_data(device_number):
    ds_hb_sel = ds_hb.sel(id=ds_hb["device"] == device_number) 

    X = ds_hb_sel["signal"].to_numpy()
    y = ds_hb_sel["hb"].to_numpy()

    r, p = pearsonr(X, y.reshape(-1, 1))
    topk_pearson = np.argpartition(np.abs(r), -20)[-20:]
    topk_pearson

    return X[:, topk_pearson], y, topk_pearson, r[topk_pearson]

In [95]:
def cross_val(model_name, Xt, yt):
    cv = KFold(n_splits=5, shuffle=True, random_state=21)

    for i, (train_index, test_index) in enumerate(cv.split(Xt, yt)):
        # get data
        X_train, X_test = Xt[train_index], Xt[test_index]
        y_train, y_test = yt[train_index], yt[test_index]

        # create model
        if model_name == "elasticnet":
            model = ElasticNet()
        else:
            model = Lasso()

        # fit model
        model.fit(X_train, y_train)

        # predict
        y_pred = model.predict(X_test)

        # eval
        r, _ = pearsonr(y_test, y_pred)
        yield {
            "algorithm": model_name,
            "fold": i,
            "r": r,
            "r2": r2_score(y_test, y_pred),
            "mae": mean_absolute_error(y_test, y_pred),
            "mse": mean_squared_error(y_test, y_pred),
            "rmse": root_mean_squared_error(y_test, y_pred),
            "intercept": model.intercept_,
            "coeffs": model.coef_
        }
    

In [96]:
scores = []
devices = [2, 9, 11, 13]
models = ["elasticnet", "lasso"]

for device_number, model_name in itertools.product(devices, models):
    X, y, data_indices, pearson_values = get_data(device_number)
    for val_result in cross_val(model_name, X, y):
        scores.append({"device_number": device_number, **val_result, "topk_pearson": data_indices, "pearson_r": pearson_values})

In [97]:
df_scores = pd.DataFrame(scores).sort_values("r", ascending=False)
df_scores.head()

Unnamed: 0,device_number,algorithm,fold,r,r2,mae,mse,rmse,intercept,coeffs,topk_pearson,pearson_r
14,9,elasticnet,4,0.816524,0.511447,1.152253,1.804585,1.343348,66.150552,"[0.004778670920030104, -0.00561120666278867, 0...","[6778, 814, 6891, 3339, 5480, 3316, 3354, 6687...","[0.19416657181270916, -0.19756963616243337, 0...."
19,9,lasso,4,0.810068,0.467321,1.203633,1.967573,1.402702,63.389331,"[0.003000541904450331, -0.005321932838709503, ...","[6778, 814, 6891, 3339, 5480, 3316, 3354, 6687...","[0.19416657181270916, -0.19756963616243337, 0...."
33,13,elasticnet,3,0.76034,0.506919,1.048882,1.578784,1.256497,-25.822985,"[0.0035521433292702493, -0.0001092229804839676...","[5010, 4642, 6657, 1210, 5665, 6942, 5033, 469...","[0.21690185868124243, 0.21834621175592991, 0.2..."
38,13,lasso,3,0.750889,0.481489,1.079126,1.660206,1.28849,-20.816988,"[0.003553238606157261, -0.0, 0.014797739781817...","[5010, 4642, 6657, 1210, 5665, 6942, 5033, 469...","[0.21690185868124243, 0.21834621175592991, 0.2..."
35,13,lasso,0,0.736455,0.465471,1.073188,1.670759,1.292578,-19.043716,"[0.007364908554587457, -0.004180968886440611, ...","[5010, 4642, 6657, 1210, 5665, 6942, 5033, 469...","[0.21690185868124243, 0.21834621175592991, 0.2..."


In [98]:
top_model_indices = df_scores.groupby("device_number")["r"].idxmax().values
df_scores.loc[top_model_indices]

Unnamed: 0,device_number,algorithm,fold,r,r2,mae,mse,rmse,intercept,coeffs,topk_pearson,pearson_r
8,2,lasso,3,0.674547,0.255907,1.177012,1.702887,1.304947,36.674673,"[0.0, 0.0, -0.0, 0.0, -0.01011461971177791, -0...","[466, 186, 766, 786, 7838, 46, 6, 706, 406, 78...","[-0.3154894140623964, -0.3157128543110338, -0...."
14,9,elasticnet,4,0.816524,0.511447,1.152253,1.804585,1.343348,66.150552,"[0.004778670920030104, -0.00561120666278867, 0...","[6778, 814, 6891, 3339, 5480, 3316, 3354, 6687...","[0.19416657181270916, -0.19756963616243337, 0...."
26,11,lasso,1,0.4122,0.066134,1.495054,3.518277,1.875707,0.341381,"[0.0, -0.0, 0.0, -0.0, -0.0, 0.0, -0.0, -0.0, ...","[348, 168, 768, 248, 7782, 7873, 7827, 7824, 7...","[0.2079164428018387, 0.2096765522602744, 0.210..."
33,13,elasticnet,3,0.76034,0.506919,1.048882,1.578784,1.256497,-25.822985,"[0.0035521433292702493, -0.0001092229804839676...","[5010, 4642, 6657, 1210, 5665, 6942, 5033, 469...","[0.21690185868124243, 0.21834621175592991, 0.2..."


In [99]:
for row in df_scores.loc[top_model_indices].itertuples():
    formula = f"HB-{row.device_number}: y = {row.intercept:.4f} + "
    for i, coeff in zip(row.topk_pearson, row.coeffs):
        formula += f"arr[{i}] * {coeff:.4f} + "

    print(formula[:-2])

HB-2: y = 36.6747 + arr[466] * 0.0000 + arr[186] * 0.0000 + arr[766] * -0.0000 + arr[786] * 0.0000 + arr[7838] * -0.0101 + arr[46] * -0.0000 + arr[6] * -0.0000 + arr[706] * -0.0052 + arr[406] * 0.0000 + arr[7834] * -0.0082 + arr[866] * -0.0000 + arr[626] * 0.0000 + arr[286] * 0.0000 + arr[226] * 0.0000 + arr[446] * -0.0000 + arr[346] * -0.0000 + arr[686] * -0.0000 + arr[386] * -0.0000 + arr[366] * -0.0000 + arr[426] * -0.0138 
HB-9: y = 66.1506 + arr[6778] * 0.0048 + arr[814] * -0.0056 + arr[6891] * 0.0077 + arr[3339] * 0.0082 + arr[5480] * -0.0000 + arr[3316] * -0.0074 + arr[3354] * 0.0000 + arr[6687] * 0.0074 + arr[6618] * 0.0000 + arr[7780] * -0.0125 + arr[2387] * -0.0000 + arr[4695] * -0.0148 + arr[2385] * 0.0185 + arr[1068] * -0.0138 + arr[3345] * -0.0271 + arr[2273] * -0.0188 + arr[5064] * -0.0092 + arr[5184] * -0.0043 + arr[5284] * -0.0071 + arr[5244] * -0.0140 
HB-11: y = 0.3414 + arr[348] * 0.0000 + arr[168] * -0.0000 + arr[768] * 0.0000 + arr[248] * -0.0000 + arr[7782] * -0.0

In [100]:
df_scores.loc[top_model_indices].to_json("coeff.json", orient="records", lines=True)