In [2]:
import scipy.sparse as sp
from torch_geometric.explain import Explainer, GNNExplainer, ModelConfig, ThresholdConfig

from HGGN import *
from utils import *

set_seed(10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

fold = 0
hid_r = 128
n_layers = 4
n_features = 512
lr = 0.0005

adj, interaction, rna_features, drug_features, inter_features_rna, inter_features_drug, \
    train_pos_data, test_pos_data, train_neg_data, test_neg_data = load_data(data_dir='./data/',
                                                                             k_index=fold)

global_node_num = int(adj.shape[0] * 0.1)

adj = np.vstack((np.hstack((adj, np.ones(shape=(adj.shape[0], global_node_num)))),
                 np.hstack((np.ones(shape=(global_node_num, adj.shape[0])),
                            np.zeros((global_node_num, global_node_num))))))

train_data = np.vstack([train_pos_data, train_neg_data])
train_data_label = torch.tensor(
    np.vstack([np.ones([train_pos_data.shape[0], 1]), np.zeros([train_neg_data.shape[0], 1])]), dtype=torch.float32)
test_data = np.vstack([test_pos_data, test_neg_data])
test_data_label = torch.tensor(
    np.vstack([np.ones([test_pos_data.shape[0], 1]), np.zeros([test_neg_data.shape[0], 1])]), dtype=torch.float32)

sp_adj = sp.coo_matrix(adj)
indices = np.vstack((sp_adj.row, sp_adj.col))
adj = torch.LongTensor(indices)

interaction = torch.tensor(interaction)
rna_features = torch.tensor(rna_features)
drug_features = torch.tensor(drug_features)
inter_features_rna = torch.tensor(inter_features_rna)
inter_features_drug = torch.tensor(inter_features_drug)

model = HGGN(r=hid_r, n_layers=n_layers, n_features=n_features,
             num_rna=rna_features.shape[0],
             num_dis=drug_features.shape[0],
             n_global_node=global_node_num
             )
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

loss_function = torch.nn.BCELoss()

for i in range(200):
    # train
    model.zero_grad()
    model.train()
    x = model.projection_and_aggregation(rna_features, drug_features, inter_features_rna, inter_features_drug)
    train_output = model(x, adj, train_data)
    train_loss = loss_function(train_output, train_data_label)
    train_auc = metrics.roc_auc_score(train_data_label.detach().cpu().numpy(), train_output.detach().cpu().numpy())
    precision, recall, _ = metrics.precision_recall_curve(train_data_label.detach().cpu().numpy(),
                                                          train_output.detach().cpu().numpy())
    train_aupr = metrics.auc(recall, precision)
    print(f'Epoch:{i + 1} Train - Loss: {train_loss.detach().cpu().numpy()}, - AUC: {train_auc} - AUPR: {train_aupr}')
    train_loss.backward()
    optimizer.step()
model.eval()

model_config = ModelConfig(
    mode='binary_classification',
    task_level='edge',
    return_type='raw',
)

explainer = Explainer(
    model=model,
    explanation_type='phenomenon',
    algorithm=GNNExplainer(epochs=100),
    node_mask_type='object',
    edge_mask_type='object',
    model_config=model_config,
    threshold_config=ThresholdConfig(threshold_type='topk', value=10)
)

explanation = explainer(
    x=model.projection_and_aggregation(rna_features, drug_features, inter_features_rna,
                                       inter_features_drug).detach().cpu(),
    edge_index=adj,
    coo_data=test_data[0, :].reshape(1, -1),
    target=torch.tensor(1)
)
print(test_data[0, :])
print(f'Generated model explanations in {explanation.available_explanations}')


Epoch:1 Train - Loss: 0.6914006471633911, - AUC: 0.8981957315609842 - AUPR: 0.8901212704164303
Epoch:2 Train - Loss: 0.6888253688812256, - AUC: 0.8724738921734251 - AUPR: 0.857275021890657
Epoch:3 Train - Loss: 0.6693514585494995, - AUC: 0.9211063074522916 - AUPR: 0.9210189301539611
Epoch:4 Train - Loss: 0.6271985769271851, - AUC: 0.9158867124649627 - AUPR: 0.9122464799441699
Epoch:5 Train - Loss: 0.5837936401367188, - AUC: 0.9131030335558621 - AUPR: 0.9117207717359808
Epoch:6 Train - Loss: 0.5589980483055115, - AUC: 0.9086047755225407 - AUPR: 0.9083291214824418
Epoch:7 Train - Loss: 0.532586395740509, - AUC: 0.9095643250898765 - AUPR: 0.9055980241259644
Epoch:8 Train - Loss: 0.49660295248031616, - AUC: 0.9103171872863426 - AUPR: 0.9062094546955803
Epoch:9 Train - Loss: 0.47924256324768066, - AUC: 0.9099737591763084 - AUPR: 0.9085686662678387
Epoch:10 Train - Loss: 0.45704787969589233, - AUC: 0.9100703052313381 - AUPR: 0.9048413491090098
Epoch:11 Train - Loss: 0.4171416163444519, - AUC