In [1]:
import numpy as np
import pandas as pd
import random
import os.path as osp
import networkx as nx
import pickle
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import k_hop_subgraph, from_networkx, train_test_split_edges
import os
import time
from tqdm import tqdm, trange
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_cluster import radius_graph, knnnn_graph
from torch_geometric.nn import GINConv, JumpingKnowledge, GCNConv, Sequential, SAGEConv, GATConv
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from torch_geometric.loader import DataLoader
def pkl_save(dataset_path, data):
    start = time.perf_counter()
    with open(dataset_path, 'wb') as file:
        pickle.dump(data, file)
    end = time.perf_counter()
    print(f"Data save {(end-start):.4f}s")
def pkl_load(dataset_path):
    start = time.perf_counter()
    with open(dataset_path, 'rb') as f:
        dat = pickle.load(f)
    end = time.perf_counter()
    print(f"Data loading {(end-start):.4f}s")
    return dat
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')

In [6]:
def read_node_feas(gene_dict):
    proteins = pkl_load('./ath_groupaa_2mer.pkl')
    cds = pkl_load('./ath_cds_2mer.pkl')
    node_feas = []
    gene_exist = []
    for name in gene_dict.keys():
        if name in proteins.keys() and name in cds.keys():
            fea1 = proteins[name]
            fea2 = cds[name]
            fea = np.array([*fea1, *fea2])
            node_feas.append(fea)
            gene_exist.append(name)
    gene_dict = {}
    for i, g in enumerate(gene_exist):
        gene_dict[g] = i
    return np.array(node_feas), gene_dict

In [7]:

def write_edgelist(edgeinfo, gene_dict, edgelist_path='./edgelist'):
    postive_samples = 0
    with open(edgelist_path, 'w') as f: 
        for edge in edgeinfo:
            if edge[-1] == 1:
                postive_samples += 1
                f.write(str(gene_dict[edge[0]]))
                f.write(' ')
                f.write(str(gene_dict[edge[1]]))
                f.write('\n')
    print(f"Total {len(edgeinfo)} edges. Positive: {postive_samples} edges. Negative: {len(edgeinfo) - postive_samples} edges.")

def get_edgeinfo(file="./balanced_learning_matrix.csv"):
    genelist = []
    edgelist = []
    edgeinfo = []
    with open(file, 'r') as f:
        for line in f:
            if not line.startswith("Interaction"):
                tmp = line.strip('\n').strip(' ').split('\t')
                if tmp[0].split('_')[0].startswith('AT') and tmp[0].split('_')[1].startswith('AT'):
                    genes = tmp[0].split('_')
                    genelist.extend(genes)
                    edgelist.append(genes)
                    others = [float(i) for i in tmp[1:9]]
                    edgeinfo.append([genes[0], genes[1], *others, float(tmp[-1])])
    unique_genelist = list(set(genelist))
    gene_dict = {}
    for i, g in enumerate(unique_genelist):
        gene_dict[g] = i
    write_edgelist(edgeinfo, gene_dict)
    return gene_dict, edgelist, edgeinfo
def construct_graph(gene_dict, edgeinfo):
    node_feas, gene_dict = read_node_feas(gene_dict)
    print(f"Unique genelist {len(gene_dict)}, i.e. number of nodes in Graph.")
    G = nx.Graph()
    for i, node in enumerate(gene_dict.values()):
        G.add_node(node)
    for i, e in enumerate(edgeinfo): 
        if e[0] in gene_dict and e[1] in gene_dict:
            if edgeinfo[i][-1] == 1:
                G.add_edge(gene_dict[e[0]], gene_dict[e[1]], edge_attr = edgeinfo[i][2:10])
    for i, node in enumerate(G.nodes):
        nx.set_node_attributes(G, node_feas[i], "x")
    G.remove_nodes_from(list(nx.isolates(G)))
    return G

In [8]:
file = "./balanced_learning_matrix.csv"
gene_dict, edgelist, edgeinfo = get_edgeinfo(file)
G = construct_graph(gene_dict, edgeinfo)
print(G)

Total 22856 edges. Positive: 5722 edges. Negative: 17134 edges.
Data loading 0.0317s
Data loading 0.2612s
Unique genelist 3799, i.e. number of nodes in Graph.
Graph with 3793 nodes and 5450 edges


In [9]:
def pos_sampler(link, hop, x, edge_index, edge_attr):
    subset, sub_edge_index, mapping, edge_mask = k_hop_subgraph(link, hop, edge_index, relabel_nodes=True)
    
#     subset, sub_edge_index_, mapping, edge_mask = k_hop_subgraph(link, hop, edge_index, relabel_nodes=False)
#     print(subset.shape, sub_edge_index.shape, mapping, edge_mask.shape)
    assert sub_edge_index[0].unique().shape==subset.shape
    sub_x = x[subset]
    sub_edge_attr = edge_attr[edge_mask]
    sub_y = 1
#     print(torch.max(sub_edge_index),sub_x.shape[0], torch.max(sub_edge_index_) < sub_x.shape[0])
    assert torch.max(sub_edge_index) < sub_x.shape[0]
    tmp = Data(x=sub_x, edge_index=sub_edge_index, edge_attr=sub_edge_attr,y=sub_y)
#     print(torch.max(sub_edge_index), sub_x.shape)
#     print(tmp, link)
    return tmp

def neg_sampler(src, hop, x, edge_index, edge_attr, all_nodes, neg_ratio):
    neglist = []
    neighbor_set, neighbor_edge_index, _, __ = k_hop_subgraph(src, hop, edge_index, relabel_nodes=True)
    neighbor_set = neighbor_set.tolist()
    sample_nodes = list(set(all_nodes)-set(neighbor_set))
    sample_nodes = list(set(sample_nodes)-set([3983]))
    for i in range(0, neg_ratio):
        tar  = random.choice(sample_nodes)
        link = [src, tar]
        if tar >= 3793:
            continue
#         print(link)
        subset, sub_edge_index, mapping, edge_mask = k_hop_subgraph(link, hop, edge_index, relabel_nodes=True)
#         subset, sub_edge_index_, mapping, edge_mask = k_hop_subgraph(link, hop, edge_index, relabel_nodes=False)
#         if sub_edge_index[0].unique().shape!=subset.shape:
#             print(link, sub_edge_index[0].unique().shape, subset.shape)
        assert sub_edge_index[0].unique().shape==subset.shape
        
        sub_x = x[subset]
        sub_edge_attr = edge_attr[edge_mask]
        sub_y = 0
        assert torch.max(sub_edge_index) < sub_x.shape[0]
        tmp = Data(x=sub_x, edge_index=sub_edge_index, edge_attr=sub_edge_attr,y=sub_y)
#         print(tmp, link)
        neglist.append(tmp)
    return neglist

In [10]:
main_data = from_networkx(G)
edge_index=main_data.edge_index
edge_attr=main_data.edge_attr
x=main_data.x
all_nodes = edge_index[0].unique()
print(edge_attr.shape, edge_index.shape, x.shape, len(all_nodes))

torch.Size([10868, 8]) torch.Size([2, 10868]) torch.Size([3793, 113]) 3793


  data[key] = torch.tensor(value)


In [11]:
datalist = []
hop = 2
neg_ratio = 3
# all_nodes = edge_index[0].unique().tolist()
all_nodes = list(G.nodes())
n = 0
for src,tar in tqdm(G.edges()):
# for src,tar in G.edges():
    if tar >= 3793:
        continue
    link = [src, tar]
    pos = pos_sampler(link, hop, x, edge_index, edge_attr)
    neg = neg_sampler(src, hop, x, edge_index, edge_attr, all_nodes, neg_ratio)
    datalist.extend([pos, *neg])
    n+=1
print(len(datalist))

100%|██████████████████████████████████████████████████████████████████████████████| 5450/5450 [00:48<00:00, 113.32it/s]

21706





In [12]:
dataset_path = './pos_neg_link_datalist.pkl'
pkl_save(dataset_path, datalist)

Data save 80.0463s


In [None]:
class GNN(torch.nn.Module):

    def __init__(self, in_fea, hidden_channels, num_layers, dropout, conv_type, out_channels=20):
        super(GNN, self).__init__()
        self.num_layers = num_layers
        self.dropout = dropout

        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.vns = nn.ModuleList()
        for i in range(num_layers):
            if i == 0:
                if conv_type=='gcn':
                    conv = GraphConv(in_fea, hidden)
                elif conv_type=='gin':
                    conv = GINConv(nn.Linear(in_fea, hidden_channels))
                bn = torch.nn.BatchNorm1d(hidden_channels)
                vn = VirtualNode(in_fea, hidden_channels, dropout=dropout)
            else:
                if conv_type=='gcn':
                    conv = GraphConv(in_channels=hidden_channels, out_channels=hidden_channels)
                elif conv_type=='gin':
                    conv = GINConv(nn.Linear(hidden_channels, hidden_channels))
                bn = torch.nn.BatchNorm1d(hidden_channels)
                vn = VirtualNode(hidden_channels, hidden_channels, dropout=dropout)
            self.vns.append(vn)
            self.convs.append(conv)
            self.batch_norms.append(bn)
        self.pool = TopKPooling(hidden_channels, 1e-4)
        self.mlp = nn.Linear(hidden_channels, out_channels)

    def reset_parameters(self):
        # if self.mol:
        #     for emb in self.node_encoder.atom_embedding_list:
        #         nn.init.xavier_uniform_(emb.weight.data)
        # else:
        #     nn.init.xavier_uniform_(self.node_encoder.weight.data)

        for i in range(self.num_layers):
            self.convs[i].reset_parameters()
            self.bns[i].reset_parameters()
            self.vns[i].reset_parameters()
        self.pool.reset_parameters()
        self.mlp.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        for i in range(self.num_layers):
            x, vx = self.vns[i].update_node_emb(x, edge_index, batch)
            x = self.convs[i](x, edge_index)
            x = self.batch_norms[i](x)
            x = F.dropout(F.relu(x), p=self.dropout)
        x, edge_index, edge_attr, batch, perm, select_output_weight = self.pool(x, edge_index, batch=batch)
        x = self.mlp(x)
        return x