In [2]:
import torch
print(torch.__version__, torch.version.cuda)

2.6.0+cu124 12.4


In [3]:
!pip install --quiet \
  torch-scatter     -f https://data.pyg.org/whl/torch-2.6.0+cu124.html \
  torch-sparse      -f https://data.pyg.org/whl/torch-2.6.0+cu124.html \
  torch-cluster     -f https://data.pyg.org/whl/torch-2.6.0+cu124.html \
  torch-spline-conv -f https://data.pyg.org/whl/torch-2.6.0+cu124.html \
  torch-geometric

# OGB, RDKit, LibAUC, and helpers
!pip install --quiet ogb rdkit libauc tqdm scikit-learn
!git clone https://github.com/lightaime/deep_gcns_torch.git /content/deep_gcns_torch

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m106.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m92.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m104.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m64.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.8/78.8 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.3/34.3 MB[0m [31m73.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
import os,sys, time, pathlib, logging, random
import numpy as np, pandas as pd
import copy
import torch_geometric
import torch.nn.functional as F
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
from tqdm.auto import tqdm
from rdkit import Chem,RDLogger
from rdkit.Chem import AllChem
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from torch_geometric.data import DataLoader
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from torch.serialization import safe_globals
from torch_geometric.data.data import DataEdgeAttr
from libauc.losses import AUCMLoss
from libauc.optimizers import PESG
from google.colab import drive
from rdkit.Chem import rdMolDescriptors, MACCSkeys
sys.path.append("/content/deep_gcns_torch")
RDLogger.DisableLog('rdApp.*')
drive.mount('/content/drive', force_remount=True)
project_root = pathlib.Path('/content/drive/MyDrive/MLNS')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from gcn_lib.sparse.torch_vertex import GENConv
from gcn_lib.sparse.torch_nn import norm_layer, MLP

# mimic ArgsInit
class CFG:
    dataset     = "ogbg-molhiv"
    batch_size  = 256
    lr          = 0.01
    epochs_pre  = 300
    epochs_ft   = 100
    num_workers = 4
    num_tasks = 1
    random_seed = 0
    hidden_channels  = 256
    num_layers      = 14
    dropout     = 0.5
    block       = "res+"
    gcn_aggr    = "softmax"
    t, p, y     = 1.0, 1.0, 0.0
    learn_t     = True
    learn_p     = False
    learn_y     = False
    msg_norm    = False
    learn_msg   = False
    learn_msg_scale = False
    conv_encode_edge = False
    add_virtual_node = False
    conv = 'gen'
    optimizer   = "pesg"
    gamma       = 500
    margin      = 1.0
    weight_decay= 1e-4
    activations  = "relu"  # or "elu"
    norm         = "batch" # batchnorm in GENConv
    mlp_layers   = 1       # MLP depth in each GENConv
    graph_pooling = "mean"  # could be "sum" or "max"

cfg = CFG()
torch.manual_seed(cfg.random_seed)
np.random.seed(cfg.random_seed)

Mounted at /content/drive


In [5]:
import torch
_torch_load = torch.load

def _torch_load_override(f, *args, **kwargs):
    # if weights_only wasn't explicitly set, force it to False
    if "weights_only" not in kwargs:
        kwargs["weights_only"] = False
    return _torch_load(f, *args, **kwargs)

# override torch.load globally
torch.load = _torch_load_override
from torch.serialization import add_safe_globals

# import the class that needs to be allow-listed
from torch_geometric.data.data import DataEdgeAttr

# allow it for all future torch.load calls
add_safe_globals([DataEdgeAttr])

# now this will succeed
from ogb.graphproppred import PygGraphPropPredDataset
dataset = PygGraphPropPredDataset(name=cfg.dataset)
all_labels = dataset.data.y.view(-1).numpy()

Downloading http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/hiv.zip


Downloaded 0.00 GB: 100%|██████████| 3/3 [00:01<00:00,  2.70it/s]
Processing...


Extracting dataset/hiv.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 41127/41127 [00:00<00:00, 110429.35it/s]


Converting graphs into PyG objects...


100%|██████████| 41127/41127 [00:01<00:00, 26998.77it/s]


Saving...


Done!


In [6]:
project_root = pathlib.Path('/content/drive/MyDrive/MLNS')
df = pd.read_csv(project_root / 'HIV.csv')
smiles = df['smiles'].tolist()
labels = df['activity'].map({'CI': 0, 'CM': 1}).tolist()

In [19]:
saved_data = copy.deepcopy(dataset._data)   # or dataset.data

# # … now you can safely monkey-patch torch.load or whatever without losing this …

# # 2) Whenever you want to re-attach the RF column:
# dataset._data = copy.deepcopy(saved_data)   # restore the original
# orig_y = dataset._data.y.view(-1,1)         # (N,1) true labels
# all_labels = dataset.data.y.view(-1).numpy()

torch.Size([41127, 1])

In [7]:
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors, MACCSkeys

def fp_vect(smi):
    """
    Compute Morgan and MACCS fingerprints for a SMILES string using
    RDKit's rdMolDescriptors.GetMorganFingerprintAsBitVect API.
    Returns:
      - morgan: numpy array of shape (2048,) with 0/1 entries
      - maccs: numpy array of shape (166,) with 0/1 entries
    If SMILES fails (parsing or sanitization), returns (None, None).
    """
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return None, None

    # attempt sanitization (catches valence errors, etc.)
    try:
        Chem.SanitizeMol(mol)
    except Exception:
        return None, None

    # Morgan fingerprint (radius=2, 2048 bits, no chirality)
    morgan_bv = rdMolDescriptors.GetMorganFingerprintAsBitVect(
        mol, radius=2, nBits=2048, useChirality=False
    )
    morgan = np.array(morgan_bv, dtype=int)

    # MACCS keys (166-bit)
    maccs_bv = MACCSkeys.GenMACCSKeys(mol)
    maccs = np.array(maccs_bv, dtype=int)

    return morgan, maccs

In [8]:
valid_entry_idxs = []
for i, (smi, y) in enumerate(zip(smiles, all_labels)):
    # skip missing labels
    if np.isnan(y):
        continue
    mg, mc = fp_vect(smi)
    # skip bad SMILES
    if mg is None:
        continue
    valid_entry_idxs.append(i)
valid_set = set(valid_entry_idxs)

In [9]:
split_idx = dataset.get_idx_split()
orig_train = split_idx["train"]  # numpy array of ints
orig_val   = split_idx["valid"]
orig_test  = split_idx["test"]

train_filt = np.intersect1d(orig_train, valid_entry_idxs).tolist()
val_filt   = np.intersect1d(orig_val,   valid_entry_idxs).tolist()
test_filt  = np.intersect1d(orig_test,  valid_entry_idxs).tolist()

print(len(orig_train), "→", len(train_filt),
      len(orig_val),   "→", len(val_filt),
      len(orig_test),  "→", len(test_filt))

32901 → 32898 4113 → 4111 4113 → 4111


In [10]:
morgan_feats, maccs_feats, labels_tv = [], [], []
for i in valid_entry_idxs:    # all valid entries
    mg, mc = fp_vect(smiles[i])
    morgan_feats.append(mg)
    maccs_feats.append(mc)
    labels_tv.append(int(all_labels[i]))

X_all = np.concatenate([np.stack(morgan_feats), np.stack(maccs_feats)], axis=1)
y_all = np.array(labels_tv, dtype=int)

# Get local positions of each split
idx_map   = {orig:loc for loc, orig in enumerate(valid_entry_idxs)}
train_loc = [idx_map[i] for i in train_filt]
val_loc   = [idx_map[i] for i in val_filt]
test_loc  = [idx_map[i] for i in test_filt]

X_train, y_train = X_all[train_loc], y_all[train_loc]
X_val,   y_val   = X_all[val_loc],   y_all[val_loc]
X_test,  y_test  = X_all[test_loc],  y_all[test_loc]


In [11]:
rf = RandomForestClassifier(
    n_estimators=1000, class_weight={0:1,1:10}, random_state=cfg.random_seed, n_jobs=-1)
rf.fit(X_train, y_train)
probs_all = rf.predict_proba(X_all)[:,1]

In [18]:
dataset.data.y.shape

torch.Size([41127, 1])

In [12]:
split_idx   = dataset.get_idx_split()
train_idx   = split_idx["train"]
valid_idx   = split_idx["valid"]
test_idx    = split_idx["test"]

loaders = {
    "train": DataLoader(dataset[train_idx], batch_size=cfg.batch_size, shuffle=True,  num_workers=cfg.num_workers),
    "valid": DataLoader(dataset[valid_idx], batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers),
    "test":  DataLoader(dataset[test_idx],  batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers),
}



In [22]:
N = len(dataset)    # total number of graphs in the full OGB split
rf_full = np.zeros(N, dtype=float)
for local_pos, orig_idx in enumerate(valid_entry_idxs):
    rf_full[orig_idx] = probs_all[local_pos]

orig_y = dataset.data.y.view(-1,1).clone()           # (N,1) true labels
rf_col = torch.from_numpy(rf_full).view(-1,1).float()  # (N,1) RF pos-probs

dataset.data.y = torch.cat([orig_y, rf_col], dim=1)



In [13]:
# ---- DeeperGCN.forward ----

class DeeperGCN(torch.nn.Module):
    def __init__(self, args):
        super(DeeperGCN, self).__init__()

        self.num_layers = args.num_layers
        self.dropout = args.dropout
        self.block = args.block
        self.conv_encode_edge = args.conv_encode_edge
        self.add_virtual_node = args.add_virtual_node

        hidden_channels = args.hidden_channels
        num_tasks = args.num_tasks
        conv = args.conv
        aggr = args.gcn_aggr
        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        y = args.y
        self.learn_y = args.learn_y

        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale
        self.activation_func = F.relu if args.activations=='relu' else F.elu

        norm = args.norm
        mlp_layers = args.mlp_layers

        graph_pooling = args.graph_pooling

        print('The number of layers {}'.format(self.num_layers),
              'Aggr aggregation method {}'.format(aggr),
              'block: {}'.format(self.block))
        if self.block == 'res+':
            print('LN/BN->ReLU->GraphConv->Res')
        elif self.block == 'res':
            print('GraphConv->LN/BN->ReLU->Res')
        elif self.block == 'dense':
            raise NotImplementedError('To be implemented')
        elif self.block == "plain":
            print('GraphConv->LN/BN->ReLU')
        else:
            raise Exception('Unknown block Type')

        self.gcns = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        if self.add_virtual_node:
            self.virtualnode_embedding = torch.nn.Embedding(1, hidden_channels)
            torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)

            self.mlp_virtualnode_list = torch.nn.ModuleList()

            for layer in range(self.num_layers - 1):
                self.mlp_virtualnode_list.append(MLP([hidden_channels]*3,
                                                     norm=norm))

        for layer in range(self.num_layers):
            if conv == 'gen':
                gcn = GENConv(hidden_channels, hidden_channels,
                              aggr=aggr,
                              t=t, learn_t=self.learn_t,
                              p=p, learn_p=self.learn_p,
                              y=y, learn_y=self.learn_p,
                              msg_norm=self.msg_norm, learn_msg_scale=learn_msg_scale,
                              encode_edge=self.conv_encode_edge, bond_encoder=True,
                              norm=norm, mlp_layers=mlp_layers)
            else:
                raise Exception('Unknown Conv Type')
            self.gcns.append(gcn)
            self.norms.append(norm_layer(norm, hidden_channels))

        self.atom_encoder = AtomEncoder(emb_dim=hidden_channels)

        if not self.conv_encode_edge:
            self.bond_encoder = BondEncoder(emb_dim=hidden_channels)

        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        else:
            raise Exception('Unknown Pool Type')

        self.graph_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)

    def forward(self, input_batch):

        x = input_batch.x
        edge_index = input_batch.edge_index
        edge_attr = input_batch.edge_attr
        batch = input_batch.batch

        h = self.atom_encoder(x)

        if self.add_virtual_node:
            virtualnode_embedding = self.virtualnode_embedding(
                torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))
            h = h + virtualnode_embedding[batch]

        if self.conv_encode_edge:
            edge_emb = edge_attr
        else:
            edge_emb = self.bond_encoder(edge_attr)

        if self.block == 'res+':

            h = self.gcns[0](h, edge_index, edge_emb)

            for layer in range(1, self.num_layers):
                h1 = self.norms[layer - 1](h)
                h2 = self.activation_func(h1)
                h2 = F.dropout(h2, p=self.dropout, training=self.training)

                if self.add_virtual_node:
                    virtualnode_embedding_temp = global_add_pool(h2, batch) + virtualnode_embedding
                    virtualnode_embedding = F.dropout(
                        self.mlp_virtualnode_list[layer-1](virtualnode_embedding_temp),
                        self.dropout, training=self.training)

                    h2 = h2 + virtualnode_embedding[batch]

                h = self.gcns[layer](h2, edge_index, edge_emb) + h

            h = self.norms[self.num_layers - 1](h)
            h = F.dropout(h, p=self.dropout, training=self.training)

        elif self.block == 'res':

            h = self.activation_func(self.norms[0](self.gcns[0](h, edge_index, edge_emb)))
            h = F.dropout(h, p=self.dropout, training=self.training)

            for layer in range(1, self.num_layers):
                h1 = self.gcns[layer](h, edge_index, edge_emb)
                h2 = self.norms[layer](h1)
                h = self.activation_func(h2) + h
                h = F.dropout(h, p=self.dropout, training=self.training)

        elif self.block == 'dense':
            raise NotImplementedError('To be implemented')

        elif self.block == 'plain':

            h = self.activation_func(self.norms[0](self.gcns[0](h, edge_index, edge_emb)))
            h = F.dropout(h, p=self.dropout, training=self.training)

            for layer in range(1, self.num_layers):
                h1 = self.gcns[layer](h, edge_index, edge_emb)
                h2 = self.norms[layer](h1)
                if layer != (self.num_layers - 1):
                    h = self.activation_func(h2)
                else:
                    h = h2
                h = F.dropout(h, p=self.dropout, training=self.training)
        else:
            raise Exception('Unknown block Type')

        h_graph = self.pool(h, batch) # N, 256
        #print (h_graph.shape)
        #h_graph= self.dropout_fc(h_graph)
        return self.graph_pred_linear(h_graph)

In [14]:
device = torch.device("cuda")
model = DeeperGCN(cfg).to(device)
criterion = AUCMLoss()
optimizer = PESG(model.parameters(), loss_fn=criterion,
                 a=criterion.a, b=criterion.b, alpha=criterion.alpha,
                 lr=cfg.lr, gamma=cfg.gamma, margin=cfg.margin,
                 weight_decay=cfg.weight_decay)

def train_epoch():
    model.train(); losses=[]
    for batch in loaders["train"]:
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = model(batch)
        pos = batch.y[:, 0:1]  #
        #pos = batch.y.view(-1,1).to(torch.float32)
        loss = criterion(pred, pos).to(torch.float32)
        loss.backward(); optimizer.step()
        losses.append(loss.item())
    return np.mean(losses)

@torch.no_grad()
def eval_split(phase):
    model.eval()
    ys,ps=[],[]
    for batch in loaders[phase]:
        batch = batch.to(device)
        out = model(batch)
        ys.append(batch.y.view(-1,1).cpu().numpy())
        ps.append(torch.sigmoid(out).cpu().numpy())
    y_true = np.vstack(ys); y_pred = np.vstack(ps)
    return Evaluator(cfg.dataset).eval({"y_true":y_true,"y_pred":y_pred})["rocauc"]

best_val,best_ckpt=0,None
for epoch in range(1, cfg.epochs_pre+1):
    l = train_epoch()
    val_auc = eval_split("valid")
    if val_auc>best_val:
        best_val=val_auc
        best_ckpt = f"pretrained_{epoch}.pth"
        torch.save(model.state_dict(), best_ckpt)
    if epoch%50==0:
    #if epoch%10==0:

        print(f"Epoch {epoch}: train loss {l:.4f}, val AUC {val_auc:.4f}")
print("Best pretrain AUC:", best_val)

The number of layers 14 Aggr aggregation method softmax block: res+
LN/BN->ReLU->GraphConv->Res
Epoch 50: train loss 0.0246, val AUC 0.7635




Epoch 100: train loss 0.0227, val AUC 0.7662
Epoch 150: train loss 0.0214, val AUC 0.7787




Epoch 200: train loss 0.0205, val AUC 0.7722




Epoch 250: train loss 0.0194, val AUC 0.7642
Epoch 300: train loss 0.0191, val AUC 0.7683
Best pretrain AUC: 0.7855060748579267


In [15]:
class DeeperGCNAtt(torch.nn.Module):
    def __init__(self, args):
        super(DeeperGCNAtt, self).__init__()

        self.num_layers = args.num_layers
        self.dropout = args.dropout
        self.block = args.block
        self.conv_encode_edge = args.conv_encode_edge
        self.add_virtual_node = args.add_virtual_node

        hidden_channels = args.hidden_channels
        num_tasks = args.num_tasks
        conv = args.conv
        aggr = args.gcn_aggr
        t = args.t
        self.learn_t = args.learn_t
        p = args.p
        self.learn_p = args.learn_p
        y = args.y
        self.learn_y = args.learn_y

        self.beta = torch.nn.Parameter(torch.Tensor([0.5]), requires_grad=True)

        self.msg_norm = args.msg_norm
        learn_msg_scale = args.learn_msg_scale
        self.activation_func = F.relu if args.activations=='relu' else F.elu

        norm = args.norm
        mlp_layers = args.mlp_layers

        graph_pooling = args.graph_pooling

        print('The number of layers {}'.format(self.num_layers),
              'Aggr aggregation method {}'.format(aggr),
              'block: {}'.format(self.block))
        if self.block == 'res+':
            print('LN/BN->ReLU->GraphConv->Res')
        elif self.block == 'res':
            print('GraphConv->LN/BN->ReLU->Res')
        elif self.block == 'dense':
            raise NotImplementedError('To be implemented')
        elif self.block == "plain":
            print('GraphConv->LN/BN->ReLU')
        else:
            raise Exception('Unknown block Type')

        self.gcns = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        if self.add_virtual_node:
            self.virtualnode_embedding = torch.nn.Embedding(1, hidden_channels)
            torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)

            self.mlp_virtualnode_list = torch.nn.ModuleList()

            for layer in range(self.num_layers - 1):
                self.mlp_virtualnode_list.append(MLP([hidden_channels]*3,
                                                     norm=norm))

        for layer in range(self.num_layers):
            if conv == 'gen':
                gcn = GENConv(hidden_channels, hidden_channels,
                              aggr=aggr,
                              t=t, learn_t=self.learn_t,
                              p=p, learn_p=self.learn_p,
                              y=y, learn_y=self.learn_p,
                              msg_norm=self.msg_norm, learn_msg_scale=learn_msg_scale,
                              encode_edge=self.conv_encode_edge, bond_encoder=True,
                              norm=norm, mlp_layers=mlp_layers)
            else:
                raise Exception('Unknown Conv Type')
            self.gcns.append(gcn)
            self.norms.append(norm_layer(norm, hidden_channels))

        self.atom_encoder = AtomEncoder(emb_dim=hidden_channels)

        if not self.conv_encode_edge:
            self.bond_encoder = BondEncoder(emb_dim=hidden_channels)

        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        else:
            raise Exception('Unknown Pool Type')

        self.graph_pred_linear = torch.nn.Linear(hidden_channels, num_tasks)

    def forward(self, input_batch, mode='train'):
        x = input_batch.x

        edge_index = input_batch.edge_index
        edge_attr = input_batch.edge_attr
        batch = input_batch.batch

        h = self.atom_encoder(x)

        if self.add_virtual_node:
            virtualnode_embedding = self.virtualnode_embedding(
                torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))
            h = h + virtualnode_embedding[batch]

        if self.conv_encode_edge:
            edge_emb = edge_attr
        else:
            edge_emb = self.bond_encoder(edge_attr)

        if self.block == 'res+':

            h = self.gcns[0](h, edge_index, edge_emb)

            for layer in range(1, self.num_layers):
                h1 = self.norms[layer - 1](h)
                h2 = self.activation_func(h1)
                h2 = F.dropout(h2, p=self.dropout, training=self.training)

                if self.add_virtual_node:
                    virtualnode_embedding_temp = global_add_pool(h2, batch) + virtualnode_embedding
                    virtualnode_embedding = F.dropout(
                        self.mlp_virtualnode_list[layer-1](virtualnode_embedding_temp),
                        self.dropout, training=self.training)

                    h2 = h2 + virtualnode_embedding[batch]

                h = self.gcns[layer](h2, edge_index, edge_emb) + h

            h = self.norms[self.num_layers - 1](h)
            h = F.dropout(h, p=self.dropout, training=self.training)

        elif self.block == 'res':

            h = self.activation_func(self.norms[0](self.gcns[0](h, edge_index, edge_emb)))
            h = F.dropout(h, p=self.dropout, training=self.training)

            for layer in range(1, self.num_layers):
                h1 = self.gcns[layer](h, edge_index, edge_emb)
                h2 = self.norms[layer](h1)
                h = self.activation_func(h2) + h
                h = F.dropout(h, p=self.dropout, training=self.training)

        elif self.block == 'dense':
            raise NotImplementedError('To be implemented')

        elif self.block == 'plain':

            h = self.activation_func(self.norms[0](self.gcns[0](h, edge_index, edge_emb)))
            h = F.dropout(h, p=self.dropout, training=self.training)

            for layer in range(1, self.num_layers):
                h1 = self.gcns[layer](h, edge_index, edge_emb)
                h2 = self.norms[layer](h1)
                if layer != (self.num_layers - 1):
                    h = self.activation_func(h2)
                else:
                    h = h2
                h = F.dropout(h, p=self.dropout, training=self.training)
        else:
            raise Exception('Unknown block Type')

        h_graph = self.pool(h, batch) # N, 256

        dcn_pred = self.graph_pred_linear(h_graph)
        rf_pred = input_batch.y[:, 1]
        return (1-self.beta)*torch.sigmoid(dcn_pred).reshape(-1, 1) + (self.beta) * rf_pred.reshape(-1,1)

In [24]:
finetune_model = DeeperGCNAtt(cfg).to(device)
finetune_model.load_state_dict(torch.load(best_ckpt,map_location = device), strict=False)

optimizer_ft = PESG(
    finetune_model.parameters(), loss_fn=criterion,
    a=criterion.a, b=criterion.b, alpha=criterion.alpha,
    lr=cfg.lr * 0.1, gamma=cfg.gamma, margin=cfg.margin,
    weight_decay=cfg.weight_decay
)

def train_epoch_ft():
    finetune_model.train()
    losses = []
    for batch in loaders["train"]:
        batch = batch.to(device)
        optimizer_ft.zero_grad()

        # 1) Forward pass: model automatically reads both true & rf from batch.y
        pred = finetune_model(batch, mode="train")  # (B,1)

        # 2) Extract only the true labels for computing the loss
        true = batch.y[:, 0:1].to(torch.float32)     # (B,1)

        # 3) Compute AUC-Margin loss on [true vs. fused pred]
        loss = criterion(pred, true)
        loss.backward()
        optimizer_ft.step()
        losses.append(loss.item())

    return float(np.mean(losses))


@torch.no_grad()
def eval_split_ft(phase):
    finetune_model.eval()
    ys, ps = [], []
    for batch in loaders[phase]:
        batch = batch.to(device)

        # model(batch, mode="test") again fuses GNN+RF
        out = finetune_model(batch, mode="test")    # (B,1)

        # collect true labels
        ys.append(batch.y[:, 0:1].cpu().numpy())    # (B,1)

        # collect fused predictions
        ps.append(out.cpu().numpy())                # (B,1)

    y_true = np.vstack(ys)                         # (N_batch,1)
    y_pred = np.vstack(ps)                         # (N_batch,1)
    return Evaluator(cfg.dataset).eval({
        "y_true": y_true,
        "y_pred": y_pred
    })["rocauc"]


# def train_epoch_ft():
#     finetune_model.train()
#     losses = []
#     for batch in loaders["train"]:
#         batch = batch.to(device)
#         optimizer_ft.zero_grad()
#         pred = finetune_model(batch, mode="train")     # now uses pretrained backbone
#         true = batch.y[:, 0:1].float()
#         loss = criterion(pred, true)
#         loss.backward()
#         optimizer_ft.step()
#         losses.append(loss.item())
#     return float(np.mean(losses))

# @torch.no_grad()
# def eval_split_ft(phase):
#     finetune_model.eval()
#     ys, ps = [], []
#     for batch in loaders[phase]:
#         batch = batch.to(device)
#         out = finetune_model(batch, mode="test")          # (B,1)
#         ys.append(batch.y[:, 0:1].cpu().numpy())          # (B,1)
#         ps.append(out.cpu().numpy())                      # (B,1)
#     y_true = np.vstack(ys)
#     y_pred = np.vstack(ps)
#     return Evaluator(cfg.dataset).eval({"y_true": y_true, "y_pred": y_pred})["rocauc"]

The number of layers 14 Aggr aggregation method softmax block: res+
LN/BN->ReLU->GraphConv->Res


In [25]:
best_val2, best_ckpt2 = 0.0, None

for epoch in range(1, cfg.epochs_ft + 1):
    # 1) One epoch of finetune training
    train_loss = train_epoch_ft()

    # 2) Evaluate on the held-out valid split
    val_auc = eval_split_ft("valid")

    # 3) If it’s the best so far, save the finetuned weights
    if val_auc > best_val2:
        best_val2   = val_auc
        best_ckpt2  = f"finetuned_epoch{epoch}.pth"
        torch.save(finetune_model.state_dict(), best_ckpt2)

    # 4) Log every 20 epochs
    if epoch % 20 == 0:
        print(f"[FT] Epoch {epoch:3d}: train loss {train_loss:.4f}, valid AUC {val_auc:.4f}")

# Final report
print(f"✅ Best finetune valid AUC: {best_val2:.4f} → {best_ckpt2}")



[FT] Epoch  20: train loss 0.0174, valid AUC 0.8066
[FT] Epoch  40: train loss 0.0131, valid AUC 0.8060
[FT] Epoch  60: train loss 0.0100, valid AUC 0.8067




[FT] Epoch  80: train loss 0.0076, valid AUC 0.8077
[FT] Epoch 100: train loss 0.0060, valid AUC 0.8096
✅ Best finetune valid AUC: 0.8098 → finetuned_epoch99.pth


In [26]:
finetune_model.load_state_dict(torch.load(best_ckpt2), strict=False)
finetune_model.to(device)

# 3) Now evaluate on the test split
test_auc = eval_split_ft("test")
print("✔️ Test ROC-AUC of best finetuned model:", test_auc)

✔️ Test ROC-AUC of best finetuned model: 0.8121960640414067
