In [1]:
import pandas as pd

pd.set_option('display.max_colwidth', 200)
pd.set_option('display.width', 200)
df = pd.read_csv("../result/gene_classification_gnn_hp_tune.csv")

### Define default and optimal parameters for tuning

The tuning procedure contains two main stages

- Stage 1: Tuning the GNN architecture (hidden dimension, number of layers, residual connection)
    - Default training parameters: `lr = 1e-3`, `dropout = 0.1`, `weight_decay = 1e-5`
- Stage 2: Tuning the training parameters (learning rate, dropout rate, weight decay)
    - From step1, find the optimal GNN architectures that result in optimal validation
    performance across different networks and datasets and fix these settings.

In [2]:
# Default training parameters
default_lr = "1e-03"
default_dropout = "0.1"
default_weight_decay = "1e-05"

# Optimal GCN architecture
optim_gcn_dim = "128"
optim_gcn_num_layers = "5"
optim_gcn_residual = "1"

# Optimal SAGE architecture
optim_sage_dim = "128"
optim_sage_num_layers = "5"
optim_sage_residual = "0"

In [3]:
def print_best(
    name,
    df,
    sep_width=100,
    top_k=10,
    sortby="Validation score",
):  
    df_top = (
        df
        .sort_values("Validation score", ascending=False)
        .head(top_k)
        .drop("RunID", axis=1)
        .reset_index(drop=True)
    )
    optim_settings = df_top.iloc[0]["Settings"]
    
    name_str = "_".join(name) if isinstance(name, tuple) else name
    print(name_str)
    print("=" * sep_width)
    print(f"Optimal settings: {optim_settings}")
    print(df_top)
    print("-" * sep_width)
    print()
    
    
def summarize(df):
    return (
        df
        .groupby(["Network", "Method", "Settings", "Task"], as_index=False)
        .median()  # Take the median across different runs
        .groupby(["Method", "Settings"], as_index=False)
        .median()  # Take the median across all networks and datsets
    )

## GNN architecture tuning

In [4]:
# GNN architecture tuning results (fixed default training params)
default_settings = f"_lr={default_lr}_dropout={default_dropout}_weight-decay={default_weight_decay}"

df_architecture = df[df["Settings"].str.endswith(default_settings)].reset_index(drop=True).copy()
df_architecture["Settings"] = df_architecture["Settings"].str.replace(default_settings, "", regex=False)
df_architecture

Unnamed: 0,Training score,Validation score,Testing score,Task,Dataset,Network,Method,Settings,RunID
0,0.492660,0.850263,0.134878,GO:0006650,GOBP,HumanBaseTop-global,sage,dim=128_num-layers=4_residual=1,2
1,2.105350,0.501582,2.799723,GO:0006402,GOBP,HumanBaseTop-global,sage,dim=128_num-layers=4_residual=1,2
2,0.211608,0.638999,0.880677,GO:0010498,GOBP,HumanBaseTop-global,sage,dim=128_num-layers=4_residual=1,2
3,0.450761,0.338191,0.662027,GO:0006979,GOBP,HumanBaseTop-global,sage,dim=128_num-layers=4_residual=1,2
4,0.791923,0.699040,1.629855,GO:0097190,GOBP,HumanBaseTop-global,sage,dim=128_num-layers=4_residual=1,2
...,...,...,...,...,...,...,...,...,...
62779,0.182574,0.189197,0.164578,DOID:12798,DisGeNet,HumanBaseTop-global,gcn,dim=16_num-layers=3_residual=0,3
62780,0.095949,0.631950,1.070099,DOID:0050889,DisGeNet,HumanBaseTop-global,gcn,dim=16_num-layers=3_residual=0,3
62781,0.320472,0.683901,1.090627,DOID:898,DisGeNet,HumanBaseTop-global,gcn,dim=16_num-layers=3_residual=0,3
62782,0.856812,-0.174031,0.504259,DOID:10126,DisGeNet,HumanBaseTop-global,gcn,dim=16_num-layers=3_residual=0,3


In [5]:
# Optimal GNN architectures across all networks and tasks
for method, g in summarize(df_architecture).groupby("Method"):
    print_best(method, g.drop("Method", axis=1))

gcn
Optimal settings: dim=128_num-layers=5_residual=1
                          Settings  Training score  Validation score  Testing score
0  dim=128_num-layers=5_residual=1        0.927532          0.525010       0.582700
1  dim=128_num-layers=4_residual=1        0.884140          0.523634       0.595597
2   dim=64_num-layers=5_residual=1        0.750898          0.518476       0.569804
3   dim=64_num-layers=4_residual=1        0.762423          0.497992       0.562140
4  dim=128_num-layers=3_residual=1        0.751208          0.494621       0.546420
5  dim=128_num-layers=5_residual=0        0.688257          0.472200       0.485478
6   dim=32_num-layers=5_residual=1        0.640385          0.462043       0.476116
7  dim=128_num-layers=4_residual=0        0.697047          0.455927       0.465337
8   dim=64_num-layers=3_residual=1        0.702701          0.451847       0.572226
9   dim=32_num-layers=4_residual=1        0.616940          0.436825       0.502254
----------------------

In [6]:
# Optimal GNN architecture specific to netowork and task
if True:  # switch to True to enable
    grouped = (
        df_architecture
        .groupby(["Network", "Method", "Dataset", "Task", "Settings"], as_index=False)
        .median()  # Take the median across different runs
        .groupby(["Network", "Method", "Dataset", "Settings"], as_index=False)
        .median()  # Take the median across all networks and datsets
        .groupby(["Method", "Network", "Dataset"])
    )

    for method, g in grouped:
        print_best(
            method,
            g.drop(["Method", "Network", "Dataset"], axis=1),
            top_k=10,
        )

gcn_HumanBase-global_DisGeNet
Optimal settings: dim=16_num-layers=3_residual=1
                          Settings  Training score  Validation score  Testing score
0   dim=16_num-layers=3_residual=1        0.414677          0.208876       0.234782
1   dim=16_num-layers=5_residual=1        0.308689          0.192509       0.213903
2   dim=32_num-layers=3_residual=1        0.335791          0.172183       0.211414
3   dim=64_num-layers=3_residual=1        0.268951          0.163877       0.274138
4  dim=128_num-layers=3_residual=1        0.031681          0.163488       0.261805
5   dim=64_num-layers=4_residual=1        0.279045          0.162812       0.250758
6   dim=16_num-layers=5_residual=0        0.413477          0.160257       0.272812
7   dim=32_num-layers=5_residual=0        0.422808          0.159230       0.272519
8  dim=128_num-layers=3_residual=0        0.412031          0.158366       0.273257
9   dim=16_num-layers=4_residual=0        0.423311          0.156767       0.2498

## Training parameters tuning

In [7]:
# GCN training parameter tuning results (fixed optimal architecture)
gcn_default_settings = f"dim={optim_gcn_dim}_num-layers={optim_gcn_num_layers}_residual={optim_gcn_residual}_"
sage_default_settings = f"dim={optim_sage_dim}_num-layers={optim_sage_num_layers}_residual={optim_sage_residual}_"

df_params_gcn = df[(df["Method"] == "gcn") & df["Settings"].str.startswith(gcn_default_settings)]
df_params_sage = df[(df["Method"] == "sage") & df["Settings"].str.startswith(sage_default_settings)]

df_params = pd.concat((df_params_gcn, df_params_sage)).reset_index(drop=True).copy()
df_params["Settings"] = (
    df_params["Settings"]
    .str.replace(gcn_default_settings, "", regex=False)
    .str.replace(sage_default_settings, "", regex=False)
)
df_params

Unnamed: 0,Training score,Validation score,Testing score,Task,Dataset,Network,Method,Settings,RunID
0,0.441580,0.252272,-0.138474,GO:0006650,GOBP,HumanBaseTop-global,gcn,lr=1e-04_dropout=0.3_weight-decay=1e-05,2
1,1.112974,0.416632,1.227576,GO:0006402,GOBP,HumanBaseTop-global,gcn,lr=1e-04_dropout=0.3_weight-decay=1e-05,2
2,0.418688,-0.191268,0.612162,GO:0010498,GOBP,HumanBaseTop-global,gcn,lr=1e-04_dropout=0.3_weight-decay=1e-05,2
3,0.671823,0.539591,0.165061,GO:0006979,GOBP,HumanBaseTop-global,gcn,lr=1e-04_dropout=0.3_weight-decay=1e-05,2
4,0.039112,0.560334,-0.188806,GO:0097190,GOBP,HumanBaseTop-global,gcn,lr=1e-04_dropout=0.3_weight-decay=1e-05,2
...,...,...,...,...,...,...,...,...,...
125563,3.271680,1.041323,-0.005436,DOID:12798,DisGeNet,HumanBaseTop-global,sage,lr=1e-03_dropout=0.1_weight-decay=1e-05,3
125564,3.703126,1.555133,-0.354161,DOID:0050889,DisGeNet,HumanBaseTop-global,sage,lr=1e-03_dropout=0.1_weight-decay=1e-05,3
125565,3.701968,1.312620,1.575077,DOID:898,DisGeNet,HumanBaseTop-global,sage,lr=1e-03_dropout=0.1_weight-decay=1e-05,3
125566,4.556964,3.318602,0.981964,DOID:10126,DisGeNet,HumanBaseTop-global,sage,lr=1e-03_dropout=0.1_weight-decay=1e-05,3


In [8]:
# Optimal training parameters across all networks and tasks
for method, g in summarize(df_params).groupby("Method"):
    print_best(method, g.drop("Method", axis=1))

gcn
Optimal settings: lr=1e-02_dropout=0.0_weight-decay=1e-07
                                  Settings  Training score  Validation score  Testing score
0  lr=1e-02_dropout=0.0_weight-decay=1e-07        1.436468          0.666042       0.770542
1  lr=1e-02_dropout=0.1_weight-decay=1e-05        1.308662          0.661572       0.749599
2  lr=1e-02_dropout=0.3_weight-decay=1e-06        1.482503          0.653373       0.674426
3  lr=1e-02_dropout=0.1_weight-decay=1e-07        1.284914          0.648079       0.701588
4  lr=1e-02_dropout=0.0_weight-decay=1e-06        1.557433          0.646852       0.676283
5  lr=1e-02_dropout=0.1_weight-decay=1e-06        1.345290          0.645083       0.822977
6  lr=1e-02_dropout=0.3_weight-decay=1e-05        1.163370          0.623447       0.649762
7  lr=1e-02_dropout=0.3_weight-decay=1e-07        1.302024          0.611894       0.723543
8  lr=1e-02_dropout=0.0_weight-decay=1e-05        1.182503          0.596668       0.665247
9  lr=1e-03_dropou

In [9]:
# Optimal training parameters specific to netowork and task
if True:  # switch to True to enable
    grouped = (
        df_params
        .groupby(["Network", "Method", "Dataset", "Task", "Settings"], as_index=False)
        .median()  # Take the median across different runs
        .groupby(["Network", "Method", "Dataset", "Settings"], as_index=False)
        .median()  # Take the median across all networks and datsets
        .groupby(["Method", "Network", "Dataset"])
    )

    for method, g in grouped:
        print_best(
            method,
            g.drop(["Method", "Network", "Dataset"], axis=1),
            top_k=10,
        )

gcn_HumanBase-global_DisGeNet
Optimal settings: lr=1e-02_dropout=0.1_weight-decay=1e-06
                                  Settings  Training score  Validation score  Testing score
0  lr=1e-02_dropout=0.1_weight-decay=1e-06        0.724978          0.262921       0.420333
1  lr=1e-02_dropout=0.1_weight-decay=1e-07        0.704256          0.248826       0.436120
2  lr=1e-02_dropout=0.0_weight-decay=1e-07        0.701434          0.247944       0.435271
3  lr=1e-02_dropout=0.0_weight-decay=1e-06        0.669217          0.233586       0.397931
4  lr=1e-01_dropout=0.1_weight-decay=1e-07        0.463783          0.210342       0.110865
5  lr=1e-02_dropout=0.1_weight-decay=1e-05        0.638662          0.208608       0.372844
6  lr=1e-02_dropout=0.3_weight-decay=1e-07        0.653443          0.199543       0.398563
7  lr=1e-01_dropout=0.1_weight-decay=1e-05        0.409466          0.196484       0.243173
8  lr=1e-02_dropout=0.0_weight-decay=1e-05        0.612657          0.187597       0