In [None]:
import torch
import delu
import os, sys
import pandas as pd
import lightgbm as lgb

import importlib
import tuned_model_predictions
importlib.reload(tuned_model_predictions)
from tuned_model_predictions import *

import tomli
from models import MLP, ResNet
from models import Model as TabR
from typing import Optional, Union, Dict, Tuple
import numpy as np
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error


In [None]:
from sklearn.linear_model import Ridge, Lasso
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split, cross_val_score
from scipy.stats import pearsonr
from sklearn import model_selection
from sklearn.linear_model import RidgeCV, LassoCV, ElasticNet, ElasticNetCV, Lasso
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import (KFold, StratifiedKFold)

In [None]:
%pip --trusted-host pypiserver.app.finngen.fi install --default-timeout=60 https://pypiserver.app.finngen.fi/packages/python3/torch-1.13.1+rocm5.2-cp310-cp310-linux_x86_64.whl --user

In [None]:
%pip --trusted-host pypiserver.app.finngen.fi install -i https://pypiserver.app.finngen.fi/simple/ delu==0.0.15 --user

In [None]:
%pip --trusted-host pypiserver.app.finngen.fi install -i https://pypiserver.app.finngen.fi/simple/ libzero==0.0.3.dev7 --user

In [None]:
%pip --trusted-host pypiserver.app.finngen.fi install -i https://pypiserver.app.finngen.fi/simple/ faiss --user

In [None]:
file_path = '/finngen/library-red/EA5/proteomics/olink/third_batch/original_data/Q-05765_Rodosthenous_NPX_2023-06-09.csv'

finngen_olink = pd.read_csv(file_path, delimiter=';')

In [None]:
file_path = '/finngen/library-red/EA5/proteomics/olink/third_batch/QCd/proteomics_QC_all.txt'
finngen_olink = pd.read_csv(file_path, sep='\t')

In [None]:
file_path = '/home/ivm/Documents/olink_imputed_jan_25_2024.csv'
finngen_olink = pd.read_csv(file_path)

In [None]:
import gzip

with gzip.open('/finngen/library-red/finngen_R12/phenotype_1.0/data/finngen_R12_minimum_extended_1.0.txt.gz','rt') as f:
    covars = pd.read_csv(f, sep='\t')
covars.rename(columns={'FINNGENID': 'FID'}, inplace=True)    

In [None]:
file_path = '/finngen/library-red/EA5/proteomics/olink/third_batch/QCd/covars.txt'
covars2 = pd.read_csv(file_path, sep='\t')

In [None]:
covars2.columns

In [None]:
finngen_all = pd.merge(finngen_olink, covars, on='FID', how='inner')
finngen_all.set_index('FID', inplace=True)

In [None]:
covars2['SEX'].value_counts()

In [None]:
# finngen_all['BMI'].describe()
finngen_all['DEATH'].value_counts()

In [None]:
finngen_all = pd.merge(finngen_olink, covars2[['FID', 'FU_END_AGE']],  on='FID', how='inner')
finngen_all.set_index('FID', inplace=True)

In [None]:
import datetime as dt 

now = dt.datetime.now()
now = now.strftime('%Y-%m-%d')

# Protein names

In [None]:
list_path = f'/finngen/green/austina/olink_names_oct_30_2023.csv'
olink_names = pd.read_csv(list_path, header=None)
olink_names = list(olink_names[0])

remove_prots = [
    # proteins in UKB but not CKB
    'HLA_A',
    'ERVV_1',
    
    # proteins in CKB but not UKB
    'CD97',
    'FGFR1OP',
    'LRMP',
    'CASC4',
    'DARS',
    'HARS',
    'WISP2',
    'FOPNL',
    'WISP1',
    
    # proteins not in FinnGen
    'EDEM2',
    'EP300',
    'CGA',
    'CDHR1',
    'CPLX2',
    'CLSTN1',
    'IFIT1',
    'FGF3',
    'TAGLN3',
    'YAP1',
    'ADIPOQ',
    'BCL2L11',
    'BMP6',
    'BID',
    'SH3GL3',
    'ARL13B',
    'ANGPTL7',
    'MGLL',
    'MPI',
    'MAGEA3',
    'KCNH2',
    
    # proteins missing > 20%
    'GLIPR1', 
    'NPM1', 
    'PCOLCE' 
]


olink_names = [prot for prot in olink_names if prot not in remove_prots]

In [None]:
from sklearn.preprocessing import MinMaxScaler

# the scaler object (model)
scaler = MinMaxScaler()

# copy data
olink_normalized = finngen_all.copy()

# UKB
for protein in olink_names:
    # fit and transform the data
    value = scaler.fit_transform(olink_normalized[protein].values.reshape(-1, 1)) 
    olink_normalized[protein] = value
    
    # Calculate the median
    median = olink_normalized[protein].median()
    # Median center
    olink_normalized[protein] = olink_normalized[protein] - median


# load models

In [None]:
import lightgbm as lgb
import pickle

# load saved models
with open("/finngen/green/austina/lasso_model_2023-11-28.p", "rb") as f:
   lasso_tuned = pickle.load(f)

with open("/finngen/green/austina/elastic_net_model_2023-11-28.p", "rb") as f:
   elastic_net_tuned = pickle.load(f)
   
with open("/finngen/green/austina/pAge_UKB_3k_dart_minMax_pre_boruta_2023-12-23.p", "rb") as f:
   lgbm_model = pickle.load(f)
   
# load data dictionary used to train lgbm model on server
with open('/finngen/green/austina/UKB_data_dict_dart_2023-12-23.p', "rb") as f:
   server_data = pickle.load(f)

# Load CKB data

In [None]:
# path to data
data_path = '/finngen/green/austina/ckb_coded_olink_oct_17_2023.csv'
# load data
ckb_data = pd.read_csv(data_path)

# set index to be csid
ckb_data.set_index('csid', inplace=True)

# the scaler object (model)
scaler = MinMaxScaler()

# copy data
ckb_normalized = ckb_data.copy()

for protein in olink_names:
    # fit and transform the data
    val = scaler.fit_transform(ckb_normalized[protein].values.reshape(-1, 1)) 
    # Calculate the median
    median = np.nanmedian(val)
    # Median center
    ckb_normalized[protein] = val - median

# rename column to match UKB
ckb_normalized['age_granular'] = ckb_normalized['recruitment_age'].copy()

# subset to those not in random subset
ckb_data_all = ckb_normalized.copy()
ckb_data_random = ckb_normalized[ckb_normalized['olinkexp1536_chd_b1_subcohort'] == 1].copy()

# ALL models

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.gridspec as gridspec

from scipy.stats import pearsonr
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from matplotlib.ticker import FuncFormatter

import importlib
import tuned_model_predictions
importlib.reload(tuned_model_predictions)
from tuned_model_predictions import *

import torch
import tomli
import models
importlib.reload(models)
from models import MLP, ResNet
from models import Model as TabR
from typing import Optional, Union, Dict, Tuple
import numpy as np
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
import delu

# set color palette
pal = sns.color_palette()
pal = sns.color_palette('CMRmap')
pal = sns.color_palette('inferno')
pal = sns.color_palette('Dark2')
pal = sns.color_palette('tab10')

sns.set_style("ticks")

# Create a new figure and specify the layout using gridspec
fig = plt.figure(figsize=(12, 18))
# fig = plt.figure(figsize=(12, 14))
gs = gridspec.GridSpec(6, 3, figure=fig, width_ratios=[1, 1, 1], height_ratios=[1, 1, 1, 1, 1, 1])

# set CKB data
ckb_y = ckb_data_random['age_granular']
ckb_X = ckb_data_random[olink_names]
# ckb_y = ckb_data_all['age_granular']
# ckb_X = ckb_data_all[olink_names]

FinnGen_X = olink_normalized[olink_names].copy()
FinnGen_y = olink_normalized['FU_END_AGE'].copy()

# plot a: LASSO
ax = plt.subplot(gs[0, 0], aspect='auto')  # Span the entire left column
ax.set_title('a', fontweight='bold', fontsize=22, loc='left')

# predicted values
predictions_lasso = lasso_tuned.predict(server_data['X_test'])

# Evaluation metrics
r, pvalue = pearsonr(server_data['y_test'], predictions_lasso)
r2 = r2_score(server_data['y_test'], predictions_lasso)
rmse = mean_squared_error(server_data['y_test'], predictions_lasso, squared=False)
mae = mean_absolute_error(server_data['y_test'], predictions_lasso)

# plot predicted age against age
regplot = sns.regplot(
    x=server_data['y_test'], 
    y=predictions_lasso,
    scatter_kws=dict(color=pal[4], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('LASSO - UKB')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'LASSO', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# Set x and y limits for ax2
ax1_xlim = (38,72)
ax1_ylim = (32,82)

ax1_xlim = ax.set_xlim(ax1_xlim)
ax1_ylim = ax.set_ylim(ax1_ylim)


# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ukb_lasso_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)


# plot b: Elastic net
ax = plt.subplot(gs[1, 0], aspect='auto')  # Span the entire left column
# ax.set_title('b', fontweight='bold', loc='left')

# get predictions
predictions_elastic_net = elastic_net_tuned.predict(server_data['X_test'])

# Evaluation metrics
r, pvalue = pearsonr(server_data['y_test'], predictions_elastic_net)
r2 = r2_score(server_data['y_test'], predictions_elastic_net)
rmse = mean_squared_error(server_data['y_test'], predictions_elastic_net, squared=False)
mae = mean_absolute_error(server_data['y_test'], predictions_elastic_net)

# plot predicted age against age
regplot = sns.regplot(
    x=server_data['y_test'], 
    y=predictions_elastic_net,
    scatter_kws=dict(color=pal[5], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('Elastic Net - UKB')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'Elastic Net', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ukb_enet_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)
    
    
# plot c: LightGBM
ax = plt.subplot(gs[2, 0], aspect='auto')
# ax.set_title('c', fontweight='bold', loc='left')

# get predictions
predictions_lgbm = lgbm_model.predict(server_data['X_test'])

# Evaluation metrics
r, pvalue = pearsonr(server_data['y_test'], predictions_lgbm)
r2 = r2_score(server_data['y_test'], predictions_lgbm)
rmse = mean_squared_error(server_data['y_test'], predictions_lgbm, squared=False)
mae = mean_absolute_error(server_data['y_test'], predictions_lgbm)

# plot predicted age against age
regplot = sns.regplot(
    x=server_data['y_test'], 
    y=predictions_lgbm,
    scatter_kws=dict(color=pal[0], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('LightGBM - UKB')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'LightGBM', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ukb_lgbm_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)


# plot d: ResNet in UKB
ax = plt.subplot(gs[3, 0], aspect='auto')
# ax.set_title('d', fontweight='bold', loc='left')


random_seed=3456
tuned_model_paths = {
    'MLP': 'ckpts/mlp_checkpoint.pt',
    'ResNet': 'ckpts/resnet_checkpoint.pt',
    'TabR': 'ckpts/tabr_checkpoint.pt'
}

data = {
    'X_test': server_data['X_test'][olink_names].values, 
    'y_test': None, 
    'X_train': server_data['X_train'][olink_names].values, 
    'y_train': server_data['y_train'].values}

predictions_resnet = nn_predict('ckpts/resnet_checkpoint.pt', data, device=torch.device('cpu'))

r, pvalue = pearsonr(server_data['y_test'], predictions_resnet)
mae = mean_absolute_error(server_data['y_test'], predictions_resnet)
rmse = np.sqrt(mean_squared_error(server_data['y_test'], predictions_resnet))
r2 = r2_score(server_data['y_test'], predictions_resnet)

# plot predicted age against age
regplot = sns.regplot(
    x=server_data['y_test'], 
    y=predictions_resnet,
    scatter_kws=dict(color=pal[1], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('ResNet - UKB')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'ResNet', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ukb_resnet_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)


# plot d: MLP in UKB
ax = plt.subplot(gs[4, 0], aspect='auto')
# ax.set_title('e', fontweight='bold', loc='left')


random_seed=3456
tuned_model_paths = {
    'MLP': 'ckpts/mlp_checkpoint.pt',
    'ResNet': 'ckpts/resnet_checkpoint.pt',
    'TabR': 'ckpts/tabr_checkpoint.pt'
}
predictions_mlp = nn_predict('ckpts/mlp_checkpoint.pt', data, device=torch.device('cpu'))

r, pvalue = pearsonr(server_data['y_test'], predictions_mlp)
mae = mean_absolute_error(server_data['y_test'], predictions_mlp)
rmse = np.sqrt(mean_squared_error(server_data['y_test'], predictions_mlp))
r2 = r2_score(server_data['y_test'], predictions_mlp)

# plot predicted age against age
regplot = sns.regplot(
    x=server_data['y_test'], 
    y=predictions_mlp,
    scatter_kws=dict(color=pal[2], s=10, alpha=0.8),
    line_kws=dict(color='red')
)
# Set the title for the regplot
ax.set_title('MLP - UKB')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'MLP', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ukb_mlp_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)
    
ax = plt.subplot(gs[5, 0], aspect='auto')
# ax.set_title('e', fontweight='bold', loc='left')
# ax.set_axis_off()

predictions_tabr = nn_predict('ckpts/tabr_checkpoint.pt', data, device=torch.device('cpu')).flatten()

r, pvalue = pearsonr(server_data['y_test'], predictions_tabr)
mae = mean_absolute_error(server_data['y_test'], predictions_tabr)
rmse = np.sqrt(mean_squared_error(server_data['y_test'], predictions_tabr))
r2 = r2_score(server_data['y_test'], predictions_tabr)

# plot predicted age against age
regplot = sns.regplot(
    x=server_data['y_test'], 
    y=predictions_tabr,
    scatter_kws=dict(color=pal[3], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('TabR - UKB')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'TabR', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# # Save the regplot object using pickle
# # with open(f'{filepath}output/plot_files/regplot_ukb_tabr_{now}.pkl', 'wb') as f:
# #     pickle.dump(regplot, f)   
    
# # Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)
    
    
    

# plot d: LASSO in CKB
ax = plt.subplot(gs[0, 1], aspect='auto')
ax.set_title('b', fontweight='bold', fontsize=22, loc='left')


ax1_xlim = (28,82)
ax1_ylim = (20,90)

# get predictions
predictions_lasso_ckb = lasso_tuned.predict(ckb_X)

# Evaluation metrics
r, pvalue = pearsonr(ckb_y, predictions_lasso_ckb)
r2 = r2_score(ckb_y, predictions_lasso_ckb)
rmse = mean_squared_error(ckb_y, predictions_lasso_ckb, squared=False)
mae = mean_absolute_error(ckb_y, predictions_lasso_ckb)

# plot predicted age against age
regplot = sns.regplot(
    x=ckb_y, 
    y=predictions_lasso_ckb,
    scatter_kws=dict(color=pal[4], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('LASSO - CKB')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'LASSO - CKB', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ckb_lasso_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)
    

# plot e: Elastic net in CKB
ax = plt.subplot(gs[1, 1], aspect='auto')  # Span the entire left column
# ax.set_title('g', fontweight='bold', loc='left')

# get predictions
predictions_enet_ckb = elastic_net_tuned.predict(ckb_X)

# Evaluation metrics
r, pvalue = pearsonr(ckb_y, predictions_enet_ckb)
r2 = r2_score(ckb_y, predictions_enet_ckb)
rmse = mean_squared_error(ckb_y, predictions_enet_ckb, squared=False)
mae = mean_absolute_error(ckb_y, predictions_enet_ckb)

# plot predicted age against age
regplot = sns.regplot(
    x=ckb_y, 
    y=predictions_enet_ckb,
    scatter_kws=dict(color=pal[5], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('Elastic Net - CKB')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'Elastic Net - CKB', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ckb_enet_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)


# plot f: LightGBM in CKB
ax = plt.subplot(gs[2, 1], aspect='auto')  # Span the entire left column
# ax.set_title('h', fontweight='bold', loc='left')

# get predictions
predictions_lgbm_ckb = lgbm_model.predict(ckb_X)

# Evaluation metrics
r, pvalue = pearsonr(ckb_y, predictions_lgbm_ckb)
r2 = r2_score(ckb_y, predictions_lgbm_ckb)
rmse = mean_squared_error(ckb_y, predictions_lgbm_ckb, squared=False)
mae = mean_absolute_error(ckb_y, predictions_lgbm_ckb)

# plot predicted age against age
regplot = sns.regplot(
    x=ckb_y, 
    y=predictions_lgbm_ckb,
    scatter_kws=dict(color=pal[0], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('LightGBM - CKB')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'LightGBM - CKB', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ckb_lgbm_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)


# plot e: ResNet in CKB
ax = plt.subplot(gs[3, 1], aspect='auto')  # Span the entire left column
# ax.set_title('i', fontweight='bold', loc='left')

random_seed=3456
tuned_model_paths = {
    'MLP': 'ckpts/mlp_checkpoint.pt',
    'ResNet': 'ckpts/resnet_checkpoint.pt',
    'TabR': 'ckpts/tabr_checkpoint.pt'
}

data_ckb = {
    'X_test': ckb_X[olink_names].values, 
    'y_test': None, 
    'X_train': server_data['X_train'].values, 
    'y_train': server_data['y_train'].values}

predictions_resnet_ckb = nn_predict('ckpts/resnet_checkpoint.pt', data_ckb, device=torch.device('cpu'))

r, pvalue = pearsonr(ckb_y, predictions_resnet_ckb)
mae = mean_absolute_error(ckb_y, predictions_resnet_ckb)
rmse = np.sqrt(mean_squared_error(ckb_y, predictions_resnet_ckb))
r2 = r2_score(ckb_y, predictions_resnet_ckb)

# plot predicted age against age
regplot = sns.regplot(
    x=ckb_y, 
    y=predictions_resnet_ckb,
    scatter_kws=dict(color=pal[1], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('ResNet - CKB')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'ResNet - CKB', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# # Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ckb_resnet_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)


# plot e: ProtAgeAccel distribution  in both
ax = plt.subplot(gs[4, 1], aspect='auto')  # Span the entire left column
# ax.set_title('j', fontweight='bold', loc='left')

random_seed=3456
tuned_model_paths = {
    'MLP': 'ckpts/mlp_checkpoint.pt',
    'ResNet': 'ckpts/resnet_checkpoint.pt',
    'TabR': 'ckpts/tabr_checkpoint.pt'
}

predictions_mlp_ckb = nn_predict('ckpts/mlp_checkpoint.pt', data_ckb, device=torch.device('cpu'))

r, pvalue = pearsonr(ckb_y, predictions_mlp_ckb)
mae = mean_absolute_error(ckb_y, predictions_mlp_ckb)
rmse = np.sqrt(mean_squared_error(ckb_y, predictions_mlp_ckb))
r2 = r2_score(ckb_y, predictions_mlp_ckb)

# plot predicted age against age
regplot = sns.regplot(
    x=ckb_y, 
    y=predictions_mlp_ckb,
    scatter_kws=dict(color=pal[2], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('MLP - CKB')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'MLP - CKB', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ckb_mlp_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)

ax = plt.subplot(gs[5, 1], aspect='auto')
# ax.set_title('e', fontweight='bold', loc='left')
# ax.set_axis_off()

predictions_tabr_ckb = nn_predict('ckpts/tabr_checkpoint.pt', data_ckb, device=torch.device('cpu')).flatten()

r, pvalue = pearsonr(ckb_y, predictions_tabr_ckb)
mae = mean_absolute_error(ckb_y, predictions_tabr_ckb)
rmse = np.sqrt(mean_squared_error(ckb_y, predictions_tabr_ckb))
r2 = r2_score(ckb_y, predictions_tabr_ckb)

# plot predicted age against age
regplot = sns.regplot(
    x=ckb_y, 
    y=predictions_tabr_ckb,
    scatter_kws=dict(color=pal[3], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('TabR - CKB')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# # plt.text(.95, .1, 'TabR - CKB', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ckb_tabr_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)    

# Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)




# LASSO in FinnGen
ax = plt.subplot(gs[0, 2], aspect='auto')
ax.set_title('c', fontweight='bold', fontsize=22, loc='left')


ax1_xlim = (15,82)
ax1_ylim = (0,110)

# get predictions
predictions_lasso_fg = lasso_tuned.predict(FinnGen_X)

# Evaluation metrics
r, pvalue = pearsonr(FinnGen_y, predictions_lasso_fg)
r2 = r2_score(FinnGen_y, predictions_lasso_fg)
rmse = mean_squared_error(FinnGen_y, predictions_lasso_fg, squared=False)
mae = mean_absolute_error(FinnGen_y, predictions_lasso_fg)

# plot predicted age against age
regplot = sns.regplot(
    x=FinnGen_y, 
    y=predictions_lasso_fg,
    scatter_kws=dict(color=pal[4], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('LASSO - FinnGen')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'LASSO - CKB', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ckb_lasso_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)
    

# Elastic net in FinnGen
ax = plt.subplot(gs[1, 2], aspect='auto')  # Span the entire left column
# ax.set_title('g', fontweight='bold', loc='left')

# get predictions
predictions_enet_fg = elastic_net_tuned.predict(FinnGen_X)

# Evaluation metrics
r, pvalue = pearsonr(FinnGen_y, predictions_enet_fg)
r2 = r2_score(FinnGen_y, predictions_enet_fg)
rmse = mean_squared_error(FinnGen_y, predictions_enet_fg, squared=False)
mae = mean_absolute_error(FinnGen_y, predictions_enet_fg)

# plot predicted age against age
regplot = sns.regplot(
    x=FinnGen_y, 
    y=predictions_enet_fg,
    scatter_kws=dict(color=pal[5], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('Elastic Net - FinnGen')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'Elastic Net - CKB', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ckb_enet_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)


# LightGBM in CKB
ax = plt.subplot(gs[2, 2], aspect='auto')  # Span the entire left column
# ax.set_title('h', fontweight='bold', loc='left')

# get predictions
predictions_lgbm_fg = lgbm_model.predict(FinnGen_X)

# Evaluation metrics
r, pvalue = pearsonr(FinnGen_y, predictions_lgbm_fg)
r2 = r2_score(FinnGen_y, predictions_lgbm_fg)
rmse = mean_squared_error(FinnGen_y, predictions_lgbm_fg, squared=False)
mae = mean_absolute_error(FinnGen_y, predictions_lgbm_fg)

# plot predicted age against age
regplot = sns.regplot(
    x=FinnGen_y, 
    y=predictions_lgbm_fg,
    scatter_kws=dict(color=pal[0], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('LightGBM - FinnGen')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'LightGBM - CKB', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ckb_lgbm_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)


# plot e: ResNet in CKB
ax = plt.subplot(gs[3, 2], aspect='auto')  # Span the entire left column
# ax.set_title('i', fontweight='bold', loc='left')

random_seed=3456
tuned_model_paths = {
    'MLP': 'ckpts/mlp_checkpoint.pt',
    'ResNet': 'ckpts/resnet_checkpoint.pt',
    'TabR': 'ckpts/tabr_checkpoint.pt'
}

data_fg = {
    'X_test': FinnGen_X[olink_names].values, 
    'y_test': None, 
    'X_train': server_data['X_train'].values, 
    'y_train': server_data['y_train'].values}

predictions_resnet_fg = nn_predict('ckpts/resnet_checkpoint.pt', data_fg, device=torch.device('cpu'))

r, pvalue = pearsonr(FinnGen_y, predictions_resnet_fg)
mae = mean_absolute_error(FinnGen_y, predictions_resnet_fg)
rmse = np.sqrt(mean_squared_error(FinnGen_y, predictions_resnet_fg))
r2 = r2_score(FinnGen_y, predictions_resnet_fg)

# plot predicted age against age
regplot = sns.regplot(
    x=FinnGen_y, 
    y=predictions_resnet_fg,
    scatter_kws=dict(color=pal[1], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('ResNet - FinnGen')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'ResNet - CKB', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# # Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ckb_resnet_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)


# plot e: ProtAgeAccel distribution  in both
ax = plt.subplot(gs[4, 2], aspect='auto')  # Span the entire left column
# ax.set_title('j', fontweight='bold', loc='left')

random_seed=3456
tuned_model_paths = {
    'MLP': 'ckpts/mlp_checkpoint.pt',
    'ResNet': 'ckpts/resnet_checkpoint.pt',
    'TabR': 'ckpts/tabr_checkpoint.pt'
}

predictions_mlp_fg = nn_predict('ckpts/mlp_checkpoint.pt', data_fg, device=torch.device('cpu'))

r, pvalue = pearsonr(FinnGen_y, predictions_mlp_fg)
mae = mean_absolute_error(FinnGen_y, predictions_mlp_fg)
rmse = np.sqrt(mean_squared_error(FinnGen_y, predictions_mlp_fg))
r2 = r2_score(FinnGen_y, predictions_mlp_fg)

# plot predicted age against age
regplot = sns.regplot(
    x=FinnGen_y, 
    y=predictions_mlp_fg,
    scatter_kws=dict(color=pal[2], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('MLP - FinnGen')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'MLP - CKB', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ckb_mlp_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)

ax = plt.subplot(gs[5, 2], aspect='auto')
# ax.set_title('e', fontweight='bold', loc='left')
# ax.set_axis_off()

predictions_tabr_fg = nn_predict('ckpts/tabr_checkpoint.pt', data_fg, device=torch.device('cpu')).flatten()

r, pvalue = pearsonr(FinnGen_y, predictions_tabr_fg)
mae = mean_absolute_error(FinnGen_y, predictions_tabr_fg)
rmse = np.sqrt(mean_squared_error(FinnGen_y, predictions_tabr_fg))
r2 = r2_score(FinnGen_y, predictions_tabr_fg)

# plot predicted age against age
regplot = sns.regplot(
    x=FinnGen_y, 
    y=predictions_tabr_fg,
    scatter_kws=dict(color=pal[3], s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# Set the title for the regplot
ax.set_title('TabR - FinnGen')

# add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'TabR - CKB', ha='right', va='top', transform=regplot.transAxes)

# p-value = {pvalue:.2e} 
regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# Save the regplot object using pickle
# with open(f'{filepath}output/plot_files/regplot_ckb_tabr_{now}.pkl', 'wb') as f:
#     pickle.dump(regplot, f)    

# Set x and y limits
ax.set_xlim(ax1_xlim)
ax.set_ylim(ax1_ylim)



### final steps
# Adjust the layout
plt.tight_layout()

# Show the plot
# plt.show()

# # save
name = f'/home/ivm/Documents/model_benchmark_plots_with_NN_{now}.png'
# name = f'/home/ivm/Documents/model_benchmark_plots_with_NN_square_{now}.png'
plt.savefig(
    name,
    dpi=600,
    facecolor='white',
    transparent=False,
    bbox_inches="tight"
)
plt.close()

# End

In [None]:
pd.Series(predictions_lasso_fg).describe()

In [None]:
# load plots from UKB and CKB
import matplotlib.backend_bases

path = '/finngen/green/austina/regplot_ukb_lasso.pkl'
with open(path, 'rb') as f:
    plot_ukb_lasso = pickle.load(f)

path = '/finngen/green/austina/regplot_ukb_enet.pkl'
with open(path, 'rb') as f:
    plot_ukb_enet = pickle.load(f)

path = '/finngen/green/austina/regplot_ukb_lgbm.pkl'
with open(path, 'rb') as f:
    plot_ukb_lgbm = pickle.load(f)

path = '/finngen/green/austina/regplot_ckb_lasso.pkl'
with open(path, 'rb') as f:
    plot_ckb_lasso = pickle.load(f)

path = '/finngen/green/austina/regplot_ckb_enet.pkl'
with open(path, 'rb') as f:
    plot_ckb_enet = pickle.load(f)

path = '/finngen/green/austina/regplot_ckb_lgbm.pkl'
with open(path, 'rb') as f:
    plot_ckb_lgbm = pickle.load(f)

In [None]:
from scipy.stats import pearsonr
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error

subset = olink_normalized.copy()
# subset = subset[(subset['FU_END_AGE'] > 30) & (subset['FU_END_AGE'] < 80)]

X = subset[olink_names]

y_pred = lgbm_ukb.predict(X)
r, pvalue = pearsonr(subset['FU_END_AGE'], y_pred)
r2 = r2_score(subset['FU_END_AGE'], y_pred)
rmse = mean_squared_error(subset['FU_END_AGE'], y_pred, squared=False)
mae = mean_absolute_error(subset['FU_END_AGE'], y_pred)

# y_pred = model.predict(ckb_normalized[olink_names])
# r, pvalue = pearsonr(ckb_normalized['recruitment_age'], y_pred)
# r2 = r2_score(ckb_normalized['recruitment_age'], y_pred)

# plot predicted age against age
regplot = sns.regplot(
    x=subset['FU_END_AGE'], 
    y=y_pred,
    scatter_kws=dict(color='orange', s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# add annotation
annotation_text = f'r = {r:.2f}\nR² = {r2:.2f}\nRMSE = {rmse:.2f}\nMAE = {mae:.2f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
plt.text(.95, .1, 'LightGBM - FinnGen', ha='right', va='top', transform=regplot.transAxes)

regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

plt.show()

# # Save the regplot object using pickle
# # with open('/finngen/red/austin_FinnGen_lgbm_plot.pkl', 'wb') as file:
# with open('/home/ivm/Documents/austin_FinnGen_lgbm_plot.pkl', 'wb') as file:
#     pickle.dump(regplot, file)

# plt.savefig(
# #     '/finngen/red/austin_FinnGen_protAge_dec_5_2023.png',
#     '/home/ivm/Documents/austin_FinnGen_protAge_dec_5_2023.png',
#     dpi=600,
#     facecolor='white',
#     transparent=False,
#     bbox_inches="tight"
# )
# plt.close()

In [None]:
path = '/finngen/green/austina/lasso_model_2023-11-28.p'
with open(path, 'rb') as f:
    lasso_ukb = pickle.load(f)

subset = olink_normalized.copy()
# subset = subset[(subset['FU_END_AGE'] > 30) & (subset['FU_END_AGE'] < 80)]

X = subset[olink_names]

y_pred = lasso_ukb.predict(X)
r, pvalue = pearsonr(subset['FU_END_AGE'], y_pred)
r2 = r2_score(subset['FU_END_AGE'], y_pred)
rmse = mean_squared_error(subset['FU_END_AGE'], y_pred, squared=False)
mae = mean_absolute_error(subset['FU_END_AGE'], y_pred)

# y_pred = model.predict(ckb_normalized[olink_names])
# r, pvalue = pearsonr(ckb_normalized['recruitment_age'], y_pred)
# r2 = r2_score(ckb_normalized['recruitment_age'], y_pred)

# plot predicted age against age
regplot = sns.regplot(
    x=subset['FU_END_AGE'], 
    y=y_pred,
    scatter_kws=dict(color='orange', s=10, alpha=0.8),
    line_kws=dict(color='red')
)

# add annotation
annotation_text = f'r = {r:.2f}\nR² = {r2:.2f}\nRMSE = {rmse:.2f}\nMAE = {mae:.2f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
plt.text(.95, .1, 'LASSO - FinnGen', ha='right', va='top', transform=regplot.transAxes)

regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

plt.show()

# # Save the regplot object using pickle
# with open('/finngen/green/austina/FinnGen_lasso_plot.pkl', 'wb') as file:
#     pickle.dump(regplot, file)

# plt.savefig(
#     '/finngen/green/austina/FinnGen_protAge_lasso_dec_5_2023.png',
#     dpi=600,
#     facecolor='white',
#     transparent=False,
#     bbox_inches="tight"
# )
# plt.close()

In [None]:
len(df_rolled.index)

In [None]:
import numpy as np

path = '/finngen/green/austina/elastic_net_model_2023-11-28.p'
with open(path, 'rb') as f:
    enet_ukb = pickle.load(f)

subset = olink_normalized.copy()
# subset = subset[(subset['FU_END_AGE'] > 30) & (subset['FU_END_AGE'] < 80)]

X = subset[olink_names]

y_pred = enet_ukb.predict(X)
r, pvalue = pearsonr(subset['FU_END_AGE'], y_pred)
r2 = r2_score(subset['FU_END_AGE'], y_pred)
rmse = mean_squared_error(subset['FU_END_AGE'], y_pred, squared=False)
mae = mean_absolute_error(subset['FU_END_AGE'], y_pred)

# y_pred = model.predict(ckb_normalized[olink_names])
# r, pvalue = pearsonr(ckb_normalized['recruitment_age'], y_pred)
# r2 = r2_score(ckb_normalized['recruitment_age'], y_pred)

# plot predicted age against age
# regplot = sns.regplot(
#     x=subset['FU_END_AGE'], 
#     y=y_pred,
#     scatter_kws=dict(color='orange', s=10, alpha=0.8),
#     line_kws=dict(color='red')
# )

# # add annotation
# annotation_text = f'r = {r:.2f}\nR² = {r2:.2f}\nRMSE = {rmse:.2f}\nMAE = {mae:.2f}'
# plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
# plt.text(.95, .1, 'Elastic net - FinnGen', ha='right', va='top', transform=regplot.transAxes)

# regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# plt.show()

# Create a DataFrame for seaborn
data = pd.DataFrame({'Age at recruitment (years)': subset['FU_END_AGE'], 'ProtAge': y_pred})

# Calculate the average of every 5 consecutive points
# data_avg = data.sort_values("Age at recruitment (years)")
data_avg = data.sort_values("ProtAge")
data_avg.reset_index(inplace=True)

# Create a new DataFrame with the rolling mean
df_rolled = data_avg.groupby(data_avg.index // 5).mean()



# Add best-fit line using regplot
# sns.regplot(
#     data=df_rolled,
#     x='Age at recruitment (years)',
#     y='ProtAge',
#     scatter=True,
#     line_kws=dict(color='orange')
# )

regplot = sns.regplot(
    data=df_rolled,
    x='Age at recruitment (years)',
    y='ProtAge',
    scatter_kws=dict(color='orange', s=20, alpha=0.8),
    line_kws=dict(color='red')
)


# Create a contour plot
# contourplot = sns.kdeplot(
#     data=data,
#     x='Age at recruitment (years)',
#     y='ProtAge',
#     cmap="Blues",  # You can choose a different colormap
#     shade=True
# )

# data.plot(
#     kind='hexbin', 
#     x='Age at recruitment (years)',
#     y='ProtAge', 
#     gridsize = 13
# )
# plt.hexbin(
#     data=data,
#     x='Age at recruitment (years)',
#     y='ProtAge', 
#     gridsize = 30,
#     cmap='autumn',
#     mincnt=1,
#     bins='log'
# )

# Add best-fit line using regplot
# sns.regplot(
#     data=data,
#     x='Age at recruitment (years)',
#     y='ProtAge',
#     scatter=False,  # Disable scatter points
#     line_kws=dict(color='red'),  # Line color
#     ax=contourplot
# )

# Set explicit axis limits
# contourplot.set_xlim(min(subset['FU_END_AGE'])-5, max(subset['FU_END_AGE'])+5)
# contourplot.set_ylim(min(y_pred)-5, max(y_pred)+10)

# Add contour lines
# contourplot.collections[0].set_label('Density Contours')

# Add annotation
annotation_text = f'r = {r:.4f}\nR² = {r2:.4f}\nRMSE = {rmse:.4f}\nMAE = {mae:.4f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=contourplot.transAxes)
plt.text(.95, .1, 'Elastic net - FinnGen', ha='right', va='top', transform=contourplot.transAxes)


# Show the plot
plt.show()

# # Save the regplot object using pickle
# with open('/finngen/green/austina/FinnGen_lasso_plot.pkl', 'wb') as file:
#     pickle.dump(regplot, file)

# plt.savefig(
#     '/finngen/green/austina/FinnGen_protAge_lasso_dec_5_2023.png',
#     dpi=600,
#     facecolor='white',
#     transparent=False,
#     bbox_inches="tight"
# )
# plt.close()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.gridspec as gridspec

from scipy.stats import pearsonr
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from matplotlib.ticker import FuncFormatter

sns.set_style("ticks")

# Create a new figure and specify the layout using gridspec
fig = plt.figure(figsize=(3.8, 10))
# fig = plt.figure(figsize=(11, 10))

gs = gridspec.GridSpec(3, 1, figure=fig, width_ratios=[1], height_ratios=[2, 2, 2])
# gs = gridspec.GridSpec(3, 3, figure=fig, width_ratios=[1, 1, 1], height_ratios=[2, 2, 2])



# plot a: LASSO
ax = plt.subplot(gs[0, 0], aspect='auto')  # Span the entire left column
ax.set_title('a', fontweight='bold', loc='left')

subset = olink_normalized.copy()
# subset = subset[(subset['FU_END_AGE'] > 30) & (subset['FU_END_AGE'] < 80)]

X = subset[olink_names]

y_pred = lasso_ukb.predict(X)
r, pvalue = pearsonr(subset['FU_END_AGE'], y_pred)
r2 = r2_score(subset['FU_END_AGE'], y_pred)
rmse = mean_squared_error(subset['FU_END_AGE'], y_pred, squared=False)
mae = mean_absolute_error(subset['FU_END_AGE'], y_pred)

# Create a DataFrame for seaborn
data = pd.DataFrame({'Age at recruitment (years)': subset['FU_END_AGE'], 'ProtAge': y_pred})

# Calculate the average of every 5 consecutive points
data_avg = data.sort_values("ProtAge")
data_avg.reset_index(inplace=True)
df_rolled = data_avg.groupby(data_avg.index // 5).mean()

# plot predicted age against age
# regplot = sns.regplot(
#     x=subset['FU_END_AGE'], 
#     y=y_pred,
#     scatter_kws=dict(color='orange', s=10, alpha=0.8),
#     line_kws=dict(color='red')
# )

regplot = sns.regplot(
    data=df_rolled,
    x='Age at recruitment (years)',
    y='ProtAge',
    scatter_kws=dict(color='orange', s=30, alpha=0.8),
    line_kws=dict(color='red')
)


# add annotation
annotation_text = f'r = {r:.2f}\nR² = {r2:.2f}\nRMSE = {rmse:.2f}\nMAE = {mae:.2f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
plt.text(.95, .1, 'LASSO - FinnGen', ha='right', va='top', transform=regplot.transAxes)

regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')


# plot b: Elastic net
ax = plt.subplot(gs[1, 0], aspect='auto')  # Span the entire left column
ax.set_title('b', fontweight='bold', loc='left')

y_pred = enet_ukb.predict(X)
r, pvalue = pearsonr(subset['FU_END_AGE'], y_pred)
r2 = r2_score(subset['FU_END_AGE'], y_pred)
rmse = mean_squared_error(subset['FU_END_AGE'], y_pred, squared=False)
mae = mean_absolute_error(subset['FU_END_AGE'], y_pred)

# Create a DataFrame for seaborn
data = pd.DataFrame({'Age at recruitment (years)': subset['FU_END_AGE'], 'ProtAge': y_pred})

# Calculate the average of every 5 consecutive points
data_avg = data.sort_values("ProtAge")
data_avg.reset_index(inplace=True)
df_rolled = data_avg.groupby(data_avg.index // 5).mean()

# plot predicted age against age
# regplot = sns.regplot(
#     x=subset['FU_END_AGE'], 
#     y=y_pred,
#     scatter_kws=dict(color='orange', s=10, alpha=0.8),
#     line_kws=dict(color='red')
# )

regplot = sns.regplot(
    data=df_rolled,
    x='Age at recruitment (years)',
    y='ProtAge',
    scatter_kws=dict(color='orange', s=30, alpha=0.8),
    line_kws=dict(color='red')
)

# add annotation
annotation_text = f'r = {r:.2f}\nR² = {r2:.2f}\nRMSE = {rmse:.2f}\nMAE = {mae:.2f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
plt.text(.95, .1, 'Elastic net - FinnGen', ha='right', va='top', transform=regplot.transAxes)

regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')

# plot c: LightGBM
ax = plt.subplot(gs[2, 0], aspect='auto')
ax.set_title('c', fontweight='bold', loc='left')

y_pred = lgbm_ukb.predict(X)
r, pvalue = pearsonr(subset['FU_END_AGE'], y_pred)
r2 = r2_score(subset['FU_END_AGE'], y_pred)
rmse = mean_squared_error(subset['FU_END_AGE'], y_pred, squared=False)
mae = mean_absolute_error(subset['FU_END_AGE'], y_pred)

# Create a DataFrame for seaborn
data = pd.DataFrame({'Age at recruitment (years)': subset['FU_END_AGE'], 'ProtAge': y_pred})

# Calculate the average of every 5 consecutive points
data_avg = data.sort_values("ProtAge")
data_avg.reset_index(inplace=True)
df_rolled = data_avg.groupby(data_avg.index // 5).mean()

# plot predicted age against age
# regplot = sns.regplot(
#     x=subset['FU_END_AGE'], 
#     y=y_pred,
#     scatter_kws=dict(color='orange', s=10, alpha=0.8),
#     line_kws=dict(color='red')
# )

regplot = sns.regplot(
    data=df_rolled,
    x='Age at recruitment (years)',
    y='ProtAge',
    scatter_kws=dict(color='orange', s=30, alpha=0.8),
    line_kws=dict(color='red')
)

# add annotation
annotation_text = f'r = {r:.2f}\nR² = {r2:.2f}\nRMSE = {rmse:.2f}\nMAE = {mae:.2f}'
plt.text(.05, .95, annotation_text, ha='left', va='top', transform=regplot.transAxes)
plt.text(.95, .1, 'LightGBM - FinnGen', ha='right', va='top', transform=regplot.transAxes)

regplot.set(xlabel='Age at recruitment (years)', ylabel='ProtAge')


### final steps
# Adjust the layout
plt.tight_layout()

# Show the plot
# plt.show()

# save
name = '/home/ivm/Documents/finngen_model_comparisons_averaged.png'
plt.savefig(
    name,
    dpi=600,
    facecolor='white',
    transparent=False,
    bbox_inches="tight"
)
plt.close()