In [1]:
import torch.nn.functional as F
import torch.nn as nn
import torch
from torch.nn import Parameter
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing, GATConv
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.nn.inits import glorot, zeros

class GAT(nn.Module):
    def __init__(self, in_feats=1,
                 h_feats=[8, 8, 1],
                 heads=[8, 8, 4],
                 dropout=0.6,
                 negative_slope=0.2,
                 linear_layer=None,
                 **kwargs):
        super(GAT, self).__init__()
        self.dropout = dropout
        self.layers = nn.ModuleList()

        self.linear_layer = linear_layer
        if self.linear_layer is not None:
            print('Applying linear')
            self.linear = nn.Linear(in_feats, linear_layer)

        in_feats = in_feats if linear_layer is None else linear_layer
        for i, h_feat in enumerate(h_feats):
            last = i + 1 == len(h_feats)
            self.layers.append(GATConv(in_feats, h_feat,
                                       heads=heads[i],
                                       dropout=dropout,
                                       concat=False if last else True))
            in_feats = h_feat * heads[i]

    def forward(self, X, A, edge_attr=None, return_alphas=False):
        if self.linear_layer is not None:
            X = self.linear(X)
            #X = F.relu(X)

        alphas = []
        for layer in self.layers[:-1]:
            if return_alphas:
                X, alpha, _ = layer(
                    X, A, edge_attr=edge_attr, return_alpha=True)
                alphas.append(alpha)
            else:
                X = layer(X, A, edge_attr=edge_attr)
            X = F.relu(X)
            X = F.dropout(X, self.dropout)

        if return_alphas:
            X, alpha, edge_index = self.layers[-1](
                X, A, edge_attr=edge_attr, return_alpha=True)
            alphas.append(alpha)
            return X, alphas, edge_index

        X = self.layers[-1](X, A, edge_attr=edge_attr)
        return X

In [14]:
import os
import random
import torch
import numpy as np
import pandas as pd
import networkx as nx
from sklearn.model_selection import train_test_split
from sklearn import metrics
import scipy.sparse as sparse
from os.path import join
import pickle
from tqdm import tqdm

def data(label_path, ppi_path, expr_path=None, ortho_path=None, subloc_path=None, string_thr=0, seed=42, weights=False):

    print(f'PPI: {os.path.basename(ppi_path)}.')

    edge_weights = None

    edges = pd.read_csv(ppi_path)
    key = 'combined_score'
    edges = edges[edges.loc[:, key] > string_thr].reset_index()
    edge_weights = edges['combined_score'] # / 1000
    edges = edges[['A', 'B']]
    print('Filtered String network with thresh:', string_thr)

    edges = edges.dropna()
    index, edges = edges.index, edges.values
    ppi_genes = np.union1d(edges[:, 0], edges[:, 1])
    if edge_weights is not None:
        edge_weights = edge_weights.iloc[index.values].values

    labels = pd.read_csv(label_path, index_col=0)

    # filter labels not in the PPI network
    print('Number of labels before filtering:', len(labels))
    labels = labels.loc[np.intersect1d(labels.index, ppi_genes)].copy()
    print('Number of labels after filtering:', len(labels))

    genes = np.union1d(labels.index, ppi_genes)
    print('Total number of genes:', len(genes))

    X = np.zeros((len(genes), 0))
    X = pd.DataFrame(X, index=genes)

    if ortho_path is not None:
        orths = pd.read_csv(ortho_path, index_col=0)
        columns = [f'ortholog_{i}' for i in range(orths.shape[1])]
        orths.columns = columns
        X = X.join(orths, how="left")
        print('Orthologs dataset shape:', orths.shape)

    if expr_path is not None:
        expression = pd.read_csv(expr_path, index_col=0)
        columns = [f'expression_{i}' for i in range(expression.shape[1])]
        expression.columns = columns
        X = X.join(expression, how="left")
        print('Gene expression dataset shape:', expression.shape)

    if subloc_path is not None:
        subloc = pd.read_csv(subloc_path, index_col=0)
        columns = [f'subloc_{i}' for i in range(subloc.shape[1])]
        subloc.columns = columns
        X = X.join(subloc, how="left")
        print('Subcellular Localizations dataset shape:', subloc.shape)

    X = X.fillna(0)

    train, test = train_test_split(
        labels, test_size=0.2, random_state=seed, stratify=labels)

    print(f'Num nodes {len(genes)} ; num edges {len(edges)}')
    print(f'X.shape: {None if X is None else X.shape}.')
    print(f'Train labels. Num: {len(train)} ; Num pos: {train.label.sum()}')
    print(f'Test labels. Num: {len(test)} ; Num pos: {test.label.sum()}')
    print(X.tail())
    return (edges, edge_weights), X, train, test, genes

In [15]:
path = '../../data'
ipath = './data'
data(os.path.join(path, 'Kidney_HELP.csv'), 
     os.path.join(path, 'Kidney_PPI.csv'),
     os.path.join(ipath, 'GTEX_expr_kidney.csv'),
     os.path.join(ipath, 'Orthologs_kidney.csv'),
     os.path.join(ipath, 'Sublocs_kidney.csv'))

PPI: Kidney_PPI.csv.
Filtered String network with thresh: 0
Number of labels before filtering: 17931
Number of labels after filtering: 17236
Total number of genes: 19314
Orthologs dataset shape: (18541, 162)
Gene expression dataset shape: (18609, 89)
Subcellular Localizations dataset shape: (12064, 11)
Num nodes 19314 ; num edges 1110251
X.shape: (19315, 262).
Train labels. Num: 13788 ; Num pos: sNEaEsNEsNEsNEsNEsNEsNEaEsNEsNEsNEsNEsNEsNEsNEsNEsNEaEaEsNEaEsNEsNEEsNEEsNEsNEEsNEsNEEsNEsNEsNEaEsNEsNEsNEsNEsNEsNEsNEaEsNEsNEsNEaEaEaEsNEaEsNEsNEsNEsNEsNEaEaEsNEsNEaEsNEsNEsNEsNEsNEaEEsNEsNEsNEEEsNEsNEsNEaEsNEsNEsNEaEsNEEsNEsNEsNEaEsNEsNEsNEEsNEsNEsNEsNEsNEsNEaEsNEEsNEsNEsNEsNEsNEsNEsNEsNEsNEaEsNEsNEsNEaEaEsNEsNEsNEsNEsNEsNEsNEsNEsNEsNEaEEsNEaEsNEsNEsNEaEsNEsNEsNEsNEsNEaEsNEsNEaEEsNEsNEsNEaEsNEaEsNEsNEsNEsNEsNEaEsNEaEaEsNEaEsNEsNEsNEsNEsNEsNEsNEsNEEsNEaEaEaEsNEaEsNEsNEsNEsNEsNEsNEEsNEsNEsNEsNEsNEsNEsNEEsNEsNEsNEsNEsNEEsNEaEaEsNEsNEaEsNEsNEsNEaEsNEsNEsNEsNEsNEaEsNEsNEaEaEsNEsNEsNEsNEsNEaEsNEsNE

((array([['MAP2K4', 'FLNC'],
         ['FNTA', 'ACVR1'],
         ['GATA2', 'PML'],
         ...,
         ['FBLN1', 'COL18A1'],
         ['SNAI1', 'LOXL3'],
         ['FBLN1', 'LAMA2']], dtype=object),
  array([2, 1, 1, ..., 1, 1, 1])),
                                                 ortholog_0  ortholog_1  \
 (clone tec14)                                          0.0         0.0   
 100 kDa coactivator                                    0.0         0.0   
 14-3-3 tau splice variant                              0.0         0.0   
 3'-phosphoadenosine-5'-phosphosulfate synthase         0.0         0.0   
 3-beta-hydroxysteroid dehydrogenase                    0.0         0.0   
 ...                                                    ...         ...   
 pp10122                                                0.0         0.0   
 tRNA-uridine aminocarboxypropyltransferase             0.0         0.0   
 tmp_locus_54                                           0.0         0.0   
 urf-ret    