In [1]:
import torch.nn.functional as F
import torch.nn as nn
import torch
from torch.nn import Parameter
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing, GATConv
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.nn.inits import glorot, zeros

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
modelname = 'GAT'

gat_human = {
    'lr': 0.005,
    'weight_decay': 5e-4,
    'h_feats': [8, 1],
    'heads': [8, 1],
    'dropout': 0.4,
    'negative_slope': 0.2}

class Loss():
    def __init__(self, y, idx):
        self.y = y
        idx = np.array(idx)

        self.y_pos = y[y == 1]
        self.y_neg = y[y == 0]

        self.pos = idx[y.cpu() == 1]
        self.neg = idx[y.cpu() == 0]

    def __call__(self, out):
        loss_p = F.binary_cross_entropy_with_logits(
            out[self.pos].squeeze(), self.y_pos)
        loss_n = F.binary_cross_entropy_with_logits(
            out[self.neg].squeeze(), self.y_neg)
        loss = loss_p + loss_n
        return loss

def train(params, X, A, edge_weights, train_y, train_idx, val_y, val_idx, save_best_only=True, savepath='',):

    epochs = 1000

    model = GAT(in_feats=X.shape[1], **params)
    model.to(DEVICE)
    X = X.to(DEVICE)
    A = A.to(DEVICE)
    train_y = train_y.to(DEVICE)
    val_y = val_y.to(DEVICE)
    if edge_weights is not None:
        edge_weights = edge_weights.to(DEVICE)

    optimizer = optim.Adam(
        model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])
    loss_fnc = Loss(train_y, train_idx)
    val_loss_fnc = Loss(val_y, val_idx)

    iterable = tqdm(range(epochs))
    for i in iterable:
        model.train()
        logits = model(X, A, edge_attr=edge_weights)

        optimizer.zero_grad()
        loss = loss_fnc(logits)
        loss.backward()
        optimizer.step()

        logits = logits.detach()
        val_loss = val_loss_fnc(logits)
        train_auc = evalAUC(None, 0, 0, train_y, 0, logits[train_idx])
        val_auc = evalAUC(None, 0, 0, val_y, 0, logits[val_idx])

        tqdm.set_description(iterable, desc='Loss: %.4f ; Val Loss %.4f ; Train AUC %.4f. Validation AUC: %.4f' % (
            loss, val_loss, train_auc, val_auc))

    score = evalAUC(model, X, A, val_y, val_idx)
    print(f'Last validation AUC: {val_auc}')

    if savepath:
        save = {
            'auc': score,
            'model_params': params,
            'model_state_dict': model.state_dict()
        }
        torch.save(save, savepath)

    return model


def test(model, X, A, test_ds=None):
    model.to(DEVICE).eval()
    X = X.to(DEVICE)
    A = A.to(DEVICE)

    with torch.no_grad():
        logits = model(X, A)
    probs = torch.sigmoid(logits)
    probs = probs.cpu().numpy()

    if test_ds is not None:
        test_idx, test_y = test_ds
        test_y = test_y.cpu().numpy()
        auc = metrics.roc_auc_score(test_y, probs[test_idx])
        preds = (probs[test_idx] > 0.5) * 1
        score = metrics.accuracy_score(test_y, preds)
        ba = metrics.balanced_accuracy_score(test_y, preds)
        mcc = metrics.matthews_corrcoef(test_y, preds)
        return probs, auc, score, ba, mcc
    return probs, None, None, None

class GAT(nn.Module):
    def __init__(self, in_feats=1,
                 h_feats=[8, 8, 1],
                 heads=[8, 8, 4],
                 dropout=0.6,
                 negative_slope=0.2,
                 linear_layer=None,
                 **kwargs):
        super(GAT, self).__init__()
        self.dropout = dropout
        self.layers = nn.ModuleList()

        self.linear_layer = linear_layer
        if self.linear_layer is not None:
            print('Applying linear')
            self.linear = nn.Linear(in_feats, linear_layer)

        in_feats = in_feats if linear_layer is None else linear_layer
        for i, h_feat in enumerate(h_feats):
            last = i + 1 == len(h_feats)
            self.layers.append(GATConv(in_feats, h_feat,
                                       heads=heads[i],
                                       dropout=dropout,
                                       concat=False if last else True))
            in_feats = h_feat * heads[i]

    def forward(self, X, A, edge_attr=None, return_alphas=False):
        if self.linear_layer is not None:
            X = self.linear(X)
            #X = F.relu(X)

        alphas = []
        for layer in self.layers[:-1]:
            if return_alphas:
                X, alpha, _ = layer(
                    X, A, edge_attr=edge_attr, return_alpha=True)
                alphas.append(alpha)
            else:
                X = layer(X, A, edge_attr=edge_attr)
            X = F.relu(X)
            X = F.dropout(X, self.dropout)

        if return_alphas:
            X, alpha, edge_index = self.layers[-1](
                X, A, edge_attr=edge_attr, return_alpha=True)
            alphas.append(alpha)
            return X, alphas, edge_index

        X = self.layers[-1](X, A, edge_attr=edge_attr)
        return X

In [2]:
import os
import sys
from pprint import pprint
sys.path.append('.')

from utils import *
import pandas as pd
import numpy as np
from tqdm import tqdm
import optuna
from sklearn import metrics
import torch
import torch.nn as nn
import torch.optim as optim

def main(name, label_path, ppi_path=None,
         expr_path=None, ortho_path=None, subloc_path=None, no_ppi=False, 
         weights=False, seed=0, train_mode=False, savedir='.', predsavedir='.'):

    set_seed(seed)

    snapshot_name = f'{name}'
    snapshot_name += f'_expr' if expr_path is not None else ''
    snapshot_name += f'_ortho' if ortho_path is not None else ''
    snapshot_name += f'_subl' if subloc_path is not None else ''
    snapshot_name += f'_ppi' if not no_ppi else ''

    savepath = os.path.join(savedir, snapshot_name)

    # Getting the data ----------------------------------
    (edge_index, edge_weights), X, (train_idx, train_y), \
        (val_idx, val_y), (test_idx, test_y), genes = data(label_path, ppi_path, expr_path, ortho_path, subloc_path, no_ppi=no_ppi, weights=weights)
    print('Fetched data')

    # Train the model -----------------------------------
    if train_mode:
        print('\nTraining the model')
        gat_params = gat_human
        model = train(gat_params, X, edge_index, edge_weights,
                        train_y, train_idx, val_y, val_idx, savepath=savepath)
    # ---------------------------------------------------

    # Load trained model --------------------------------
    print(f'\nLoading the model from: {savepath}')
    snapshot = torch.load(savepath)
    model = GAT(in_feats=X.shape[1], **snapshot['model_params'])
    model.load_state_dict(snapshot['model_state_dict'])
    print('Model loaded. Val AUC: {}'.format(snapshot['auc']))
    # ---------------------------------------------------

    # Test the model ------------------------------------
    preds, auc, score, ba, mcc = test(model, X, edge_index, (test_idx, test_y))
    preds = np.concatenate(
        [genes[test_idx].reshape((-1, 1)), preds[test_idx]], axis=1)
    save_preds(modelname, preds, predsavedir, snapshot, seed=seed)
    print('Test AUC:', auc)
    print('Test Accuracy:', score)
    print('Test BA:', ba)
    print('Test MCC:', mcc)
    # ---------------------------------------------------

    preds, auc, score, ba, mcc

path = "../../data"
ipath = "./data"
main('kidney', os.path.join(path, 'Kidney_HELP_2.csv'), 
        os.path.join(path, 'Kidney_PPI.csv'),
        os.path.join(ipath, 'GTEX_expr_kidney.csv'),
        os.path.join(ipath, 'Orthologs_kidney.csv'),
        os.path.join(ipath, 'Sublocs_kidney.csv'), no_ppi=False, weights=True, train_mode=True)

  from .autonotebook import tqdm as notebook_tqdm


PPI: Kidney_PPI.csv.
Filtered String network with thresh: 0
