In [2]:
import pandas as pd
import numpy as np
from helper.graphfeat import StructureEncoder
from helper.preprocess import *
from helper.trainer import *
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric.data as data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GraphConv, GATConv, global_max_pool, norm
from helper.load_dataset import load_bace_classification

In [None]:
# Load and split dataset
bace = load_bace_classification()
train, valid, test = split_train_valid_test(bace, type='random')

# Generate graph features

train_dataset = generate_graph_dataset(train, 'SMILES', 'Class')
valid_dataset = generate_graph_dataset(valid, 'SMILES', 'Class')
test_dataset = generate_graph_dataset(test, 'SMILES', 'Class')

# Create data loader

BATCH_SIZE=128
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Build network
class GraConv(nn.Module):
    def __init__(self, hidden_channel=64, use_edge_weight=True):
        super(GraConv, self).__init__()
        self.use_edge_weight = use_edge_weight
        self.conv1 = GraphConv(train_dataset.num_node_features, hidden_channel, bias=True)
        self.conv2 = GraphConv(hidden_channel, hidden_channel, bias=True)
        self.conv3 = GraphConv(hidden_channel, hidden_channel, bias=True)
        self.lin1 = nn.Linear(hidden_channel, 20)
        self.lin2 = nn.Linear(20, 1)
    
    def forward(self, x, edge_index, batch, edge_weight=None):
        if self.use_edge_weight and edge_weight is None:
            edge_weight = None
        x = self.conv1(x, edge_index, edge_weight=edge_weight)
        x = F.selu(x)
        x = self.conv2(x, edge_index, edge_weight=edge_weight)
        x = F.selu(x)
        x = self.conv3(x, edge_index, edge_weight=edge_weight)
        x = F.selu(x)
        x = global_max_pool(x, batch)
        x = self.lin1(x)
        return self.lin2(x)
    
# Train model

model = GraConv(use_edge_weight=False)
history = fit_model(model, train_loader, valid_loader, epochs=200, lr=0.005, patience=15, task='classification')
print(f'Test metrics:')
print(evaluate_test(model, test_loader, task='classification'))

Epoch 1/200 stats: {'loss_train': 0.87100106, 'loss_valid': 0.7381363809108734, 'roc_auc': 0.5}
Epoch 2/200 stats: {'loss_train': 0.6487581, 'loss_valid': 0.6042793989181519, 'roc_auc': 0.72188013136289}
Epoch 3/200 stats: {'loss_train': 0.6024455, 'loss_valid': 0.5529862642288208, 'roc_auc': 0.7559523809523809}
Epoch 4/200 stats: {'loss_train': 0.55796206, 'loss_valid': 0.5498924553394318, 'roc_auc': 0.7052545155993433}
Epoch 5/200 stats: {'loss_train': 0.5690731, 'loss_valid': 0.53008833527565, 'roc_auc': 0.735016420361248}
Epoch 6/200 stats: {'loss_train': 0.52642184, 'loss_valid': 0.5022647827863693, 'roc_auc': 0.7516420361247947}
Epoch 7/200 stats: {'loss_train': 0.4992941, 'loss_valid': 0.46604542434215546, 'roc_auc': 0.8119868637110015}
Epoch 8/200 stats: {'loss_train': 0.47872862, 'loss_valid': 0.4520668089389801, 'roc_auc': 0.7801724137931034}
Epoch 9/200 stats: {'loss_train': 0.47723526, 'loss_valid': 0.40654900670051575, 'roc_auc': 0.8304597701149425}
Epoch 10/200 stats: {'l