In [7]:
# imports
import csv
import os
import pathlib
from collections import namedtuple
from time import time
from typing import List, Tuple, Optional
from random import randint, seed

import networkx as nx
import numpy as np
import torch
import torch_geometric.utils as tutils
from torch_geometric.data import Data
from torch_geometric.datasets import TUDataset
from tqdm import tqdm_notebook as tqdm

from sklearn.preprocessing import normalize
from sklearn.model_selection import train_test_split
from grakel.kernels import WeisfeilerLehman, VertexHistogram, ShortestPath
from grakel import GraphKernel
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import StratifiedShuffleSplit

from utils import load_singleton_graphs_from_TUDataset

import warnings
warnings.filterwarnings('ignore')

In [23]:
# Parameters
datasets = ['AIDS', 'BZR', 'BZR_MD', 'COX2', 'COX2_MD', 'DHFR', 'DHFR_MD', 'ER_MD', 'MUTAG', 'Mutagenicity', 'NCI1', 'NCI109', 'PTC_FM', 'PTC_FR', 'PTC_MM', 'PTC_MR',
            'DD', 'ENZYMES', 'KKI', 'OHSU', 'Peking_1', 'PROTEINS_full',
            'MSRC_9', 'MSRC_21', 
            ]
            #'COLLAB', 'REDDIT-BINARY', 'REDDIT-MULTI-5K', 'REDDIT-MULTI-12K']
# 'NCI-H23H''AIDS', 
seed(42)
seeds = [randint(0, 10000) for _ in range(10)]

In [24]:
# CONSTANTS
NODE_ATTRIBUTE = 'x'

CLF_METHODS = {
    'knn': (KNeighborsClassifier, {'kneighborsclassifier__n_neighbors': [3, 5, 7, 9, 11]}),
    'rbf': (SVC, {#'svc__gamma': np.logspace(-6, 2, 9),
                  'svc__C': np.logspace(-2, 2, 5)})
}


In [25]:
from torch_geometric.utils import degree

def load_special_singleton(root: str,
                                         dataset: str,
                                         node_attr: str = 'x') -> Tuple[List[np.ndarray], np.ndarray]:
    """
    Use the Pytorch Geometric (PyG) loader to download the graph dataset from the TUDataset repo.
    The raw graphs from PyG are saved in `root`.
    
    A singleton graph is created by summing all the node attributes `x` of the original graph.
    The corresponding class of each graph is also retrieved from TUDataset dataset.
    
    Args:
        root: Path where to save the raw graph dataset
        name_dataset: Name of the graph dataset to load
    Returns:
        List of the loaded `np.ndarray` singleton graphs and `np.ndarray` of the corresponding class of each graph
    """
    dataset = TUDataset(root=root, name=dataset)

    tmp_graph = dataset[0]
    is_graph_labelled = node_attr in tmp_graph.keys

    # Convert the PyG graphs into singleton graphs
    graphs = []
    graph_labels = []
    for graph in tqdm(dataset, desc='Convert graph to singleton'):
        
        if not is_graph_labelled:
            # Create graph with dummy node vector
            graph = Data(x=torch.tensor(np.ones((graph.num_nodes, 1))),
                         y=graph.y,
                         edge_index=graph.edge_index)
        degrees = degree(graph.edge_index[0], graph.num_nodes)
        node_feature = torch.mul(graph.x, degrees.view(-1, 1))
        
        graphs.append(np.array(node_feature.sum(axis=0)))
        graph_labels.append(int(graph.y))

    graph_cls = np.array(graph_labels)

    return graphs, graph_cls

In [26]:
def distinct_graphs(graphs):
    unique_graphs = []
    for graph in graphs:
        
        if not any(np.array_equal(graph, g_) for g_ in unique_graphs):
            unique_graphs.append(graph)
            # unique_graphs.append(graph)
        
            
    return unique_graphs

In [30]:
from sklearn.metrics import accuracy_score, f1_score, balanced_accuracy_score, precision_score, recall_score, roc_auc_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from collections import Counter
import matplotlib.pyplot as plt


n_jobs = 10
cv = 5
verbose = False

results = {}
folder = './predictions'

for dataset in datasets:
    print(dataset)

    graphs, labels = load_singleton_graphs_from_TUDataset(root=os.path.join('root', dataset),
                                                          dataset=dataset,
                                                          node_attr=NODE_ATTRIBUTE)
#     graphs, labels = load_special_singleton(root=os.path.join('root', dataset),
#                                                           dataset=dataset,
#                                                           node_attr=NODE_ATTRIBUTE)
    # print(f'Distinct graphs {len(distinct_graphs(graphs))}/{len(graphs)}')
    
    print("class proportions", np.unique(labels, return_counts=True)[1] / len(labels))
    # plt.hist(labels)
    
    for clf_name, clf in CLF_METHODS.items():
        clf_method, param_grid = clf

        f1_scores = []
        accuracies = []
        precisions, recalls = [], []
        balanced_accuracies = []
        aucrocs = []
        

        for c_seed in seeds:
            
            G_train, G_test, y_train, y_test = train_test_split(graphs,
                                                                labels,
                                                                test_size=0.2,
                                                                random_state=c_seed,
                                                                stratify=labels)
            pipe_clf = make_pipeline(StandardScaler(),
                                     clf_method())
            clf = GridSearchCV(estimator=pipe_clf,
                                param_grid=param_grid,
                                n_jobs=n_jobs,
                                cv=cv,
                               scoring='f1_micro',
                                verbose=int(verbose)*3)
            
            
            clf.fit(G_train, y_train)
            y_predictions = clf.predict(G_test)
            
            
            
            accuracies.append(accuracy_score(y_test, y_predictions))
            balanced_accuracies.append(balanced_accuracy_score(y_test, y_predictions))
            average = 'binary' if len(set(labels)) <= 2 else 'macro'
            f1_scores.append(f1_score(y_test, y_predictions, average=average))
            precisions.append(precision_score(y_test, y_predictions))
            recalls.append(recall_score(y_test, y_predictions))
            aucrocs.append(roc_auc_score(y_test, y_predictions, multi_class='ovr'))
        
        print(f'{dataset}, {clf_name} (acc): {np.mean(accuracies)*100:.2f} +- {np.std(accuracies)*100:.2f}')
        print(f'{dataset}, {clf_name} (balanced acc): {np.mean(balanced_accuracies)*100:.2f} +- {np.std(balanced_accuracies)*100:.2f}')
        print(f'{dataset}, {clf_name} (f1): {np.mean(f1_scores)*100:.2f} +- {np.std(f1_scores)*100:.2f}')
        print(f'{dataset}, {clf_name} (precision): {np.mean(precisions)*100:.2f} +- {np.std(precisions)*100:.2f}')
        print(f'{dataset}, {clf_name} (recall): {np.mean(recalls)*100:.2f} +- {np.std(recalls)*100:.2f}')
        print(f'{dataset}, {clf_name} (AUCROC): {np.mean(aucrocs)*100:.2f} +- {np.std(aucrocs)*100:.2f}')
        print('*' * 50)
            
    print('='*60)
    
  

AIDS


Convert graph to singleton: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:00<00:00, 6833.23it/s]


class proportions [0.2 0.8]
AIDS, knn (acc): 98.15 +- 0.65
AIDS, knn (balanced acc): 96.03 +- 1.33
AIDS, knn (f1): 98.85 +- 0.40
AIDS, knn (precision): 98.15 +- 0.61
AIDS, knn (recall): 99.56 +- 0.32
AIDS, knn (AUCROC): 96.03 +- 1.33
**************************************************
AIDS, rbf (acc): 99.05 +- 0.40
AIDS, rbf (balanced acc): 97.91 +- 0.89
AIDS, rbf (f1): 99.41 +- 0.25
AIDS, rbf (precision): 99.01 +- 0.43
AIDS, rbf (recall): 99.81 +- 0.21
AIDS, rbf (AUCROC): 97.91 +- 0.89
**************************************************
BZR


Convert graph to singleton: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 405/405 [00:00<00:00, 17653.71it/s]

class proportions [0.78765432 0.21234568]





BZR, knn (acc): 83.33 +- 2.16
BZR, knn (balanced acc): 64.40 +- 4.98
BZR, knn (f1): 43.54 +- 11.38
BZR, knn (precision): 75.66 +- 11.92
BZR, knn (recall): 31.76 +- 10.26
BZR, knn (AUCROC): 64.40 +- 4.98
**************************************************
BZR, rbf (acc): 84.44 +- 2.77
BZR, rbf (balanced acc): 67.26 +- 5.94
BZR, rbf (f1): 49.28 +- 12.96
BZR, rbf (precision): 78.62 +- 16.11
BZR, rbf (recall): 37.65 +- 12.67
BZR, rbf (AUCROC): 67.26 +- 5.94
**************************************************
BZR_MD


Convert graph to singleton: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 306/306 [00:00<00:00, 8482.47it/s]

class proportions [0.4869281 0.5130719]





BZR_MD, knn (acc): 66.61 +- 6.04
BZR_MD, knn (balanced acc): 66.50 +- 6.22
BZR_MD, knn (f1): 68.46 +- 4.43
BZR_MD, knn (precision): 67.83 +- 7.29
BZR_MD, knn (recall): 70.00 +- 6.73
BZR_MD, knn (AUCROC): 66.50 +- 6.22
**************************************************
BZR_MD, rbf (acc): 68.71 +- 3.83
BZR_MD, rbf (balanced acc): 68.48 +- 3.97
BZR_MD, rbf (f1): 71.35 +- 3.11
BZR_MD, rbf (precision): 68.18 +- 5.05
BZR_MD, rbf (recall): 75.62 +- 6.82
BZR_MD, rbf (AUCROC): 68.48 +- 3.97
**************************************************
COX2


Convert graph to singleton: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 467/467 [00:00<00:00, 14914.19it/s]

class proportions [0.78158458 0.21841542]





COX2, knn (acc): 78.83 +- 2.25
COX2, knn (balanced acc): 58.89 +- 3.03
COX2, knn (f1): 32.15 +- 7.53
COX2, knn (precision): 58.91 +- 13.35
COX2, knn (recall): 22.86 +- 6.67
COX2, knn (AUCROC): 58.89 +- 3.03
**************************************************
COX2, rbf (acc): 77.77 +- 2.45
COX2, rbf (balanced acc): 55.67 +- 3.58
COX2, rbf (f1): 23.19 +- 10.39
COX2, rbf (precision): 52.56 +- 25.79
COX2, rbf (recall): 15.71 +- 7.98
COX2, rbf (AUCROC): 55.67 +- 3.58
**************************************************
COX2_MD


Convert graph to singleton: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 303/303 [00:00<00:00, 7745.03it/s]

class proportions [0.51155116 0.48844884]





COX2_MD, knn (acc): 65.90 +- 5.86
COX2_MD, knn (balanced acc): 65.81 +- 5.88
COX2_MD, knn (f1): 63.11 +- 7.43
COX2_MD, knn (precision): 67.31 +- 6.94
COX2_MD, knn (recall): 60.00 +- 9.78
COX2_MD, knn (AUCROC): 65.81 +- 5.88
**************************************************
COX2_MD, rbf (acc): 66.72 +- 5.78
COX2_MD, rbf (balanced acc): 66.62 +- 5.81
COX2_MD, rbf (f1): 63.87 +- 7.75
COX2_MD, rbf (precision): 68.26 +- 6.88
COX2_MD, rbf (recall): 60.67 +- 10.31
COX2_MD, rbf (AUCROC): 66.62 +- 5.81
**************************************************
DHFR


Convert graph to singleton: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 756/756 [00:00<00:00, 16139.99it/s]

class proportions [0.39021164 0.60978836]





DHFR, knn (acc): 73.03 +- 2.73
DHFR, knn (balanced acc): 71.85 +- 3.03
DHFR, knn (f1): 77.74 +- 2.43
DHFR, knn (precision): 78.58 +- 2.99
DHFR, knn (recall): 77.10 +- 4.14
DHFR, knn (AUCROC): 71.85 +- 3.03
**************************************************
DHFR, rbf (acc): 72.17 +- 3.48
DHFR, rbf (balanced acc): 70.91 +- 3.95
DHFR, rbf (f1): 77.06 +- 3.01
DHFR, rbf (precision): 77.86 +- 3.95
DHFR, rbf (recall): 76.56 +- 5.04
DHFR, rbf (AUCROC): 70.91 +- 3.95
**************************************************
DHFR_MD


Convert graph to singleton: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 393/393 [00:00<00:00, 14762.20it/s]

class proportions [0.67938931 0.32061069]





DHFR_MD, knn (acc): 67.22 +- 3.64
DHFR_MD, knn (balanced acc): 56.69 +- 5.02
DHFR_MD, knn (f1): 34.66 +- 9.62
DHFR_MD, knn (precision): 46.07 +- 9.57
DHFR_MD, knn (recall): 28.00 +- 9.12
DHFR_MD, knn (AUCROC): 56.69 +- 5.02
**************************************************
DHFR_MD, rbf (acc): 67.34 +- 2.64
DHFR_MD, rbf (balanced acc): 49.47 +- 1.30
DHFR_MD, rbf (f1): 1.11 +- 3.33
DHFR_MD, rbf (precision): 1.82 +- 5.45
DHFR_MD, rbf (recall): 0.80 +- 2.40
DHFR_MD, rbf (AUCROC): 49.47 +- 1.30
**************************************************
ER_MD


Convert graph to singleton: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 446/446 [00:00<00:00, 8581.95it/s]

class proportions [0.5941704 0.4058296]





ER_MD, knn (acc): 71.00 +- 3.80
ER_MD, knn (balanced acc): 70.48 +- 4.14
ER_MD, knn (f1): 65.56 +- 5.25
ER_MD, knn (precision): 63.99 +- 4.27
ER_MD, knn (recall): 67.57 +- 7.74
ER_MD, knn (AUCROC): 70.48 +- 4.14
**************************************************
ER_MD, rbf (acc): 72.44 +- 3.17
ER_MD, rbf (balanced acc): 73.42 +- 3.16
ER_MD, rbf (f1): 70.18 +- 3.41
ER_MD, rbf (precision): 63.35 +- 3.58
ER_MD, rbf (recall): 78.92 +- 5.38
ER_MD, rbf (AUCROC): 73.42 +- 3.16
**************************************************
MUTAG


Convert graph to singleton: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 188/188 [00:00<00:00, 10815.25it/s]

class proportions [0.33510638 0.66489362]





MUTAG, knn (acc): 79.74 +- 6.23
MUTAG, knn (balanced acc): 77.77 +- 8.32
MUTAG, knn (f1): 84.58 +- 4.51
MUTAG, knn (precision): 85.69 +- 7.58
MUTAG, knn (recall): 84.00 +- 5.37
MUTAG, knn (AUCROC): 77.77 +- 8.32
**************************************************
MUTAG, rbf (acc): 79.21 +- 4.63
MUTAG, rbf (balanced acc): 75.15 +- 4.44
MUTAG, rbf (f1): 84.68 +- 3.87
MUTAG, rbf (precision): 81.84 +- 2.82
MUTAG, rbf (recall): 88.00 +- 6.69
MUTAG, rbf (AUCROC): 75.15 +- 4.44
**************************************************
Mutagenicity


Convert graph to singleton: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 4337/4337 [00:00<00:00, 15885.92it/s]


class proportions [0.55360849 0.44639151]
Mutagenicity, knn (acc): 71.07 +- 1.12
Mutagenicity, knn (balanced acc): 70.46 +- 1.14
Mutagenicity, knn (f1): 66.62 +- 1.40
Mutagenicity, knn (precision): 68.62 +- 1.53
Mutagenicity, knn (recall): 64.78 +- 2.15
Mutagenicity, knn (AUCROC): 70.46 +- 1.14
**************************************************
Mutagenicity, rbf (acc): 72.19 +- 1.51
Mutagenicity, rbf (balanced acc): 71.35 +- 1.48
Mutagenicity, rbf (f1): 67.10 +- 1.68
Mutagenicity, rbf (precision): 71.05 +- 2.27
Mutagenicity, rbf (recall): 63.62 +- 1.97
Mutagenicity, rbf (AUCROC): 71.35 +- 1.48
**************************************************
NCI1


Convert graph to singleton: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 4110/4110 [00:00<00:00, 10372.73it/s]


class proportions [0.49951338 0.50048662]
NCI1, knn (acc): 68.44 +- 1.07
NCI1, knn (balanced acc): 68.44 +- 1.07
NCI1, knn (f1): 68.13 +- 1.09
NCI1, knn (precision): 68.82 +- 1.26
NCI1, knn (recall): 67.47 +- 1.46
NCI1, knn (AUCROC): 68.44 +- 1.07
**************************************************
NCI1, rbf (acc): 68.45 +- 1.44
NCI1, rbf (balanced acc): 68.45 +- 1.44
NCI1, rbf (f1): 64.51 +- 1.87
NCI1, rbf (precision): 73.71 +- 1.64
NCI1, rbf (recall): 57.37 +- 2.20
NCI1, rbf (AUCROC): 68.45 +- 1.44
**************************************************
NCI109


Convert graph to singleton: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 4127/4127 [00:00<00:00, 18482.63it/s]


class proportions [0.49624425 0.50375575]
NCI109, knn (acc): 69.36 +- 1.17
NCI109, knn (balanced acc): 69.36 +- 1.16
NCI109, knn (f1): 69.35 +- 1.31
NCI109, knn (precision): 69.88 +- 1.26
NCI109, knn (recall): 68.87 +- 2.05
NCI109, knn (AUCROC): 69.36 +- 1.16
**************************************************
NCI109, rbf (acc): 68.37 +- 1.04
NCI109, rbf (balanced acc): 68.40 +- 1.04
NCI109, rbf (f1): 67.11 +- 1.64
NCI109, rbf (precision): 70.45 +- 1.54
NCI109, rbf (recall): 64.21 +- 3.36
NCI109, rbf (AUCROC): 68.40 +- 1.04
**************************************************
PTC_FM


Convert graph to singleton: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 349/349 [00:00<00:00, 16027.55it/s]

class proportions [0.59025788 0.40974212]





PTC_FM, knn (acc): 56.86 +- 3.77
PTC_FM, knn (balanced acc): 54.14 +- 3.71
PTC_FM, knn (f1): 42.12 +- 5.62
PTC_FM, knn (precision): 47.63 +- 5.32
PTC_FM, knn (recall): 38.28 +- 7.31
PTC_FM, knn (AUCROC): 54.14 +- 3.71
**************************************************
PTC_FM, rbf (acc): 62.57 +- 3.71
PTC_FM, rbf (balanced acc): 56.19 +- 3.86
PTC_FM, rbf (f1): 28.58 +- 11.60
PTC_FM, rbf (precision): 65.43 +- 29.53
PTC_FM, rbf (recall): 18.97 +- 8.34
PTC_FM, rbf (AUCROC): 56.19 +- 3.86
**************************************************
PTC_FR


Convert graph to singleton: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 351/351 [00:00<00:00, 14238.74it/s]

class proportions [0.65527066 0.34472934]





PTC_FR, knn (acc): 65.63 +- 5.16
PTC_FR, knn (balanced acc): 56.10 +- 4.89
PTC_FR, knn (f1): 34.27 +- 8.41
PTC_FR, knn (precision): 50.10 +- 13.54
PTC_FR, knn (recall): 26.67 +- 7.26
PTC_FR, knn (AUCROC): 56.10 +- 4.89
**************************************************
PTC_FR, rbf (acc): 68.59 +- 2.19
PTC_FR, rbf (balanced acc): 54.15 +- 3.19
PTC_FR, rbf (f1): 16.37 +- 10.83
PTC_FR, rbf (precision): 70.14 +- 38.47
PTC_FR, rbf (recall): 9.58 +- 6.73
PTC_FR, rbf (AUCROC): 54.15 +- 3.19
**************************************************
PTC_MM


Convert graph to singleton: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 336/336 [00:00<00:00, 7306.16it/s]

class proportions [0.61607143 0.38392857]





PTC_MM, knn (acc): 61.47 +- 5.42
PTC_MM, knn (balanced acc): 57.01 +- 5.72
PTC_MM, knn (f1): 42.66 +- 8.90
PTC_MM, knn (precision): 49.87 +- 9.51
PTC_MM, knn (recall): 38.08 +- 9.95
PTC_MM, knn (AUCROC): 57.01 +- 5.72
**************************************************
PTC_MM, rbf (acc): 65.15 +- 3.66
PTC_MM, rbf (balanced acc): 57.87 +- 4.40
PTC_MM, rbf (f1): 36.55 +- 9.94
PTC_MM, rbf (precision): 59.18 +- 11.43
PTC_MM, rbf (recall): 26.92 +- 8.60
PTC_MM, rbf (AUCROC): 57.87 +- 4.40
**************************************************
PTC_MR


Convert graph to singleton: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:00<00:00, 14291.49it/s]

class proportions [0.55813953 0.44186047]





PTC_MR, knn (acc): 58.12 +- 5.16
PTC_MR, knn (balanced acc): 56.68 +- 5.13
PTC_MR, knn (f1): 48.49 +- 6.72
PTC_MR, knn (precision): 52.45 +- 7.67
PTC_MR, knn (recall): 45.67 +- 8.31
PTC_MR, knn (AUCROC): 56.68 +- 5.13
**************************************************
PTC_MR, rbf (acc): 58.99 +- 5.46
PTC_MR, rbf (balanced acc): 55.79 +- 5.14
PTC_MR, rbf (f1): 37.85 +- 13.19
PTC_MR, rbf (precision): 61.42 +- 17.97
PTC_MR, rbf (recall): 31.33 +- 14.77
PTC_MR, rbf (AUCROC): 55.79 +- 5.14
**************************************************
DD


Convert graph to singleton: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 1178/1178 [00:00<00:00, 13729.16it/s]

class proportions [0.58658744 0.41341256]





DD, knn (acc): 77.63 +- 2.34
DD, knn (balanced acc): 76.42 +- 2.36
DD, knn (f1): 71.99 +- 2.86
DD, knn (precision): 75.11 +- 3.75
DD, knn (recall): 69.29 +- 3.80
DD, knn (AUCROC): 76.42 +- 2.36
**************************************************
DD, rbf (acc): 78.18 +- 1.94
DD, rbf (balanced acc): 77.36 +- 2.05
DD, rbf (f1): 73.38 +- 2.52
DD, rbf (precision): 74.46 +- 3.22
DD, rbf (recall): 72.55 +- 4.41
DD, rbf (AUCROC): 77.36 +- 2.05
**************************************************
ENZYMES


Convert graph to singleton: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 600/600 [00:00<00:00, 18112.47it/s]

class proportions [0.16666667 0.16666667 0.16666667 0.16666667 0.16666667 0.16666667]





ValueError: Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted'].