In [1]:
output_dir = '/vol/bmd/yanyul/UKB/ptrs-tf/models'
population = 'British'  # for test 'Chinese'
pred_expr_name = 'MESA_CAU'
logfile = f'/vol/bmd/yanyul/UKB/ptrs-tf/models/elastic_net_{pred_expr_name}.log'

In [2]:
import sys
sys.path.append("../code/")
import util_ElasticNet, lib_LinearAlgebra, util_hdf5, lib_ElasticNet, lib_Checker
import util_misc
import tensorflow as tf
import numpy as np
import pandas as pd
import h5py, yaml, functools
import matplotlib.pyplot as plt
from importlib import reload
lib_LinearAlgebra = reload(lib_LinearAlgebra)
util_ElasticNet = reload(util_ElasticNet)
util_hdf5 = reload(util_hdf5)
lib_ElasticNet = reload(lib_ElasticNet)
lib_Checker = reload(lib_Checker)
util_misc = reload(util_misc)
import util_hdf5
import logging, sys
import seaborn as sns
logging.basicConfig(
    level = logging.INFO, 
#     stream = sys.stderr,
    filename = logfile,
    format = '%(asctime)s  %(message)s',
    datefmt = '%Y-%m-%d %I:%M:%S %p'
)

# Analysis overview

Building PTRS using Elastic Net. 

1. Split British data into 3 sets: training, test, validation.
2. Train a sequence of elastic net predictors along regularization path using British training data.
3. Repeat step 2 for $\alpha = 0.1, 0.5, 0.9$

More about setting up: 
$\frac{\lambda_{max}}{\lambda_{min}} = 10^6$. 
nlambda = 50.
Number of max iteration is 100.
Batch size is roughly 1/5 of sample size.

For MESA dataset, we take the genes in both EUR models and AFR+HIS models. 

# Load data

In [3]:
# set path to British data
mesa_cau = f'/vol/bmd/yanyul/UKB/predicted_expression_tf2/ukb_imp_x_MESA_CAU_{population}.hdf5'
mesa_afhi = f'/vol/bmd/yanyul/UKB/predicted_expression_tf2/ukb_imp_x_MESA_AFHI_{population}.hdf5'

# extract the gene names 
with h5py.File(mesa_cau, 'r') as f:
    genes_cau = f['columns_x'][...].astype(str)
with h5py.File(mesa_afhi, 'r') as f:
    genes_afhi = f['columns_x'][...].astype(str)

# get the genes occur in both models
x_indice_cau, x_indice_afhi = util_misc.intersect_indice(genes_cau, genes_afhi)

# data scheme specifying which are traits and covariates
scheme_yaml = '../misc_files/data_scheme.yaml'

# loading names of traits/covariates
# the order is matched with the data being loaded
feature_dic = util_hdf5.read_yaml(scheme_yaml)
with h5py.File(mesa_cau, 'r') as f:
    features = f['columns_y'][:].astype('str')
    sample_size = f['y'].shape[0]
    y = f['y'][:]
covar_indice = np.where(np.isin(features, feature_dic['covar_names']))[0]
trait_indice = np.where(np.isin(features, feature_dic['outcome_names']))[0]

In [4]:
# sample_size

In [5]:
logging.info('Features in order')
logging.info(features)

In [6]:
# load data_scheme for training
batch_size_to_load = 2 ** 12  # int(sample_size / 8) + 1
logging.info(f'batch_size in {population} set is {batch_size_to_load}')
data_scheme, sample_size = util_hdf5.build_data_scheme(
    mesa_cau, 
    scheme_yaml, 
    batch_size = batch_size_to_load, 
    inv_norm_y = True,
    x_indice = x_indice_cau
)

# set validation and test set as the first and second batch
# dataset_valid = data_scheme.dataset.take(1)
data_scheme.dataset = data_scheme.dataset.skip(1)
# dataset_test = data_scheme.dataset.take(1)
data_scheme.dataset = data_scheme.dataset.skip(1)
batch_size = int(sample_size / 4) + 1
data_scheme.dataset = data_scheme.dataset.unbatch().batch(batch_size)
# dataset_insample = data_scheme.dataset.take(1)
ntrain = sample_size - batch_size_to_load * 2
train_batch = batch_size
logging.info(f'train_batch = {train_batch}, ntrain = {ntrain}')
# data_scheme.dataset = data_scheme.dataset.take(10)

# Training

In [7]:
alpha_list = [0.1, 0.5, 0.9]
learning_rate = 1

for alpha in alpha_list:
    logging.info('alpha = {} starts'.format(alpha))
    lambda_init_dict = {
        'data_init': None, 
        'prefactor_of_lambda_max': 1.5,
        'lambda_max_over_lambda_min': 1e6,
        'nlambda': 50
    }
    updater = lib_ElasticNet.ProximalUpdater(learning_rate = learning_rate, line_search = True)
    update_dic = {
        'updater': updater,
        'update_fun': updater.proximal_train_step
    }
    my_stop_rule = functools.partial(lib_Checker.diff_stop_rule, threshold = 1e-3)
    ny = len(data_scheme.outcome_indice)
    elastic_net_estimator = lib_LinearAlgebra.ElasticNetEstimator(
        data_scheme,
        alpha,
        normalizer = True,
        learning_rate = learning_rate,
        lambda_init_dict = lambda_init_dict,
        updater = update_dic
    )
    checker = [ lib_Checker.Checker(ntrain, train_batch, lib_Checker.my_stat_fun, my_stop_rule) 
               for i in range(ny) ]

    elastic_net_estimator.solve(checker, nepoch = 100, logging = logging)
    
    
    outfile = f'{output_dir}/elastic_net_{pred_expr_name}_alpha_{alpha}_{population}.hdf5'
    logging.info(f'alpha = {alpha} saving to {outfile}')
    elastic_net_estimator.minimal_save(outfile)
    logging.info('alpha = {} ends'.format(alpha))

Saving lambda_seq
Saving beta_hat_path
Saving covar_hat_path
Saving intercept_path
Saving normalizer
Saving alpha
Saving data_scheme.dataset
Saving data_scheme.X_index
Saving data_scheme.Y_index
Saving data_scheme.outcome_indice
Saving data_scheme.covariate_indice
Saving data_scheme.x_indice
Saving data_scheme.num_predictors
Saving lambda_seq
Saving beta_hat_path
Saving covar_hat_path
Saving intercept_path
Saving normalizer
Saving alpha
Saving data_scheme.dataset
Saving data_scheme.X_index
Saving data_scheme.Y_index
Saving data_scheme.outcome_indice
Saving data_scheme.covariate_indice
Saving data_scheme.x_indice
Saving data_scheme.num_predictors
Saving lambda_seq
Saving beta_hat_path
Saving covar_hat_path
Saving intercept_path
Saving normalizer
Saving alpha
Saving data_scheme.dataset
Saving data_scheme.X_index
Saving data_scheme.Y_index
Saving data_scheme.outcome_indice
Saving data_scheme.covariate_indice
Saving data_scheme.x_indice
Saving data_scheme.num_predictors


Saving lambda_seq
Saving beta_hat_path
Saving covar_hat_path
Saving intercept_path
Saving normalizer
Saving alpha
Saving data_scheme.dataset
Saving data_scheme.X_index
Saving data_scheme.Y_index
Saving data_scheme.outcome_indice
Saving data_scheme.covariate_indice
Saving data_scheme.x_indice
Saving data_scheme.num_predictors


In [8]:
#################### for test below ########################

In [9]:
# elastic_net_estimator.minimal_save('test.hdf5')

In [10]:
# # load data_scheme for training
# batch_size_to_load = 2 ** 8 # 2 ** 12  # int(sample_size / 8) + 1
# logging.info(f'batch_size in {population} set is {batch_size_to_load}')
# data_scheme, sample_size = util_hdf5.build_data_scheme(
#     mesa_afhi, 
#     scheme_yaml, 
#     batch_size = batch_size_to_load, 
#     inv_norm_y = True,
#     x_indice = x_indice_afhi  # x_indice_afhi
# )

# # set validation and test set as the first and second batch
# # dataset_valid = data_scheme.dataset.take(1)
# data_scheme.dataset = data_scheme.dataset.skip(1)
# # dataset_test = data_scheme.dataset.take(1)
# data_scheme.dataset = data_scheme.dataset.skip(1)
# batch_size = int(sample_size / 4) + 1
# data_scheme.dataset = data_scheme.dataset.unbatch().batch(batch_size)
# # dataset_insample = data_scheme.dataset.take(1)
# ntrain = sample_size - batch_size_to_load * 2
# train_batch = batch_size
# logging.info(f'train_batch = {train_batch}, ntrain = {ntrain}')
# # data_scheme.dataset = data_scheme.dataset.take(10)

In [11]:
# model = lib_LinearAlgebra.ElasticNetEstimator('', None, minimal_load = True)
# model.minimal_load('test.hdf5')

In [12]:
# model.data_scheme.x_indice = data_scheme.x_indice

In [13]:
# out = model.predict_x(data_scheme.dataset, model.beta_hat_path)

In [14]:
# # out = o1
# fig, aes = plt.subplots(nrows = 3, ncols = 3, figsize = (15, 10))
# seq = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 16, 19]
# for i in range(3):
#     for j in range(3):
#         idx = seq[(i * 3 + j)] + 40
#         if idx < len(model.lambda_seq[0]):
#             for k in range(1):
#                 aes[i][j].scatter(out['y'][:,k], out['y_pred_from_x'][:, k, idx])
#             aes[i][j].set_title(
#                 'lambda = ' + "{:.3E} cor = {:.3E}".format(
#                     model.lambda_seq[0][idx], 
#                     np.corrcoef(out['y'][:, 0], out['y_pred_from_x'][:, 0, idx])[0,1]
#                 ) # + '\n' +
# #                 'lambda = ' + "{:.3E} cor = {:.3E}".format(
# #                     model_list[alpha].lambda_seq[1][idx], 
# #                     np.corrcoef(out['y'][:, 1], out['y_pred_from_x'][:, 1, idx])[0,1]
# #                 )
#             )