In [1]:
import os
import numpy as np
import pandas as pd
import torch
import h5py
from sklearn.metrics.pairwise import cosine_similarity, paired_distances
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

In [2]:
class EmbeddingGraph:
    def __init__(self, nodes, embedding_matrix):
        super(EmbeddingGraph).__init__()
        self.nodes = np.array(nodes, dtype=object)
        self.nodes_num = len(nodes)
        self.embedding_matrix = embedding_matrix
        self.adj_matrix = np.zeros(shape = (self.nodes_num, self.nodes_num), dtype=float)
        self.distance_matrix = np.zeros_like(self.adj_matrix)
        self.degrees= self.adj_matrix.sum(1)
    
    def preprocess(self):
        estimator = PCA(n_components=32)
        pca = estimator.fit_transform(self.embedding_matrix)
        self.embedding_matrix = pca
        scaler = StandardScaler().fit(self.embedding_matrix)
        self.embedding_matrix = scaler.transform(self.embedding_matrix)
        return self.embedding_matrix

    def cal_eucli_distance(self, node_idx1, node_idx2):
        '''calculate the euclidean distance between two given nodes' index'''
        embedding1 = self.embedding_matrix[node_idx1]
        embedding2 = self.embedding_matrix[node_idx2]
        eucli_dist = np.sqrt(sum((embedding1 - embedding2) ** 2))
        return eucli_dist

    def cal_cosine_distance(self, node_idx1, node_idx2):
        embedding1 = self.embedding_matrix[node_idx1].reshape(1, -1)
        embedding2 = self.embedding_matrix[node_idx2].reshape(1, -1)
        cosine_dist = paired_distances(embedding1, embedding2, metric='cosine')
        return cosine_dist

    def create_distance_matrix(self, mode='cosine'):
        '''create distance matrix according to specified standard'''
        assert mode in ['eculidean', 'cosine']
        for i in range(self.nodes_num):
            for j in range(i, self.nodes_num):
                if mode == 'eculidean':
                    dist = self.cal_eucli_distance(i, j)
                elif mode == 'cosine':
                    dist = self.cal_cosine_distance(i, j)
                self.distance_matrix[i, j] = dist
                self.distance_matrix[j, i] = dist
        return self.distance_matrix

    def knn_create_adj_matrix(self, k):
        '''create adj matrix according to k nearest neighbours'''
        assert self.distance_matrix.sum() != 0   # need to create distance matrix first, call class.create_distance_matrix()
        self.adj_matrix = np.zeros(shape = (self.nodes_num, self.nodes_num), dtype=float) 
        for i in range(self.nodes_num):
            min_k_indices = self.distance_matrix[i].argpartition(k)[:k]
            for j in min_k_indices:
                self.adj_matrix[i][j] = 1
                self.adj_matrix[j][i] = 1
        self.degrees = self.adj_matrix.sum(1)
        return self.adj_matrix
    
    def embedding_out(self, name, mode='csv'):
        assert mode in ['csv', 'tsv']    # output the embedding matrix as a csv file or tsv file
        if mode == 'tsv':
            np.savetxt(name, self.embedding_matrix, delimiter='\t')
    
    def norm_adj_matrix(self):
        '''return lapalacian-normed adj matrix'''
        assert self.distance_matrix.sum != 0 # need to create distance matrix first, call class.create_distance_matrix() 
        assert self.adj_matrix.sum() != 0 # need to create adj matrix first, call class.create_adj_matrix() 
        degree = np.array(self.adj_matrix.sum(1))
        self.degrees = degree
        degree = np.diag(np.power(degree, -0.5))
        return degree.dot(self.adj_matrix).dot(degree)


    def save_h5_file(self, name): 
        if os.path.exists(name):   # replace the old h5 file
            os.remove(name)
        dt_str = h5py.special_dtype(vlen=str)
        f = h5py.File(name, mode='w')
        f.create_dataset('nodes', data=self.nodes, dtype=dt_str)
        f.create_dataset('embedding_matrix', data=self.embedding_matrix, dtype=float)
        f.create_dataset('adj_matrix', data=self.adj_matrix, dtype=float)
        f.create_dataset('distance_matrix', data=self.distance_matrix, dtype=float)
        normed_adj_matrix = self.norm_adj_matrix()
        f.create_dataset('normed_adj_matrix', data=np.array(normed_adj_matrix), dtype=float)
        f.close()

In [13]:
# create adj table for gnn
patho_gene_embedding_dir = '/amax/data/luad/embedding/split8/normal/'
split_file_dir = '/home/xieyuzhang/mtmcat/dataset/survival/luad/luad_splits/luad_split8.csv'
split_file = pd.read_csv(split_file_dir)
train_patient_list = list(split_file['train'].dropna())
val_patient_list = list(split_file['validation'].dropna())
test_patient_list = list(split_file['test'].dropna())
case_ids = np.array(train_patient_list + val_patient_list + test_patient_list)
patho_gene_embeddings = []
for case_id in case_ids:
    embedding = np.array(torch.load(os.path.join(patho_gene_embedding_dir, case_id+'.pt')).to('cpu').squeeze(), dtype=float)
    patho_gene_embeddings.append(embedding)
patho_gene_embeddings = np.array(patho_gene_embeddings)
person_graph = EmbeddingGraph(case_ids, patho_gene_embeddings)

In [14]:
person_graph = EmbeddingGraph(case_ids, patho_gene_embeddings)
# person_graph.preprocess()
person_graph.create_distance_matrix(mode='cosine')
person_graph.knn_create_adj_matrix(4)
adj = person_graph.norm_adj_matrix()
person_graph.save_h5_file('/home/xieyuzhang/mtmcat/dataset/survival/luad/inputs/embedding/split8/split8.h5')

In [29]:
person_graph.create_distance_matrix()

array([[0.00000000e+00, 1.99583990e+00, 3.26077066e-03, ...,
        4.79257273e-03, 8.20340509e-03, 1.76020914e-02],
       [1.99583990e+00, 0.00000000e+00, 1.99573130e+00, ...,
        1.99646679e+00, 1.99536476e+00, 1.98047804e+00],
       [3.26077066e-03, 1.99573130e+00, 0.00000000e+00, ...,
        1.52509690e-03, 5.93473973e-03, 2.72131462e-02],
       ...,
       [4.79257273e-03, 1.99646679e+00, 1.52509690e-03, ...,
        0.00000000e+00, 6.18176540e-03, 2.67902869e-02],
       [8.20340509e-03, 1.99536476e+00, 5.93473973e-03, ...,
        6.18176540e-03, 0.00000000e+00, 3.40280102e-02],
       [1.76020914e-02, 1.98047804e+00, 2.72131462e-02, ...,
        2.67902869e-02, 3.40280102e-02, 0.00000000e+00]])

In [130]:
person_graph.knn_create_adj_matrix(4)

array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 1., 0., 0.],
       ...,
       [0., 0., 1., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]])

In [5]:
person_graph.norm_adj_matrix()

array([[0.1       , 0.        , 0.        , ..., 0.        , 0.        ,
        0.14142136],
       [0.        , 0.2       , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.125     , ..., 0.125     , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.125     , ..., 0.125     , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.2       ,
        0.        ],
       [0.14142136, 0.        , 0.        , ..., 0.        , 0.        ,
        0.2       ]])

In [6]:
person_graph.degrees

array([ 4.,  6.,  7.,  5.,  5.,  9.,  5.,  7., 10.,  7.,  5.,  4.,  5.,
        4.,  4.,  4.,  6.,  6.,  4.,  4.,  5.,  6.,  5.,  6.,  6.,  4.,
        4.,  6.,  4.,  5.,  5.,  9.,  4.,  8.,  6.,  6.,  4.,  4.,  4.,
        9.,  4.,  4.,  4.,  4.,  5.,  5.,  6.,  4.,  5.,  4.,  7.,  4.,
        6.,  5.,  4.,  7.,  8.,  4.,  4.,  4.,  5.,  5.,  4.,  6.,  8.,
        8., 16.,  4.,  5.,  4.,  4.,  4.,  4.,  4.,  5.,  4.,  7.,  4.,
        4.,  5.,  6.,  6.,  7.,  6.,  5.,  5.,  7.,  4.,  7.,  4.,  6.,
        5.,  5.,  6.,  4.,  5.,  7.,  5.,  9.,  6., 10.,  8.,  4.,  5.,
        5.,  6.,  9.,  4.,  7.,  9., 15.,  4.,  4.,  4.,  5.,  6.,  4.,
        5.,  4.,  4., 12.,  7.,  6.,  4.,  5.,  5.,  5.,  5.,  6., 11.,
        5.,  5.,  4.,  6.,  5.,  6.,  6.,  5.,  4.,  5.,  4.,  6.,  8.,
        6.,  6.,  6.,  6.,  4.,  8.,  7.,  4.,  4.,  4.,  9.,  5.,  7.,
        4.,  5.,  4.,  4.,  4.,  7.,  8.,  4.,  4.,  8.,  5.,  5.,  4.,
        7.,  4.,  6.,  6.,  4.,  4.,  6.,  4.,  5.,  4.,  4.,  8

In [7]:
(((person_graph.embedding_matrix).mean(0)) == 0)

array([ True, False, False, False, False, False,  True,  True, False,
       False, False,  True, False, False,  True, False,  True,  True,
       False, False, False, False, False,  True, False, False, False,
        True, False,  True, False, False,  True, False, False, False,
       False, False,  True, False,  True,  True, False, False, False,
       False, False, False, False,  True, False, False,  True, False,
       False,  True, False, False, False, False, False, False, False,
       False])

In [8]:
person_graph.embedding_matrix.shape

(437, 64)