In [None]:
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch_geometric
import pickle
import optuna
from scipy.stats import pearsonr
from sklearn.metrics import mean_squared_error
from torch_geometric.nn import GCNConv
import networkx as nx
import warnings
warnings.filterwarnings("ignore")

In [None]:
class GCN(torch.nn.Module):
    def __init__(self, size1, size2, size3):
        super(GCN, self).__init__()

        self.conv1 = GCNConv(1969, size1)
        self.conv2 = GCNConv(size1, size2)
        self.fc1 = torch.nn.Linear(size2, size3)
        self.fc2 = torch.nn.Linear(size3, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x1 = torch.relu(x)
        x = self.fc1(x1)
        x = torch.relu(x)
        x = self.fc2(x)       
        return x, x1

def train(model, data, optimizer, mask):
    model.train()
    optimizer.zero_grad()
    out, _ = model(data)
    loss = F.mse_loss(out[mask].view(-1), data.y[mask])
    loss.backward()
    optimizer.step()
    return loss.item()

def evaluate(model, data, mask):
    model.eval()
    with torch.no_grad():
        out, _ = model(data)
        loss = F.mse_loss(out[mask].view(-1), data.y[mask])
    return loss.item()

def objective(trial: optuna.Trial):
    size1 = trial.suggest_int("size1", 32, 512)
    size2 = trial.suggest_int("size2", 4, 512)
    size3 = trial.suggest_int("size3", 4, 512)
    num_epochs = trial.suggest_int("num_epochs", 1, 200)
    lr = trial.suggest_float("lr", 1e-5, 1e-2)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GCN(size1, size2, size3).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # find best hyperparameter by training on the training set and get loss on the validating set
    for epoch in range(1, num_epochs + 1):
        train_loss = train(model, data, optimizer, data.train_mask)
        val_loss = evaluate(model, data, data.valid_mask)
        trial.report(val_loss, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
    return val_loss

In [None]:
for network_name in ["DAGMA_thresholdAdaptive", "NOTEARS_thresholdAdaptive", "Combine", "STRING", "BIOGRID", "ChIP_hTFtarget", "ChIP_TIP", "ChIP_TIP_K562", "CoExpr_ENCODE_K562_0.75", "CoExpr_GTEx_WholeBlood_0.75", "CoExpr_perturb_0.5", "Random_ER", "Random_SF"]:
    loss_list = []
    p_list = []
    loss_net_list = []
    p_net_list = []
    
    for rs in range(1, 6):

        # read data
        X_train = pd.read_csv("../../result/input_perturb_go/stratified_%d/X_train" % rs, sep="\t", index_col=0)
        X_valid = pd.read_csv("../../result/input_perturb_go/stratified_%d/X_valid" % rs, sep="\t", index_col=0)
        X_test = pd.read_csv("../../result/input_perturb_go/stratified_%d/X_test" % rs, sep="\t", index_col=0)
        Y_train = pd.read_csv("../../result/input_perturb_go/stratified_%d/Y_train" % rs, sep="\t", index_col=0)
        Y_valid = pd.read_csv("../../result/input_perturb_go/stratified_%d/Y_valid" % rs, sep="\t", index_col=0)
        Y_test = pd.read_csv("../../result/input_perturb_go/stratified_%d/Y_test" % rs, sep="\t", index_col=0)
        X = pd.concat([X_train, X_valid, X_test])
        Y = pd.concat([Y_train, Y_valid, Y_test])

        # split masks
        train_mask = np.concatenate([[True] * len(X_train), [False] * len(X_valid), [False] * len(X_test)])
        valid_mask = np.concatenate([[False] * len(X_train), [True] * len(X_valid), [False] * len(X_test)])
        test_mask = np.concatenate([[False] * len(X_train), [False] * len(X_valid), [True] * len(X_test)])
        mask = pd.DataFrame([train_mask, valid_mask, test_mask]).T
        mask.index = X.index

        # re-order data to match the network index
        valid_genes = pd.read_csv("../../result/network_perturb_go/valid_genes", sep="\t")
        X = X.loc[valid_genes['genes'].values]
        Y = Y.loc[valid_genes['genes'].values]
        mask = mask.loc[valid_genes['genes'].values]

        # read network
        g_all = pd.read_csv("../../result/network_perturb_go/%s.tsv" % network_name, sep="\t", header=None)
        edge_index = torch.tensor(g_all.values.T)

        # convert to pyg data object
        data = torch_geometric.data.Data(x=torch.tensor(X.values).float(), y=torch.tensor(Y.iloc[:,0].values).float(), edge_index=edge_index)
        data.train_mask = torch.tensor(mask[0].values)
        data.valid_mask = torch.tensor(mask[1].values)
        data.test_mask = torch.tensor(mask[2].values)
        data.test_net_mask = torch.tensor(mask[2].values & dag_mask)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        data = data.to(device)

        # repeat the whole process multiple times for mean and std of the loss
        os.makedirs("../../result/model_perturb_go_GCN/%s/%d/" % (network_name, rs), exist_ok=True)
        for rep in range(10):
            # tune hyperparameters
            study = optuna.create_study(direction="minimize")
            study.optimize(objective, n_trials=20, timeout=600)
            best_params = study.best_params

            # final model with best hyperparameters
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            final_model = GCN(best_params["size1"], best_params["size2"], best_params["size3"]).to(device)
            optimizer = torch.optim.Adam(final_model.parameters(), lr=best_params["lr"])

            # train final model with both train and valid data
            for epoch in range(1, best_params["num_epochs"] + 1):
                train_loss = train(final_model, data, optimizer, data.train_mask|data.valid_mask)

            # prediction from final model
            final_model.eval()
            with torch.no_grad():
                out, X1 = final_model(data)

            # save model, hyperparameters and results
            prefix = "../../result/model_perturb_go_GCN/%s/%d/model%d" % (network_name, rs, rep)
            torch.save(final_model.state_dict(), prefix+".model")
            with open(prefix+".para", "wb") as f:
                pickle.dump([best_params, X1.cpu().numpy(), out.cpu().numpy()], f)