In [None]:
import tensorflow as tf
import pandas as pd

from tensorflow import keras


from tensorflow import feature_column
from tensorflow.keras import layers

import tensorflow.keras.backend as K



In [None]:
def get_allowed_snp_names():
    df = pd.read_csv('allowed_snp.csv')
    df['snp'] = df['snp'].apply(lambda x : x.split('_')[1])
    df = df.sort_values(by=['snp'])
    snps = df['snp'].tolist()
    return snps

- 首先对34个位点进行分块，针对每块做self_attention 或者lstm，没有分块的直接进入dense网络
- 然后针对性的依次移除位点，来检验位点的重要性
- 为了训练更充足，可以进行预训练，

In [None]:
# 用来返回定义的特征
def get_feature_columns():
    
    features = []
    
    cne_feature = feature_column.numeric_column('CNE')  #累积噪声量
    features.append(cne_feature)
    
    age_feature = feature_column.numeric_column('age')
    features.append(age_feature)
    
    sex_feature = feature_column.categorical_column_with_vocabulary_list('sex', [1, 2])
    sex_feature_one_hot = feature_column.indicator_column(sex_feature)
    features.append(sex_feature_one_hot)
    
    smoke_feature = feature_column.categorical_column_with_vocabulary_list('smoke', [1, 0])
    smoke_feature_one_hot = feature_column.indicator_column(smoke_feature)
    features.append(smoke_feature_one_hot)
    
    drink_feature = feature_column.categorical_column_with_vocabulary_list('drink', [1, 0])
    drink_feature_one_hot = feature_column.indicator_column(drink_feature)
    features.append(drink_feature_one_hot)
    
    excercise_feature = feature_column.categorical_column_with_vocabulary_list('excercise', [1, 0])
    excercise_feature_one_hot = feature_column.indicator_column(excercise_feature)
    features.append(excercise_feature_one_hot)
    
    hp_feature = feature_column.categorical_column_with_vocabulary_list('HP', [1, 0])
    hp_feature_one_hot = feature_column.indicator_column(hp_feature)
    features.append(hp_feature_one_hot)
    
    bmi_feature = feature_column.numeric_column('BMI')
    features.append(bmi_feature)
    
    #hl_feature = feature_column.numeric_column('HL')
    #features.append(hl_feature)
    
    # build cross feature
    
    rs1358714_x_rs1200130_feature = feature_column.crossed_column(['rs1358714', 'rs1200130'], 9)
    features.append(feature_column.indicator_column(rs1358714_x_rs1200130_feature))
    
    rs17412009_x_rs1200130_feature = feature_column.crossed_column(['rs17412009', 'rs1200130'], 9)
    features.append(feature_column.indicator_column(rs17412009_x_rs1200130_feature))
    
    rs2070703_x_rs1200130_feature = feature_column.crossed_column(['rs2070703', 'rs1200130'], 9)
    features.append(feature_column.indicator_column(rs2070703_x_rs1200130_feature))
    
    rs6458080_x_rs1200130_feature = feature_column.crossed_column(['rs6458080', 'rs1200130'], 9)
    features.append(feature_column.indicator_column(rs6458080_x_rs1200130_feature))
    
    rs17412009_x_rs1200135_feature = feature_column.crossed_column(['rs17412009', 'rs1200135'], 9)
    features.append(feature_column.indicator_column(rs17412009_x_rs1200135_feature))
    
    rs1200137_x_rs6458080_feature = feature_column.crossed_column(['rs1200137', 'rs6458080'], 9)
    features.append(feature_column.indicator_column(rs1200137_x_rs6458080_feature))
    
    rs17412009_x_rs1358714_feature = feature_column.crossed_column(['rs17412009', 'rs1358714'], 9)
    features.append(feature_column.indicator_column(rs17412009_x_rs1358714_feature))
    
    rs6458080_x_rs1678690_feature = feature_column.crossed_column(['rs6458080', 'rs1678690'], 9)
    features.append(feature_column.indicator_column(rs6458080_x_rs1678690_feature))
    
    rs17412009_x_rs6458080_feature = feature_column.crossed_column(['rs17412009', 'rs6458080'], 9)
    features.append(feature_column.indicator_column(rs17412009_x_rs6458080_feature))
    
    rs1200137_x_rs1358714_x_rs1200130_feature = feature_column.crossed_column(['rs1200137', 'rs1358714','rs1200130'], 27)
    features.append(feature_column.indicator_column(rs1200137_x_rs1358714_x_rs1200130_feature))
    
    rs1200137_x_rs17412009_x_rs1200135_feature = feature_column.crossed_column(['rs1200137', 'rs17412009','rs1200135'], 27)
    features.append(feature_column.indicator_column(rs1200137_x_rs17412009_x_rs1200135_feature))

    rs17412009_x_rs1200135_x_rs1358714_feature = feature_column.crossed_column(['rs17412009', 'rs1200135','rs1358714'], 27)
    features.append(feature_column.indicator_column(rs17412009_x_rs1200135_x_rs1358714_feature))

    rs1200137_x_rs17412009_x_rs1358714_feature = feature_column.crossed_column(['rs1200137', 'rs17412009','rs1358714'], 27)
    features.append(feature_column.indicator_column(rs1200137_x_rs17412009_x_rs1358714_feature))
    
    valid_snp_names = get_allowed_snp_names()
    
    snp_ga = ('rs10091503', 'rs1026435', 'rs10503675', 
              'rs11778205', 'rs1200135', 'rs1358714', 'rs1678674',
              'rs1738254', 'rs3737094', 'rs3807154', 'rs3823430', 
              'rs4452640', 'rs874808', 'rs9357283')
    
    for snp in valid_snp_names:
        if snp in snp_ga:
            snp_feature = feature_column.categorical_column_with_vocabulary_list(snp, ['G/A', 'G/G', 'A/A' ])
        else:
            snp_feature = feature_column.categorical_column_with_vocabulary_list(snp, ['C/C', 'T/C', 'T/T' ])
        snp_feature_one_hot = feature_column.indicator_column(snp_feature)
        features.append(snp_feature_one_hot)
    return features
    

In [None]:
#  位点进行分组
def get_feature_block(features):
    blocks = []
    
    blocks.append(['rs1358714', 'rs1200130'])
    blocks.append(['rs17412009','rs1200130'])
    blocks.append(['rs2070703','rs1200130'])
    blocks.append(['rs6458080','rs1200130'])
    blocks.append(['rs17412009','rs1200135'])
    blocks.append(['rs1200137','rs6458080'])
    blocks.append(['rs17412009','rs1358714'])
    blocks.append(['rs6458080','rs1678690'])
    blocks.append(['rs17412009','rs6458080'])
    blocks.append(['rs1200137','rs1358714','rs1200130'])
    blocks.append(['rs1200137','rs17412009','rs1200135'])
    blocks.append(['rs17412009','rs1200135','rs1358714'])
    blocks.append(['rs1200137','rs17412009','rs1358714'])
    
    return blocks

def get_block_features():
    blocks = []
    
    blocks += ['rs1358714', 'rs1200130']
    blocks += ['rs17412009','rs1200130']
    blocks += ['rs2070703','rs1200130']
    blocks += ['rs6458080','rs1200130']
    blocks += ['rs17412009','rs1200135']
    blocks += ['rs1200137','rs6458080']
    blocks += ['rs17412009','rs1358714']
    blocks += ['rs6458080','rs1678690']
    blocks += ['rs17412009','rs6458080']
    blocks += ['rs1200137','rs1358714','rs1200130']
    blocks += ['rs1200137','rs17412009','rs1200135']
    blocks += ['rs17412009','rs1200135','rs1358714']
    blocks += ['rs1200137','rs17412009','rs1358714']
    
    return blocks

In [None]:
def df_to_dataset(dataframe, shuffle=True, batch_size=10, drop_col = None):
    dataframe = dataframe.copy()
    if drop_col is not None:
        for col in drop_col:
            dataframe[col] = 'G/A'
    labels = dataframe.pop('caco')
    dataframe.pop('Simple Name')
    ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
    ds.repeat()
    if shuffle:
        ds = ds.shuffle(buffer_size=len(dataframe))
        ds = ds.batch(batch_size)
    return ds

In [None]:
def build_input_fn(file, drop_col=None):
    # 生成输入样本
    df = pd.read_csv(file)
    df = df.dropna()
    return df_to_dataset(df, drop_col = drop_col)

In [None]:
def build_all_input_fn():
    df = pd.read_csv('1.csv')
    df = df.dropna()
    df1 = pd.read_csv('2.csv')
    df1 = df1.dropna()
    final_df = pd.concat([df, df1],ignore_index=True)
    final_df = final_df.sample(frac=1).reset_index(drop=True)
    final_train_df = final_df.sample(frac=0.8, random_state=0, axis=0)
    
    final_test_df = final_df[~final_df.index.isin(final_train_df.index)]
    print(final_test_df.head())
    return df_to_dataset(final_train_df), df_to_dataset(final_test_df)

In [None]:
# 输入block，输出attention之后的塔 embedding
def build_lstm_tower(features, feature_columns):
    feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
    feature_input = feature_layer(features)
    seq_num = len(feature_columns)
    embedding_out = tf.keras.layers.Dense(seq_num * 8)(feature_input)
    attention_input = tf.reshape(embedding_out, [-1, seq_num, 8])
    lstm = tf.keras.layers.LSTM(4, activation = 'relu')
    lstm_out = lstm(attention_input)
    return lstm_out 

In [None]:
def get_feature_name(feature_name):
    if '_' not in feature_name:
        return feature_name
    return feature_name.split('_')[0]

In [None]:
def get_feature_type(feature_name):
    integar_features=['age', 'sex', 'smoke', 'drink', 'excercise', 'HP']
    float_features = ['CNE', 'BMI']
    key = get_feature_name(feature_name)
    if key in integar_features:
        return 'int'
    if key in float_features:
        return 'float'
    return 'string'

In [None]:
def build_model(feature_columns):
    # 0.生成一个place_holder
    
    feature_layer_input={}
    
    for fc in feature_columns:
        key = get_feature_name(fc.name)
        input_type = get_feature_type(fc.name)
        if input_type == 'int':
            feature_layer_input[key] = tf.keras.Input(shape=(1,), dtype = tf.int32, name=key)
        elif input_type == 'float':
            feature_layer_input[key] = tf.keras.Input(shape=(1,), dtype = tf.float32, name=key)
        else:
            feature_layer_input[key] = tf.keras.Input(shape=(1,), dtype = tf.string, name=key)
    
    feature_map = {}
    # 1.生成feature_name->feature dict
    for feature in feature_columns:
        key = get_feature_name(feature.name)
        feature_map[key] = feature
    
    feature_blocks = get_feature_block(feature_columns)
    blocks_output = []
    # 2.lstm block构建
    for block in feature_blocks:
        block_feature_columns = []
        for block_feature in block:
            block_feature_columns.append(feature_map[block_feature])
        feature_tower_output = build_lstm_tower(feature_layer_input, block_feature_columns)
        blocks_output.append(feature_tower_output)
    # out shape is (batch_size, embedding_size)
    attention_output = tf.concat(blocks_output, axis=1)
    
    # 4.处理其余特征
    attention_features = get_block_features()
    dnn_features = []
    for fc in feature_columns:
        key = get_feature_name(fc.name)
        if key not in attention_features:
            dnn_features.append(fc)
    
    feature_layer = tf.keras.layers.DenseFeatures(dnn_features)
    dnn_input = feature_layer(feature_layer_input)
    
    base_input = tf.concat([attention_output, dnn_input], axis=1)
    
    dnn_dense_layer = tf.keras.layers.Dense(64, activation = 'relu', name='layer')
    dnn_output = dnn_dense_layer(base_input)
    
    dnn_dense_layer1 = tf.keras.layers.Dense(16, activation = 'relu', name='layer1')
    dnn_output1 = dnn_dense_layer1(dnn_output)
    
        
    dnn_dense_layer2 = tf.keras.layers.Dense(1, activation = 'sigmoid', name='layer2')
    dnn_output2 = dnn_dense_layer2(dnn_output1)
    
    model = tf.keras.Model(inputs=feature_layer_input, outputs = dnn_output2)
    
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),
                  loss=tf.keras.losses.BinaryCrossentropy(),
                  metrics=['AUC','accuracy','Recall','Precision'])
    return model

In [None]:
# 训练和评估模型
def train_and_eval_model(batch_size=20, drop_col = None):
    feature_columns = get_feature_columns()
    model = build_model(feature_columns)
    train_dataset, eval_dataset= build_all_input_fn()
    # train_dataset = build_input_fn("1.csv", drop_col)
    # eval_dataset = build_input_fn("2.csv", drop_col)
    result = model.fit(x=train_dataset, 
              epochs=30,
              validation_data = eval_dataset,
              callbacks=[tf.keras.callbacks.TensorBoard(log_dir='./logs')])
    return result
    

In [None]:
def model_summary():
    feature_columns = get_feature_columns()
    model = build_model(feature_columns)
    model.summary()
    #tf.keras.utils.plot_model(model)

In [None]:
model_summary()

In [None]:
def get_feature_importance():
    # 获取特征的重要性
    # 1.获取需要计算重要性的特征
    snps = get_allowed_snp_names()
    
    snps = ['rs919390','rs11681642','rs2289273','rs10503675','rs1200130','rs6458080']
    snp_importance = {}
    # 2.开始遍历
    for snp in snps:
        model_history = train_and_eval_model(drop_col=[snp])
        val_auc = model_history.history['val_AUC']
        snp_importance[snp] =val_auc[-1]
        print('begin snp:{} train'.format(snp))
    #print(snp_importance)
    return snp_importance

In [None]:
result = get_feature_importance()

In [None]:
result

In [55]:
train_and_eval_model()

   Simple Name  caco         CNE        age  sex  smoke  drink  excercise  HP  \
11     5703003     1  102.536841  39.250000    1      0      1          1   0   
19     5206009     1   95.248409  37.083333    1      0      1          0   0   
23     1705002     0   88.937321  25.583333    1      1      1          0   0   
24     9100119     1   94.731949  40.166667    1      1      1          1   1   
25     5212008     1  102.216398  50.000000    1      1      1          0   1   

          BMI  ...  rs3766031 rs3807154 rs3823430 rs4452640 rs4714192  \
11  20.069204  ...        C/C       G/G       A/A       A/A       C/C   
19  24.609734  ...        C/C       G/G       G/A       A/A       T/C   
23  27.041644  ...        C/C       G/G       G/A       A/A       T/C   
24  30.116213  ...        C/C       G/G       G/A       A/A       T/C   
25  31.861629  ...        C/C       G/G       G/G       A/A       T/T   

   rs6458080 rs751122 rs874808 rs919390 rs9357283  
11       T/C      T/T 

2022-02-07 00:15:35.928744: I tensorflow/core/profiler/lib/profiler_session.cc:184] Profiler session started.


     90/Unknown - 38s 427ms/step - loss: 0.7444 - AUC: 0.4994 - accuracy: 0.5880 - Recall: 0.1693 - Precision: 0.3252

2022-02-07 00:15:41.215070: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 2/30

2022-02-07 00:15:56.711688: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 3/30

2022-02-07 00:16:01.110420: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 4/30

2022-02-07 00:16:05.669653: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 5/30

2022-02-07 00:16:10.420291: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 6/30

2022-02-07 00:16:15.525341: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 7/30

2022-02-07 00:16:20.300193: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 8/30

2022-02-07 00:16:24.855871: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 9/30

2022-02-07 00:16:29.440756: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 10/30

2022-02-07 00:16:34.240715: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 11/30

2022-02-07 00:16:38.915539: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 12/30

2022-02-07 00:16:43.694501: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 13/30

2022-02-07 00:16:48.611939: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 14/30

2022-02-07 00:16:53.334604: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 15/30

2022-02-07 00:16:58.041635: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 16/30

2022-02-07 00:17:03.155466: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 17/30

2022-02-07 00:17:08.015071: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 18/30

2022-02-07 00:17:12.971060: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 19/30

2022-02-07 00:17:18.307886: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 20/30

2022-02-07 00:17:23.021542: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 21/30

2022-02-07 00:17:27.922211: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 22/30

2022-02-07 00:17:32.793249: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 23/30

2022-02-07 00:17:37.197978: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 24/30

2022-02-07 00:17:41.709533: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 25/30

2022-02-07 00:17:46.378753: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 26/30

2022-02-07 00:17:50.674005: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 27/30

2022-02-07 00:17:55.055041: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 28/30

2022-02-07 00:17:59.708070: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 29/30

2022-02-07 00:18:04.230371: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]


Epoch 30/30

2022-02-07 00:18:08.370640: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
	 [[{{node IteratorGetNext}}]]




<tensorflow.python.keras.callbacks.History at 0x7ffb72eb8190>