# Construct synthetic graph

In [1]:
import numpy as np
import networkx as nx
import scipy.sparse as sp
from itertools import combinations
import dgl
import torch
from src.model import Model

Using backend: pytorch


In [2]:
n_genes = 500 # дб кратно 10
n_drugs = 400
n_drugdrug_rel_types = 3
gene_net = nx.planted_partition_graph(n_genes // 10, 10, 0.2, 0.05, seed=42, directed=False)

In [3]:
gene_drug_adj = sp.csr_matrix(
            (10 * np.random.randn(n_genes, n_drugs) > 15).astype(int))
drug_gene_adj = gene_drug_adj.transpose(copy=True)
gene_adj = nx.adjacency_matrix(gene_net)

In [4]:
drug_drug_adj_list = []
tmp = np.dot(drug_gene_adj, gene_drug_adj)

for i in range(n_drugdrug_rel_types):
    mat = np.zeros((n_drugs, n_drugs))
    for d1, d2 in combinations(list(range(n_drugs)), 2):
        if tmp[d1, d2] == i + 4:
            mat[d1, d2] = mat[d2, d1] = 1.
    drug_drug_adj_list.append(sp.csr_matrix(mat))

In [5]:
graph = dgl.heterograph(
    {('protein', 'association', 'protein') : gene_adj.nonzero(),
     ('drug', 'interaction', 'protein') : drug_gene_adj.nonzero(), 
     ('protein', 'interaction_by', 'drug') : gene_drug_adj.nonzero(), 
     ('drug', 'side_effect_0', 'drug'): drug_drug_adj_list[0].nonzero(), 
     ('drug', 'side_effect_1', 'drug'): drug_drug_adj_list[1].nonzero(), 
     ('drug', 'side_effect_2', 'drug'): drug_drug_adj_list[2].nonzero()})

Добавим петли в граф, чтобы при пересчете эмбеддингов учесть эмбеддинги с предудыщего слоя.

In [6]:
for edge_type in ['association', 'side_effect_0', 'side_effect_1', 'side_effect_2']:
    graph = dgl.add_self_loop(graph, edge_type)

In [7]:
graph

Graph(num_nodes={'drug': 400, 'protein': 500},
      num_edges={('drug', 'interaction', 'protein'): 13274, ('drug', 'side_effect_0', 'drug'): 17854, ('drug', 'side_effect_1', 'drug'): 8086, ('drug', 'side_effect_2', 'drug'): 3108, ('protein', 'association', 'protein'): 13778, ('protein', 'interaction_by', 'drug'): 13274},
      metagraph=[('drug', 'protein', 'interaction'), ('drug', 'drug', 'side_effect_0'), ('drug', 'drug', 'side_effect_1'), ('drug', 'drug', 'side_effect_2'), ('protein', 'protein', 'association'), ('protein', 'drug', 'interaction_by')])

Добавим фичи

In [8]:
graph.nodes['protein'].data['feature'] = torch.eye(n_genes)
graph.nodes['drug'].data['feature'] = torch.eye(n_drugs)

# EdgesLoader

In [9]:
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)

In [10]:
train_eid_dict = {'interaction': torch.arange(graph.num_edges('interaction') - 500), 
                  'interaction_by': torch.arange(graph.num_edges('interaction_by') - 500), 
                  'association': torch.arange(graph.num_edges('association') - 500), 
                  'side_effect_0': torch.arange(graph.num_edges('side_effect_0') - 500), 
                  'side_effect_1': torch.arange(graph.num_edges('side_effect_1') - 500), 
                  'side_effect_2': torch.arange(graph.num_edges('side_effect_2') - 500)}

In [11]:
dataloader = dgl.dataloading.EdgeDataLoader(
    graph, train_eid_dict, sampler,
    negative_sampler=dgl.dataloading.negative_sampler.Uniform(5),
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=4)

# Toy train

In [12]:
def compute_loss(pos_score, neg_score):
    # Margin loss
    score = 0.
    for etype in pos_score.keys():
        if not len(pos_score[etype]):
            continue
        pos, neg = None, None

        if pos_score[etype].shape == torch.Size([1]):
            pos = pos_score[etype].unsqueeze(0)
            neg = neg_score[etype].unsqueeze(0)
        else:
            pos = pos_score[etype].squeeze()
            neg = neg_score[etype].squeeze()
        n_edges = pos_score[etype].shape[0]
        score += (1 - neg.view(n_edges, -1) + pos.unsqueeze(1)).clamp(min=0).mean()
    return score / len(pos_score)

In [13]:
node2features = {'drug': graph.nodes['drug'].data['feature'], 
                 'protein': graph.nodes['protein'].data['feature']}
node2in_feat_dim = {'drug': n_drugs, 'protein': n_genes}
rel2nodes = {'interaction': ['drug', 'protein'], 
             'interaction_by': ['protein', 'drug'],
             'association': ['protein', 'protein'],
             'side_effect_0': ['drug', 'drug'], 
             'side_effect_1': ['drug', 'drug'], 
             'side_effect_2': ['drug', 'drug']}
model = Model(node2in_feat_dim=node2in_feat_dim, hidden_dim=64, embed_dim=8, rel2nodes=rel2nodes, bias=False, dropout=0.0)
opt = torch.optim.Adam(model.parameters())

In [14]:
for input_nodes, positive_graph, negative_graph, blocks in dataloader:
    input_features = blocks[0].srcdata['feature']
    pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features)
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

1.0000889301300049
1.0002793073654175
0.9996232986450195
0.9994733929634094
1.0001286268234253
0.998995304107666
0.9982969164848328
0.998706042766571
1.0000535249710083
0.9988147616386414
0.9975137710571289
0.9984338283538818
0.9979334473609924
0.9984638094902039
0.9977039694786072
0.9973219037055969
0.9963573813438416
0.996618926525116
0.9968838691711426
0.995826780796051
0.9964146614074707
0.9968047738075256
0.995299756526947
0.9951732158660889
0.9977521896362305
0.9947855472564697
0.9966996312141418
0.9956080913543701
0.9952062964439392
0.9957213401794434
0.9926124215126038
0.9946722984313965
0.9928417801856995
0.9936997890472412
0.9921159148216248
0.9917944073677063
0.9909383654594421
0.993380069732666
0.9903073906898499
0.9883115291595459
0.9937024116516113
0.9855095744132996
0.9908328056335449
0.988882839679718
0.9896569848060608
0.9883001446723938
0.9841164946556091
0.9855642318725586
0.9838072657585144
0.978958785533905
0.9803254008293152
0.9768226742744446
0.9817459583282471
0