In [6]:
import warnings
warnings.filterwarnings("ignore")

In [7]:
import numpy as np
import os
import re
import pandas as pd
import scipy.sparse as sp
import torch as th

# import dgl
# from dgl.data.utils import download, extract_archive, get_download_dir

from itertools import product
from collections import Counter
from copy import deepcopy
from sklearn.model_selection import KFold
from tqdm import tqdm

import random
random.seed(1234)
np.random.seed(1234)

In [8]:
def load_data(directory):
    GSSM = np.loadtxt(directory + '\GSSM_.txt',dtype=np.float32)
    PESSM = np.loadtxt(directory + '\PSSM.txt',dtype=np.float32,delimiter='\t')

    IPE = pd.DataFrame(PESSM).reset_index()
    IG = pd.DataFrame(GSSM).reset_index()
    IPE.rename(columns = {'index':'id'}, inplace = True)
    IG.rename(columns = {'index':'id'}, inplace = True)
    IPE['id'] = IPE['id']
    IG['id'] = IG['id']
    
    return IPE, IG

In [10]:
def sample(directory, random_seed):
    all_associations = pd.read_csv(directory + '/all_gep_pairs.csv')
    known_associations = all_associations.loc[all_associations['label'] == 1]
    unknown_associations = all_associations.loc[all_associations['label'] == 0]
    
    random_negative = unknown_associations.sample(n=known_associations.shape[0], random_state=random_seed, axis=0)

    sample_df = known_associations.append(random_negative)
    sample_df.reset_index(drop=True, inplace=True)

    return sample_df

In [12]:
def obtain_data(directory, isbalance):
    IPE, IG = load_data(directory)
    
    if isbalance:
        dtp = sample(directory, random_seed = 1234)
    else:
        dtp = pd.read_csv(directory + '/all_gep_pairs.csv')
        
    gene_ids = list(set(dtp['gene_idx']))
    peco_ids = list(set(dtp['peco_idx']))
    random.shuffle(gene_ids)
    random.shuffle(peco_ids)
    print('# gene = {} | peco = {}'.format(len(gene_ids), len(peco_ids)))

    gene_test_num = int(len(gene_ids) / 5)
    peco_test_num = int(len(peco_ids) / 5)
    print('# Test: gene = {} | peco = {}'.format(gene_test_num, peco_test_num))
    
    knn_x = pd.merge(dtp, IPE, left_on = 'peco_idx', right_on = 'id')
    knn_x = pd.merge(knn_x, IG, left_on = 'gene_idx', right_on = 'id')

    label = dtp['label']
    knn_x.drop(labels = ['gene_idx', 'peco_idx', 'label', 'id_x', 'id_y'], axis = 1, inplace = True)
    
    return IPE, IG, dtp, gene_ids, peco_ids, gene_test_num, peco_test_num, knn_x, label

In [14]:
def generate_task_Tg_Tpe_train_test_idx(item, ids, dtp):
    
    test_num = int(len(ids) / 5)
    
    train_index_all, test_index_all = [], []
    train_id_all, test_id_all = [], []
    
    for fold in range(5):
        print('-------Fold ', fold)
        if fold != 4:
            test_ids = ids[fold * test_num : (fold + 1) * test_num]
        else:
            test_ids = ids[fold * test_num :]

        train_ids = list(set(ids) ^ set(test_ids))
        print('# {}: Train = {} | Test = {}'.format(item, len(train_ids), len(test_ids)))

        test_idx = dtp[dtp[item].isin(test_ids)].index.tolist()
        train_idx = dtp[dtp[item].isin(train_ids)].index.tolist()
        random.shuffle(test_idx)
        random.shuffle(train_idx)
        print('# Pairs: Train = {} | Test = {}'.format(len(train_idx), len(test_idx)))
        assert len(train_idx) + len(test_idx) == len(dtp)

        train_index_all.append(train_idx) 
        test_index_all.append(test_idx)
        
        train_id_all.append(train_ids)
        test_id_all.append(test_ids)

        print('train_index_all',train_index_all)
        print('test_index_all',test_index_all)
        print('train_id_all',train_id_all)
        print('test_id_all',test_id_all)

    return train_index_all, test_index_all, train_id_all, test_id_all

In [15]:
def generate_task_Tp_train_test_idx(knn_x):
    kf = KFold(n_splits = 5, shuffle = True, random_state = 1234)

    train_index_all, test_index_all, n = [], [], 0
    train_id_all, test_id_all = [], []
    fold = 0
    for train_idx, test_idx in tqdm(kf.split(knn_x)):
        print('-------Fold ', fold)
        train_index_all.append(train_idx)
        test_index_all.append(test_idx)

        train_id_all.append(np.array(dtp.iloc[train_idx][['gene_idx', 'peco_idx']]))
        test_id_all.append(np.array(dtp.iloc[test_idx][['gene_idx', 'peco_idx']]))

        print('# Pairs: Train = {} | Test = {}'.format(len(train_idx), len(test_idx)))
        fold += 1
    return train_index_all, test_index_all, train_id_all, test_id_all

In [16]:
from sklearn.neighbors import KNeighborsClassifier

In [17]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score, auc
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report

In [18]:
def generate_knn_graph_save(knn_x, label, n_neigh, train_index_all, test_index_all, pwd, task, balance):
    
    fold = 0
    for train_idx, test_idx in zip(train_index_all, test_index_all): 
        print('-------Fold ', fold)
        
        knn_y = deepcopy(label)
        knn_y[test_idx] = 0
        print('Label: ', Counter(label))
        print('knn_y: ', Counter(knn_y))

        knn = KNeighborsClassifier(n_neighbors = n_neigh)
        knn.fit(knn_x, knn_y)

        knn_y_pred = knn.predict(knn_x)
        knn_y_prob = knn.predict_proba(knn_x)
        knn_neighbors_graph = knn.kneighbors_graph(knn_x, n_neighbors = n_neigh)
        
        prec_reca_f1_supp_report = classification_report(knn_y, knn_y_pred, target_names = ['label_0', 'label_1'])
        tn, fp, fn, tp = confusion_matrix(knn_y, knn_y_pred).ravel()

        pos_acc = tp / sum(knn_y)
        neg_acc = tn / (len(knn_y_pred) - sum(knn_y_pred)) # [y_true=0 & y_pred=0] / y_pred=0
        accuracy = (tp+tn)/(tn+fp+fn+tp)

        recall = tp / (tp+fn)
        precision = tp / (tp+fp)
        f1 = 2*precision*recall / (precision+recall)

        roc_auc = roc_auc_score(knn_y, knn_y_prob[:, 1])
        prec, reca, _ = precision_recall_curve(knn_y, knn_y_prob[:, 1])
        aupr = auc(reca, prec)

        print('acc={:.4f}|precision={:.4f}|recall={:.4f}|f1={:.4f}|auc={:.4f}|aupr={:.4f}|pos_acc={:.4f}|neg_acc={:.4f}'.format(accuracy, precision, recall, f1, roc_auc, aupr, pos_acc, neg_acc))
        print('tn = {}, fp = {}, fn = {}, tp = {}'.format(tn, fp, fn, tp))
        print('y_pred: ', Counter(knn_y_pred))
        print('y_true: ', Counter(knn_y))
#         print('knn_score = {:.4f}'.format(knn.score(knn_x, knn_y)))

        sp.save_npz(pwd + 'task_' + task + balance + '__testlabel0_knn' + str(n_neigh) + 'neighbors_edge__fold' + str(fold) + '.npz', knn_neighbors_graph)
        fold += 1
    return knn_x, knn_y, knn, knn_neighbors_graph

# Run

In [19]:
for isbalance in [True]:
    print('************isbalance = ', isbalance)
    #[]
    for task in ['Tp', 'Tpe', 'Tg']:
        print('=================task = ', task)
        
        IPE, IG, dtp, gene_ids, peco_ids, gene_test_num, peco_test_num, knn_x, label = obtain_data(r'E:\MDA-GCNFTG-main\GDA\data', isbalance)

        if task == 'Tp':
            train_index_all, test_index_all, train_id_all, test_id_all = generate_task_Tp_train_test_idx(knn_x)

        elif task == 'Tg':
            item = 'gene_idx'
            ids =gene_ids
            train_index_all, test_index_all, train_id_all, test_id_all = generate_task_Tg_Tpe_train_test_idx(item, ids, dtp)
        elif task == 'Tpe':
            item = 'peco_idx'
            ids = peco_ids
            train_index_all, test_index_all, train_id_all, test_id_all = generate_task_Tg_Tpe_train_test_idx(item, ids, dtp)

        if isbalance:
            balance = ''
        else:
            balance = '__nobalance'


        np.savez_compressed(r'..\data\task_' + task + balance + '__testlabel0_knn_edge_train_test_index_all.npz',
                               train_index_all = train_index_all, 
                               test_index_all = test_index_all,
                               train_id_all = train_id_all, 
                               test_id_all = test_id_all)
        pwd = r'../0_data/'
        for n_neigh in [1, 3, 5, 7, 10, 15]: 
            print('--------------------------n_neighbors = ', n_neigh)
            knn_x, knn_y, knn, knn_neighbors_graph = generate_knn_graph_save(knn_x, label, n_neigh, train_index_all, test_index_all, pwd, task, balance)

************isbalance =  True
# gene = 11177 | peco = 24
# Test: gene = 2235 | peco = 4
12219
42
(47116, 42) Counter({1: 23558, 0: 23558})
-------Fold  0
# peco_idx: Train = 20 | Test = 4
# Pairs: Train = 41104 | Test = 6012
train_index_all [[30186, 28425, 15724, 14758, 41465, 34689, 3740, 24303, 26790, 1229, 17172, 33878, 23126, 22973, 9864, 4118, 29300, 32583, 29267, 18695, 42084, 36262, 38929, 22830, 7117, 5161, 2709, 1347, 5996, 24794, 21707, 3929, 26300, 11418, 4271, 19130, 25398, 35195, 15793, 2939, 34, 6625, 16006, 17617, 18123, 39269, 17748, 2010, 30074, 31894, 15933, 36640, 29720, 34794, 37815, 8255, 26612, 35221, 21633, 23133, 4672, 31262, 28095, 40405, 15247, 7307, 41507, 3336, 41532, 43356, 22517, 44923, 19165, 5320, 33501, 24530, 21413, 29258, 36152, 35306, 11598, 28136, 4639, 18534, 8799, 44636, 3173, 43650, 17970, 42557, 16282, 45, 21431, 28039, 9657, 33873, 17975, 44227, 13593, 11206, 44958, 1951, 461, 13328, 36277, 26943, 19825, 39310, 40731, 29848, 4833, 4802, 45264, 

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [20]:
node_feature_label = pd.concat([dtp, knn_x], axis = 1)
node_feature_label

Unnamed: 0,peco_idx,gene_idx,label,0_x,1_x,2_x,3_x,4_x,5_x,6_x,...,0_y,1_y,2_y,3_y,4_y,5_y,6_y,7_y,8_y,9_y
0,0,1,1,1.000000,-0.013850,-0.052308,-0.115104,-0.026206,-0.056888,0.019795,...,50.011002,-0.427,-5.062,13.100,3.009,1.893,1.318,-0.740,-2.141,2.156
1,0,24,1,-0.013850,1.000000,-0.023418,0.055074,0.008011,0.021873,-0.101540,...,50.011002,-0.427,-5.062,13.100,3.009,1.893,1.318,-0.740,-2.141,2.156
2,0,28,1,-0.052308,-0.023418,1.000000,-0.013061,0.170718,0.064524,0.146808,...,50.011002,-0.427,-5.062,13.100,3.009,1.893,1.318,-0.740,-2.141,2.156
3,0,32,1,-0.026206,0.008011,0.170718,-0.041546,1.000000,0.135609,0.031272,...,50.011002,-0.427,-5.062,13.100,3.009,1.893,1.318,-0.740,-2.141,2.156
4,0,33,1,0.094465,-0.113927,-0.002729,-0.011421,-0.028804,-0.059783,0.003187,...,50.011002,-0.427,-5.062,13.100,3.009,1.893,1.318,-0.740,-2.141,2.156
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
47111,14,11660,0,-0.004300,-0.029589,-0.084950,0.033342,-0.102425,0.058110,-0.032627,...,-10.218000,5.064,0.665,3.266,-0.428,1.694,2.332,0.475,6.168,-1.964
47112,24,10270,0,-0.004300,-0.029589,-0.084950,0.033342,-0.102425,0.058110,-0.032627,...,-21.438999,-0.025,-0.429,4.004,-0.680,1.648,0.063,1.728,-2.282,-1.559
47113,31,9349,0,-0.004300,-0.029589,-0.084950,0.033342,-0.102425,0.058110,-0.032627,...,-23.811001,0.680,1.543,1.089,-3.024,-2.133,2.434,0.240,0.399,0.383
47114,13,1419,0,-0.004300,-0.029589,-0.084950,0.033342,-0.102425,0.058110,-0.032627,...,-24.548000,-2.288,1.928,2.478,1.391,-1.339,0.900,1.656,1.143,-1.213


In [21]:
pwd = r'../data/0_data/'
node_feature_label.to_csv(pwd + 'node_feature_label.csv')