In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch_geometric
import copy
from sklearn.metrics import mean_squared_error

### data

In [None]:
network_name = "DAGMA_DAG"

In [None]:
X_train = pd.read_csv("../result/data/X_train", sep="\t", header=None).values
X_valid = pd.read_csv("../result/data/X_valid", sep="\t", header=None).values
X_test = pd.read_csv("../result/data/X_test", sep="\t", header=None).values
Y_train = pd.read_csv("../result/data/Y_train", sep="\t", header=None).values.reshape(-1) * 1000
Y_valid = pd.read_csv("../result/data/Y_valid", sep="\t", header=None).values.reshape(-1) * 1000
Y_test = pd.read_csv("../result/data/Y_test", sep="\t", header=None).values.reshape(-1) * 1000

X = np.concatenate([X_train, X_valid, X_test])
Y = np.concatenate([Y_train, Y_valid, Y_test])

train_mask = np.concatenate([[True] * len(X_train), [False] * len(X_valid), [False] * len(X_test)])
valid_mask = np.concatenate([[False] * len(X_train), [True] * len(X_valid), [False] * len(X_test)])
test_mask = np.concatenate([[False] * len(X_train), [False] * len(X_valid), [True] * len(X_test)])

edge_index = torch.tensor(pd.read_csv("../result/network/%s.tsv" % network_name, sep="\t", header=None).values.T)

data = torch_geometric.data.Data(x=torch.tensor(X).float(), edge_index=edge_index, y=torch.tensor(Y).float())
data.train_mask = torch.tensor(train_mask)
data.valid_mask = torch.tensor(valid_mask)
data.test_mask = torch.tensor(test_mask)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)

In [None]:
from torch_geometric.nn import SGConv

class GCN(torch.nn.Module):
    def __init__(self, GCN_size1, GCN_size2):
        super().__init__()
        self.GCN_size1 = GCN_size1
        self.GCN_size2 = GCN_size2
        
        self.conv1 = SGConv(1969, self.GCN_size1, K=2)
        self.fc1 = torch.nn.Linear(self.GCN_size1, self.GCN_size2)
        self.fc2 = torch.nn.Linear(self.GCN_size2, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x1 = self.conv1(x, edge_index)
        x = torch.tanh(x1)
        x = self.fc1(x)
        x = torch.tanh(x)
        x = self.fc2(x)       
        return x, x1

In [None]:
torch_geometric.seed.seed_everything(100)

# the general best para from previous results
best_para = (256, 128)

model = GCN(best_para[0], best_para[1]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-4)

model.train()

es_test_loss = 1000
stop_count = 0
train_loss_list = []

for epoch in range(200):
    optimizer.zero_grad()
    out, _ = model(data)
    loss = F.mse_loss(out[data.train_mask | data.valid_mask | data.test_mask].view(-1), data.y[data.train_mask | data.valid_mask | data.test_mask])
    loss.backward()
    optimizer.step()

    train_loss_list.append(loss)

    # the epoch with smallest valid loss
    if loss < es_test_loss:
        es_test_loss = loss
        stop_count = 0
    else:
        stop_count+=1
    
    # allowance = 5
    if stop_count > 5:
        break

In [None]:
x, x1 = model(data)
X1 = x1.detach().cpu().numpy()
np.save("../result/X1.npy", X1) # save emb
torch.save(model.state_dict(), "../result/model/DAGMA_DAG_best") # save model