In [2]:
import os,sys
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from stellargraph.mapper import PaddedGraphGenerator
from stellargraph.layer import GCNSupervisedGraphClassification
from reproduce_gcn_utils import *

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import binary_crossentropy,mean_squared_error
from tensorflow.keras.callbacks import EarlyStopping

In [4]:
ori_train = pd.read_csv('data/reproduce_gcn_data/remove0123_sample100.csv')
hla = pd.read_csv('data/reproduce_gcn_data/hla2paratopeTable_aligned.txt',sep='\t')
after_pca = np.loadtxt('data/reproduce_gcn_data/after_pca.txt')
hla_dic = hla_df_to_dic(hla)
inventory = list(hla_dic.keys())
dic_inventory = dict_inventory(inventory)
ori_train['immunogenicity'], ori_train['potential'] = ori_train['potential'], ori_train['immunogenicity']
kf = KFold(n_splits=10)
fold_indices = list(kf.split(np.arange(ori_train.shape[0])))
holding = {'validation':[],'dengue':[],'cell':[],'covid':[]}

In [5]:
for fold in fold_indices:
    i = 1
    graphs, graph_labels = Graph_Constructor.entrance(ori_train, after_pca, hla_dic, dic_inventory)
    generator = PaddedGraphGenerator(graphs=graphs)
    gc_model = GCNSupervisedGraphClassification(
        layer_sizes=[64, 64],
        activations=["relu", "relu"],
        generator=generator,
        dropout=0.2, )
    x_inp, x_out = gc_model.in_out_tensors()
    predictions = Dense(units=32, activation="relu")(x_out)
    predictions = Dense(units=16, activation="relu")(predictions)
    predictions = Dense(units=1, activation="sigmoid")(predictions)
    model = Model(inputs=x_inp, outputs=predictions)
    model.compile(optimizer=Adam(0.001), loss=mean_squared_error)
    train_gen = generator.flow(
        fold[0],
        targets=graph_labels.iloc[fold[0]].values,
        batch_size=256, )
    test_gen = generator.flow(
        fold[1],
        targets=graph_labels.iloc[fold[1]].values,
        batch_size=1, )
    epochs = 100
    es1 = EarlyStopping(monitor='loss', patience=2, restore_best_weights=False)
    es2 = EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=False)
    history = model.fit(
        train_gen, epochs=epochs, validation_data=test_gen, shuffle=True, callbacks=[es1, es2, ],
        class_weight={0: 0.5, 1: 0.5})

    # test in validation
    pred = model.predict(test_gen)
    from sklearn.metrics import mean_squared_error
    result = mean_squared_error(graph_labels.iloc[fold_indices[0][1]], pred, squared=False)
    holding['validation'].append(result)

    # test in dengue
    ori_test = pd.read_csv('data/dengue_test.csv')
    graphs_test, graph_labels_test = Graph_Constructor.entrance(ori_test, after_pca, hla_dic, dic_inventory)
    generator_test = PaddedGraphGenerator(graphs=graphs_test)
    input = generator_test.flow(graphs_test)
    prediction = model.predict(input)
    from sklearn.metrics import accuracy_score, recall_score, precision_score
    hard = [1 if item >= 0.5 else 0 for item in prediction[:, 0]]
    result = accuracy_score(graph_labels_test, hard)
    holding['dengue'].append(result)

    # test in cell
    ori_test_cell = pd.read_csv('data/reproduce_gcn_data/ori_test_cells.csv')
    graphs_test, graph_labels_test = Graph_Constructor.entrance(ori_test_cell, after_pca, hla_dic, dic_inventory)
    generator_test = PaddedGraphGenerator(graphs=graphs_test)
    input = generator_test.flow(graphs_test)
    prediction = model.predict(input)
    hard = [1 if item >= 0.5 else 0 for item in prediction]
    result1 = recall_score(graph_labels_test, hard)  # recall
    ori_test_cell['result'] = prediction
    ori_test_cell = ori_test_cell.sort_values(by='result', ascending=False).set_index(
        pd.Index(np.arange(ori_test_cell.shape[0])))
    result2 = np.count_nonzero(ori_test_cell['immunogenicity'].values[:20] == 1)  # top20
    result3 = np.count_nonzero(ori_test_cell['immunogenicity'].values[:50] == 1)  # top50
    holding['cell'].append((result1, result2, result3))

    # test in covid
    ori = pd.read_csv('data/reproduce_gcn_data/sars_cov_2_result.csv')
    ori = ori.sample(frac=1, replace=False).set_index(pd.Index(np.arange(ori.shape[0])))
    ori_test_covid = retain_910(ori)
    graphs_test, graph_labels_test = Graph_Constructor.entrance(ori_test_covid, after_pca, hla_dic, dic_inventory)
    generator_test = PaddedGraphGenerator(graphs=graphs_test)
    input = generator_test.flow(graphs_test)
    prediction = model.predict(input)
    hard = [1 if item >= 0.5 else 0 for item in prediction]
    result1 = recall_score(ori_test_covid['immunogenicity-con'], hard)  # convalescent recall
    result2 = recall_score(ori_test_covid['immunogenicity'], hard)  # unexposed recall
    result3 = precision_score(ori_test_covid['immunogenicity-con'], hard)  # convalescent recall
    result4 = precision_score(ori_test_covid['immunogenicity'], hard)  # unexposed recall
    holding['covid'].append((result1, result2, result3, result4))
    print('round {}, finished covid'.format(i))
    break

100%|██████████| 8971/8971 [00:32<00:00, 272.21it/s]


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100


100%|██████████| 408/408 [00:01<00:00, 260.05it/s]
100%|██████████| 522/522 [00:01<00:00, 276.89it/s]
100%|██████████| 92/92 [00:00<00:00, 264.14it/s]


round 1, finished covid


In [6]:
print(holding)

{'validation': [0.1403044852465869], 'dengue': [0.8504901960784313], 'cell': [(0.7428571428571429, 2, 2)], 'covid': [(0.72, 0.625, 0.2465753424657534, 0.0684931506849315)]}
