# Set up LogME for Strategies for Pre-training GNNs (Hu et al., ICLR 2020)

In [2]:
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim

from loader import MoleculeDataset
from model import GNN, GNN_graphpred
from splitters import scaffold_split


seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else torch.device("cpu"))
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

## Load Pre-trained GNN model and downstream task (BBBP)

### Set up dataset

In [37]:
dataset_name = "bbbp"
num_tasks = 1
batch_size = 32 # default
num_workers = 4 # default


dataset = MoleculeDataset("dataset/" + dataset_name, dataset=dataset_name)
print(dataset)

# scaffold split:
smiles_list = pd.read_csv('dataset/' + dataset_name + '/processed/smiles.csv', header=None)[0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
print(train_dataset[0])

bbbp_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers)
bbbp_val_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers = num_workers)
bbbp_test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers = num_workers)


MoleculeDataset(2039)


[18:28:36] Conflicting single bond directions around double bond at index 1.
[18:28:36]   BondStereo set to STEREONONE and single bond directions set to NONE.


Data(edge_attr=[46, 2], edge_index=[2, 46], id=[1], x=[23, 2], y=[1])


### Set up model

In [59]:
num_layer = 5 # default
emb_dim = 300 # default
JK = 'last' # default (how the node features across laysers are combined)
dropout_ratio = 0.5 # default
graph_pooling = 'mean' # default
gnn_type = 'gin' # default

##########################
input_model_file = './model_gin/supervised.pth'

gin_supervised_model = GNN_graphpred(num_layer, emb_dim, num_tasks, JK = JK, drop_ratio = dropout_ratio, graph_pooling = graph_pooling, gnn_type = gnn_type)
gin_supervised_model.from_pretrained(input_model_file)

gin_supervised_model.to(device)
gin_supervised_model.eval()

###########################

input_model_file = './model_gin/supervised_infomax.pth'

gin_supervised_infomax_model = GNN_graphpred(num_layer, emb_dim, num_tasks, JK = JK, drop_ratio = dropout_ratio, graph_pooling = graph_pooling, gnn_type = gnn_type)
gin_supervised_infomax_model.from_pretrained(input_model_file)

gin_supervised_infomax_model.to(device)
gin_supervised_infomax_model.eval()

###########################

input_model_file = './model_gin/supervised_edgepred.pth'

gin_supervised_edgepred_model = GNN_graphpred(num_layer, emb_dim, num_tasks, JK = JK, drop_ratio = dropout_ratio, graph_pooling = graph_pooling, gnn_type = gnn_type)
gin_supervised_edgepred_model.from_pretrained(input_model_file)

gin_supervised_edgepred_model.to(device)
gin_supervised_edgepred_model.eval()

###########################

input_model_file = './model_gin/supervised_masking.pth'

gin_supervised_masking_model = GNN_graphpred(num_layer, emb_dim, num_tasks, JK = JK, drop_ratio = dropout_ratio, graph_pooling = graph_pooling, gnn_type = gnn_type)
gin_supervised_masking_model.from_pretrained(input_model_file)

gin_supervised_masking_model.to(device)
gin_supervised_masking_model.eval()

###########################

input_model_file = './model_gin/supervised_contextpred.pth'

gin_supervised_contextpred_model = GNN_graphpred(num_layer, emb_dim, num_tasks, JK = JK, drop_ratio = dropout_ratio, graph_pooling = graph_pooling, gnn_type = gnn_type)
gin_supervised_contextpred_model.from_pretrained(input_model_file)

gin_supervised_contextpred_model.to(device)
gin_supervised_contextpred_model.eval()

GNN_graphpred(
  (gnn): GNN(
    (x_embedding1): Embedding(120, 300)
    (x_embedding2): Embedding(3, 300)
    (gnns): ModuleList(
      (0): GINConv(
        (mlp): Sequential(
          (0): Linear(in_features=300, out_features=600, bias=True)
          (1): ReLU()
          (2): Linear(in_features=600, out_features=300, bias=True)
        )
        (edge_embedding1): Embedding(6, 300)
        (edge_embedding2): Embedding(3, 300)
      )
      (1): GINConv(
        (mlp): Sequential(
          (0): Linear(in_features=300, out_features=600, bias=True)
          (1): ReLU()
          (2): Linear(in_features=600, out_features=300, bias=True)
        )
        (edge_embedding1): Embedding(6, 300)
        (edge_embedding2): Embedding(3, 300)
      )
      (2): GINConv(
        (mlp): Sequential(
          (0): Linear(in_features=300, out_features=600, bias=True)
          (1): ReLU()
          (2): Linear(in_features=600, out_features=300, bias=True)
        )
        (edge_embedding1): E

## Extract features of BBBP training set using pre-trained GNN & obtain LogME score

In [50]:
def get_graph_features_labels(loader, model, seed):
    """Extract graph features.
    Args:
        loader: graph dataloader
        model: feature extractor
        seed: integer value for setting random seed
    Returns:
        all_graph_features: list of all graph features from the dataloader 
        all_graph_labels: list of all graph (true) labels from the dataloader
    """
    
    torch.manual_seed(seed)
    np.random.seed(seed)

    all_graph_features = []
    all_graph_labels = []
    for step, batch in enumerate(loader):
        batch = batch.to(device)
        x = batch.x
        edge_index = batch.edge_index
        edge_attr = batch.edge_attr
        y = batch.y
        batch = batch.batch

        # code from GNN_graphpred.forward() #
        node_representation = model.gnn(x, edge_index, edge_attr)
#         print("step {} | node_representation.shape: {}".format(step, node_representation.shape))

        graph_features = model.pool(node_representation, batch)
#         print("graph feature (shape): {}".format(graph_features.shape))
        all_graph_features.extend(graph_features.cpu().detach().numpy())
#         print("len(all_graph_features): {}".format(len(all_graph_features)))

        graph_preds = model.graph_pred_linear(graph_features)
    #     print("graph predictions: {}".format(graph_preds))
    #     print("graph predictions (shape): {}".format(graph_preds.shape))
    #     y_true = y.view(graph_preds.shape)
    #     print("y true: {}".format(y_true))
    #     print("y true (shape): {}".format(y_true.shape))
    #     print("y (shape): {}".format(y.shape))

        all_graph_labels.extend(y.cpu().detach().numpy())
#         print("len(all_graph_labels): {}".format(len(all_graph_labels)))
        
    return all_graph_features, all_graph_labels

In [61]:
from LogME import LogME

## using GIN supervised model ('./model_gin/supervised.pth') ## 

bbbp_graph_features, bbbp_graph_labels = get_graph_features_labels(bbbp_train_loader, gin_supervised_model, seed)
logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bbbp_graph_features), np.array(bbbp_graph_labels))
print("\n==============")
print("logme score (GIN supervised.pth): {}".format(score))

#############################

## using GIN supervised infomax model (supervised_infomax.pth) ## 

bbbp_graph_features, bbbp_graph_labels = get_graph_features_labels(bbbp_train_loader, gin_supervised_infomax_model, seed)
logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bbbp_graph_features), np.array(bbbp_graph_labels))
print("\n==============")
print("logme score (GIN supervised_infomax.pth): {}".format(score))

#############################

## using GIN supervised edgepred model (supervised_edgepred.pth) ## 

bbbp_graph_features, bbbp_graph_labels = get_graph_features_labels(bbbp_train_loader, gin_supervised_edgepred_model, seed)
logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bbbp_graph_features), np.array(bbbp_graph_labels))
print("\n==============")
print("logme score (GIN supervised_edgepred.pth): {}".format(score))

#############################

## using GIN supervised masking model (supervised_masking.pth) ## 

bbbp_graph_features, bbbp_graph_labels = get_graph_features_labels(bbbp_train_loader, gin_supervised_masking_model, seed)
logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bbbp_graph_features), np.array(bbbp_graph_labels))
print("\n==============")
print("logme score (GIN supervised_masking.pth): {}".format(score))

#############################

## using GIN supervised contextpred model (supervised_contextpred.pth) ## 

bbbp_graph_features, bbbp_graph_labels = get_graph_features_labels(bbbp_train_loader, gin_supervised_contextpred_model, seed)
logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bbbp_graph_features), np.array(bbbp_graph_labels))
print("\n==============")
print("logme score (GIN supervised_contextpred.pth): {}".format(score))


logme score (GIN supervised.pth): 3.983138708929339

logme score (GIN supervised_infomax.pth): 3.904801580066735

logme score (GIN supervised_edgepred.pth): 3.9701614855516447

logme score (GIN supervised_masking.pth): 3.9627838426304236

logme score (GIN supervised_contextpred.pth): 3.9576596769514647


## Extract features of BACE training set using pre-trained GNN & obtain LogME score

In [62]:
dataset_name = "bace"
num_tasks = 1
batch_size = 32 # default
num_workers = 4 # default

dataset = MoleculeDataset("dataset/" + dataset_name, dataset=dataset_name)
print(dataset)

# scaffold split:
smiles_list = pd.read_csv('dataset/' + dataset_name + '/processed/smiles.csv', header=None)[0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
print(train_dataset[0])

bace_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers)
bace_val_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers = num_workers)
bace_test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers = num_workers)

###########################

## using GIN supervised model ('./model_gin/supervised.pth') ## 

bace_graph_features, bace_graph_labels = get_graph_features_labels(bace_train_loader, gin_supervised_model, seed)
logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bace_graph_features), np.array(bace_graph_labels))
print("\n==============")
print("logme score (GIN supervised.pth): {}".format(score))

#############################

## using GIN supervised infomax model (supervised_infomax.pth) ## 

bace_graph_features, bace_graph_labels = get_graph_features_labels(bace_train_loader, gin_supervised_infomax_model, seed)
logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bace_graph_features), np.array(bace_graph_labels))
print("\n==============")
print("logme score (GIN supervised_infomax.pth): {}".format(score))

#############################

## using GIN supervised edgepred model (supervised_edgepred.pth) ## 

bace_graph_features, bace_graph_labels = get_graph_features_labels(bace_train_loader, gin_supervised_edgepred_model, seed)
logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bace_graph_features), np.array(bace_graph_labels))
print("\n==============")
print("logme score (GIN supervised_edgepred.pth): {}".format(score))

#############################

## using GIN supervised masking model (supervised_masking.pth) ## 

bace_graph_features, bace_graph_labels = get_graph_features_labels(bace_train_loader, gin_supervised_masking_model, seed)
logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bace_graph_features), np.array(bace_graph_labels))
print("\n==============")
print("logme score (GIN supervised_masking.pth): {}".format(score))

#############################

## using GIN supervised contextpred model (supervised_contextpred.pth) ## 

bace_graph_features, bace_graph_labels = get_graph_features_labels(bace_train_loader, gin_supervised_contextpred_model, seed)
logme = LogME(regression=False)
# f has shape of [N, D], y has shape [N]
score = logme.fit(np.array(bace_graph_features), np.array(bace_graph_labels))
print("\n==============")
print("logme score (GIN supervised_contextpred.pth): {}".format(score))

MoleculeDataset(1513)
Data(edge_attr=[66, 2], edge_index=[2, 66], fold=[1], id=[1], x=[31, 2], y=[1])

logme score (GIN supervised.pth): 3.8149495222876455

logme score (GIN supervised_infomax.pth): 3.76645414668266

logme score (GIN supervised_edgepred.pth): 3.8000357054228586

logme score (GIN supervised_masking.pth): 3.7818508909475215

logme score (GIN supervised_contextpred.pth): 3.794460846581036
