In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import networkx as nx
import numpy as np
import json 
import math
from collections import defaultdict as ddict
import time 
import random
from itertools import combinations, chain
from sklearn.metrics import roc_auc_score
import os
import plotly.graph_objects as go
from sklearn.metrics import auc

import torch
import torch.nn.functional as F
from torch import nn

import dgl
from dgl.dataloading import GraphDataLoader
from dgl.data import TUDataset

from sklearn import preprocessing

Using backend: pytorch


In [1]:
from run_pipeline import *

Using backend: pytorch


In [2]:
dataset_folder = f'./data/BindingDB_Kd/'

graphs, labels, test_graphs1, test_graphs2, test_labels = load_data(dataset_folder)
graphs[0], test_graphs1[0]

(Graph(num_nodes=31, num_edges=961,
       ndata_schemes={'node_attr': Scheme(shape=(9,), dtype=torch.float32)}
       edata_schemes={'edge_attr': Scheme(shape=(1,), dtype=torch.float32)}),
 Graph(num_nodes=34, num_edges=1156,
       ndata_schemes={'node_attr': Scheme(shape=(9,), dtype=torch.float32)}
       edata_schemes={'edge_attr': Scheme(shape=(1,), dtype=torch.float32)}))

In [56]:
graphs, labels = get_train_graphs_and_labels(dataset_folder, datapart='train')
test_graphs1, test_graphs2, test_labels = get_test_graphs_and_labels(dataset_folder, datapart='test')
graphs[0], test_graphs1[0]

(Graph(num_nodes=31, num_edges=961,
       ndata_schemes={'node_attr': Scheme(shape=(9,), dtype=torch.float32)}
       edata_schemes={'edge_attr': Scheme(shape=(1,), dtype=torch.float32)}),
 Graph(num_nodes=34, num_edges=1156,
       ndata_schemes={'node_attr': Scheme(shape=(9,), dtype=torch.float32)}
       edata_schemes={'edge_attr': Scheme(shape=(1,), dtype=torch.float32)}))

In [39]:
data_suffix = 'Ki'
folder = f'BindingDB_{data_suffix}'

graphs = dgl.load_graphs(f'{folder}/train.bin')[0]
labels = []
with open(f'{folder}/train.labels') as f:
    for line in f:
        if line:
            labels.append(int(line))
            
label2ix = dict()
ix = 0
for l in labels:
    if l not in label2ix:
        label2ix[l] = ix
        ix += 1
        
labels = [label2ix[l] for l in labels]

for g in graphs:
    g.ndata['node_attr'] = g.ndata['node_attr'].float()



loader = GraphDataLoader(graphs, batch_size=len(graphs), drop_last=False, shuffle=False)
for gs in loader:
    X = gs.ndata['node_attr']
    

min_max_scaler = preprocessing.MinMaxScaler()
min_max_scaler.fit(X)
for g in graphs:
    g.ndata['node_attr'] = torch.tensor(min_max_scaler.transform(g.ndata['node_attr']))

    
test_graphs1 = dgl.load_graphs(f'{folder}/test_graph1.bin')[0]
test_graphs2 = dgl.load_graphs(f'{folder}/test_graph2.bin')[0]
test_labels = []
with open(f'{folder}/test.labels') as f:
    for line in f:
        if line:
            test_labels.append(int(line))
            
for g1, g2 in zip(test_graphs1, test_graphs2):
    g1.ndata['node_attr'] = torch.tensor(min_max_scaler.transform(g1.ndata['node_attr']))
    g2.ndata['node_attr'] = torch.tensor(min_max_scaler.transform(g2.ndata['node_attr']))

In [40]:
import models.models
import sys
import importlib
importlib.reload(sys.modules['models.models'])
from models.models import GNN

with_arcface = False
s = 4 
m = 0.5
batch_size = None
num_epochs = 100

gnn_model = GNN(with_arcface, lr=0.01, hidden_dim=128,  dropout=0., name='gcn', residual=True, s=s, m=m)

metrics = gnn_model.fit(graphs, labels, test_graphs1, test_graphs2, test_labels, num_epochs, batch_size=batch_size,)

(0.8266128412149174,) (0.10430219146482123,) (0.0037447458922430263, 0.000687757909215956) (6.114455223083496, 6.112326622009277)
(0.8263514033064205,) (0.10418300653594773,) (0.0037447458922430263, 0.000687757909215956) (6.101343154907227, 6.099581241607666)
(0.8258323721645521,) (0.10398308342945022,) (0.0037447458922430263, 0.000687757909215956) (6.087991237640381, 6.086574077606201)
(0.8253018069973088,) (0.10371972318339102,) (0.0036683225066870464, 0.000687757909215956) (6.074042320251465, 6.072942733764648)
(0.8247327950788157,) (0.10347943098808152,) (0.004661826518914788, 0.002751031636863824) (6.059293746948242, 6.0584797859191895)
(0.824046520569012,) (0.10316224529027299,) (0.00550248376003057, 0.0020632737276478678) (6.043670177459717, 6.043121337890625)
(0.8232987312572086,) (0.10283352556708958,) (0.005273213603362629, 0.002751031636863824) (6.027217864990234, 6.026915073394775)
(0.8225547866205306,) (0.10260476739715496,) (0.010699273977837218, 0.0068775790921595595) (6

(0.7969396386005383,) (0.09562860438292964,) (0.027665265571264808, 0.023383768913342505) (5.354515075683594, 5.422595500946045)
(0.7974509803921569,) (0.09596116878123799,) (0.02797095911348873, 0.024759284731774415) (5.344349384307861, 5.413329124450684)
(0.7980353710111496,) (0.09629373317954634,) (0.028964463125716468, 0.027510316368638238) (5.3342413902282715, 5.403995513916016)
(0.7985332564398308,) (0.0965128796616686,) (0.03049293083683607, 0.028198074277854195) (5.324174880981445, 5.394580364227295)
(0.7989234909650136,) (0.09674932718185315,) (0.031792128391287734, 0.028198074277854195) (5.314135551452637, 5.385080337524414)
(0.7994156093810072,) (0.09697424067666284,) (0.03293847917462744, 0.027510316368638238) (5.304106712341309, 5.375489234924316)
(0.7999077277970013,) (0.09717031910803539,) (0.0338555598012992, 0.030949105914718018) (5.294073104858398, 5.365809440612793)
(0.8004498269896193,) (0.09733179546328337,) (0.0338555598012992, 0.030949105914718018) (5.28402519226

In [43]:
graphs[0].edata

{'edge_attr': tensor([[ 0.],
        [ 3.],
        [15.],
        ...,
        [ 6.],
        [ 6.],
        [ 0.]])}

In [41]:
gnn_model.plot_interactive([metrics], legend=['Train', 'Test'], title='TAR_FAR AUC', metric_name='tar_auc', start_from=0, output_fn=None, to_show=True)
gnn_model.plot_interactive([metrics], legend=['Train', 'Test'], title='ROC AUC', metric_name='auc', start_from=0, output_fn=None, to_show=True)
gnn_model.plot_interactive([metrics], legend=['Train', 'Test'], title='Train acc', metric_name='acc', start_from=0, output_fn=None, to_show=True)
gnn_model.plot_interactive([metrics], legend=['Train', 'Test'], title='Train loss', metric_name='loss', start_from=0, output_fn=None, to_show=True)

In [36]:
gnn_model.plot_interactive([metrics], legend=['Train', 'Test'], title='TAR_FAR AUC', metric_name='tar_auc', start_from=0, output_fn=None, to_show=True)
gnn_model.plot_interactive([metrics], legend=['Train', 'Test'], title='ROC AUC', metric_name='auc', start_from=0, output_fn=None, to_show=True)
gnn_model.plot_interactive([metrics], legend=['Train', 'Test'], title='Train acc', metric_name='acc', start_from=0, output_fn=None, to_show=True)
gnn_model.plot_interactive([metrics], legend=['Train', 'Test'], title='Train loss', metric_name='loss', start_from=0, output_fn=None, to_show=True)