# Construct synthetic graph

In [1]:
import numpy as np
import networkx as nx
import scipy.sparse as sp
from itertools import combinations
import dgl
import torch
from src.model import Model

Using backend: pytorch


In [2]:
n_genes = 500 # дб кратно 10
n_drugs = 400
n_drugdrug_rel_types = 3
gene_net = nx.planted_partition_graph(n_genes // 10, 10, 0.2, 0.05, seed=42, directed=False)

In [3]:
gene_drug_adj = sp.csr_matrix(
            (10 * np.random.randn(n_genes, n_drugs) > 15).astype(int))
drug_gene_adj = gene_drug_adj.transpose(copy=True)
gene_adj = nx.adjacency_matrix(gene_net)

In [4]:
drug_drug_adj_list = []
tmp = np.dot(drug_gene_adj, gene_drug_adj)

for i in range(n_drugdrug_rel_types):
    mat = np.zeros((n_drugs, n_drugs))
    for d1, d2 in combinations(list(range(n_drugs)), 2):
        if tmp[d1, d2] == i + 4:
            mat[d1, d2] = mat[d2, d1] = 1.
    drug_drug_adj_list.append(sp.csr_matrix(mat))

In [5]:
graph = dgl.heterograph(
    {('protein', 'association', 'protein') : gene_adj.nonzero(),
     ('drug', 'interaction', 'protein') : drug_gene_adj.nonzero(), 
     ('protein', 'interaction_by', 'drug') : gene_drug_adj.nonzero(), 
     ('drug', 'side_effect_0', 'drug'): drug_drug_adj_list[0].nonzero(), 
     ('drug', 'side_effect_1', 'drug'): drug_drug_adj_list[1].nonzero(), 
     ('drug', 'side_effect_2', 'drug'): drug_drug_adj_list[2].nonzero()})

Добавим петли в граф, чтобы при пересчете эмбеддингов учесть эмбеддинги с предудыщего слоя.

In [6]:
for edge_type in ['association', 'side_effect_0', 'side_effect_1', 'side_effect_2']:
    graph = dgl.add_self_loop(graph, edge_type)

In [7]:
graph

Graph(num_nodes={'drug': 400, 'protein': 500},
      num_edges={('drug', 'interaction', 'protein'): 13329, ('drug', 'side_effect_0', 'drug'): 18014, ('drug', 'side_effect_1', 'drug'): 8210, ('drug', 'side_effect_2', 'drug'): 3240, ('protein', 'association', 'protein'): 13778, ('protein', 'interaction_by', 'drug'): 13329},
      metagraph=[('drug', 'protein', 'interaction'), ('drug', 'drug', 'side_effect_0'), ('drug', 'drug', 'side_effect_1'), ('drug', 'drug', 'side_effect_2'), ('protein', 'protein', 'association'), ('protein', 'drug', 'interaction_by')])

Добавим фичи

In [8]:
graph.nodes['protein'].data['feature'] = torch.eye(n_genes)
graph.nodes['drug'].data['feature'] = torch.eye(n_drugs)

# Toy train

In [13]:
def construct_negative_graph(graph, k, etype):
    'etype - Tuple[str, str, str]'
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.number_of_nodes(vtype), (len(src) * k,))
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.number_of_nodes(ntype) for ntype in graph.ntypes})

In [15]:
def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - neg_score.view(n_edges, -1) + pos_score.unsqueeze(1)).clamp(min=0).mean()

node2features = {'drug': graph.nodes['drug'].data['feature'], 
                 'protein': graph.nodes['protein'].data['feature']}
node2in_feat_dim = {'drug': n_drugs, 'protein': n_genes}
rel2nodes = {'interaction': ['drug', 'protein'], 
             'interaction_by': ['protein', 'drug'],
             'association': ['protein', 'protein'],
             'side_effect_0': ['drug', 'drug'], 
             'side_effect_1': ['drug', 'drug'], 
             'side_effect_2': ['drug', 'drug']}
model = Model(node2in_feat_dim=node2in_feat_dim, hidden_dim=64, embed_dim=8, rel2nodes=rel2nodes, bias=0.0, dropout=0.0)

opt = torch.optim.Adam(model.parameters())

for epoch in range(10):
    negative_graph = construct_negative_graph(graph, 1, ('drug', 'side_effect_0', 'drug'))
    pos_score, neg_score = model(graph, negative_graph, node2features, 'side_effect_0')
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())


1.000003457069397
0.9999546408653259
0.9997059106826782
0.9995090961456299
0.999215841293335
0.9986886382102966
0.9979740977287292
0.9972043633460999
0.9962499141693115
0.9954308867454529
