In [7]:
import pkasolver as ps
from pkasolver import util
from pkasolver import analysis
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.data import DataLoader
import random

def mol_to_pyg(prot):
    """Take protonated molecules and return a Pytorch Geometric Data object."""
    i = 0
    num_atoms = prot.GetNumAtoms()
    nodes = []
    edges = []
    edges_attr = []

    for mol in [prot]:

        # ComputeGasteigerCharges(mol)

        for atom in mol.GetAtoms():
            nodes.append(
                np.array(
                    [
                        #atom.GetIdx() + num_atoms * i,
                        #float(atom.GetProp("_GasteigerCharge")),
                        atom.GetSymbol() == "C",
                        atom.GetSymbol() == "O",
                        atom.GetSymbol() == "N",
                        atom.GetSymbol() == "P",
                        atom.GetSymbol() == "F",
                        atom.GetSymbol() == "Cl",
                        atom.GetSymbol() == "I",
                        atom.GetFormalCharge(),
                        atom.GetChiralTag(),
                        atom.GetHybridization(),
                        atom.GetNumExplicitHs(),
                        atom.GetIsAromatic(),
                        atom.GetTotalValence(),
                        atom.GetTotalDegree()
                    ]
                )
            )

        for bond in mol.GetBonds():
            edges.append(
                np.array(
                    [
                        [bond.GetBeginAtomIdx() + num_atoms * i],
                        [bond.GetEndAtomIdx() + num_atoms * i],
                    ]
                )
            )
            bond_type = [bond.GetBondTypeAsDouble(), bond.GetIsConjugated()]
            edges_attr.append(bond_type)

        i += 1

    X = torch.tensor(np.array([np.array(xi) for xi in nodes]), dtype=torch.float)
    edge_index = torch.tensor(np.hstack(np.array(edges)), dtype=torch.long)
    edge_attr = torch.tensor(np.array(edges_attr), dtype=torch.float)

    return Data(x=X, edge_index=edge_index, edge_attr=edge_attr).to(device=device)

In [8]:
data_folder_Bal = "../data/Baltruschat/"
SDFfile1 = data_folder_Bal + "combined_training_datasets_unique.sdf"
SDFfile2 = data_folder_Bal + "novartis_cleaned_mono_unique_notraindata.sdf"
SDFfile3 = data_folder_Bal + "AvLiLuMoVe_cleaned_mono_unique_notraindata.sdf"
# specify device
device = 'cpu'
#device = 'cuda'


df1 = ps.util.import_sdf(SDFfile1)
df2 = ps.util.import_sdf(SDFfile2)
df3 = ps.util.import_sdf(SDFfile3)

#Data corrections:
df1.marvin_atom[90] = "3"

df1 = util.conjugates_to_DataFrame(df1)
df1 = util.sort_conjugates(df1)
df1 = util.pka_to_ka(df1)
print(df1.head())

    pKa marvin_pKa marvin_atom marvin_pKa_type original_dataset       ID  \
0  6.21       6.09          10           basic     ['chembl25']  1702768   
1  7.46        8.2           9           basic     ['chembl25']   273537   
2   4.2       3.94           9           basic  ['datawarrior']     7175   
3  3.73       5.91           8          acidic  ['datawarrior']      998   
4  11.0       8.94          13           basic     ['chembl25']   560562   

                                  smiles  \
0     Brc1c(NC2CC2)nc(C2CC2)nc1N1CCCCCC1   
1      Brc1cc(Br)c(NC2=[NH+]CCN2)c(Br)c1   
2                 Brc1cc2cccnc2c2ncccc12   
3                Brc1ccc(-c2nn[n-]n2)cc1   
4  Brc1ccc(Br)c(N(CC2CC2)C2=[NH+]CCN2)c1   

                                          protonated  \
0  <img data-content="rdkit/molecule" src="data:i...   
1  <img data-content="rdkit/molecule" src="data:i...   
2  <img data-content="rdkit/molecule" src="data:i...   
3  <img data-content="rdkit/molecule" src="data:i...  

In [9]:
#create pyG Dataset
dataset = []
for i in range(len(df1.index)):
    dataset.append(mol_to_pyg(df1.protonated[i]))
    dataset[i].y = torch.tensor([float(df1.pKa[i])], dtype=torch.float32, device=device)
print(dataset[0], '\n\n' ,dataset[0].x,'\n\n', dataset[0].edge_index)

Data(edge_attr=[24, 2], edge_index=[2, 24], x=[21, 14], y=[1]) 

 tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 3., 0., 0., 3., 3.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 1., 4., 3.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 1., 4., 3.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 1., 4., 3.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 3., 1., 1., 4., 3.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 1., 4., 3.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 3., 0., 1., 3., 2.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [1

In [10]:
print(dataset[0], '\n\n' ,dataset[0].x,'\n\n', dataset[0].edge_index,'\n\n', dataset[0].edge_attr)

Data(edge_attr=[24, 2], edge_index=[2, 24], x=[21, 14], y=[1]) 

 tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 3., 0., 0., 3., 3.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 1., 4., 3.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 1., 4., 3.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 1., 4., 3.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 3., 1., 1., 4., 3.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 3., 0., 1., 4., 3.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 3., 0., 1., 3., 2.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 4., 0., 0., 4., 4.],
        [1

In [11]:
#set Hyperparameters
train_test_split = 0.8
hidden_channels = 64
learning_rate = 0.001
batch_size = 64
num_epochs = 10000

#split train and test set

random.shuffle(dataset)

split_length=int(len(dataset)*train_test_split)
train_dataset = dataset[:split_length]
test_dataset = dataset[split_length:]
#create Dataloader objects that contain batches 

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_geometric.nn import GCNConv
from torch_geometric.nn import NNConv
from torch_geometric.nn import GraphConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import global_max_pool
from torch_geometric.nn import global_add_pool
from torch import optim

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(1)
        
        num_features = dataset[0].num_features
        num_edge_features = dataset[0].num_edge_features
        
        
        nn = Seq(Lin(num_edge_features, 16), ReLU(), Lin(16, dataset[0].num_node_features* 96))
        self.conv1 = NNConv(dataset[0].num_node_features, 96, nn=nn)
        nn = Seq(Lin(num_edge_features, 16), ReLU(), Lin(16, 96* hidden_channels))
        self.conv2 = NNConv(96, hidden_channels, nn=nn)
        nn = Seq(Lin(num_edge_features, 16), ReLU(), Lin(16, hidden_channels* hidden_channels))
        self.conv3 = NNConv(hidden_channels, hidden_channels, nn=nn)
        self.conv4 = NNConv(hidden_channels, hidden_channels, nn=nn)
        self.lin = Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch, edge_attr):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_attr)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_attr)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_attr)
        x = x.relu()
        x = self.conv4(x, edge_index, edge_attr)
        x = x.relu()
        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

model = GCN(hidden_channels=hidden_channels).to(device=device)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.MSELoss()
criterion_v = torch.nn.L1Loss() # that's the MAE Loss
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=True)

from scipy.stats.stats import pearsonr

def train(loader):
    model.train()
    for data in loader:  # Iterate in batches over the training dataset.
        out = model(data.x, data.edge_index, data.batch,  data.edge_attr)  # Perform a single forward pass.
        loss = criterion(out.flatten(), data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test(loader):
    model.eval()
    loss = torch.Tensor([0]).to(device=device)
    for data in loader:  # Iterate in batches over the training dataset.
        out = model(data.x, data.edge_index, data.batch, data.edge_attr) 
        loss += criterion_v(out.flatten(), data.y)
    return loss/len(loader) # MAE loss of batches can be summed and divided by the number of batches

for epoch in range(1, num_epochs+1):
    train(train_loader)
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    if epoch % 10 == 0:
        print(f'Epoch: {epoch:03d}, Train MAE: {train_acc.item():.4f}, Test MAE: {test_acc.item():.4f}')

GCN(
  (conv1): NNConv(14, 96)
  (conv2): NNConv(96, 64)
  (conv3): NNConv(64, 64)
  (conv4): NNConv(64, 64)
  (lin): Linear(in_features=64, out_features=1, bias=True)
)
Epoch: 010, Train MAE: 1.5370, Test MAE: 1.5334
Epoch: 020, Train MAE: 1.5033, Test MAE: 1.4953
Epoch: 030, Train MAE: 1.4335, Test MAE: 1.4479
Epoch: 040, Train MAE: 1.1432, Test MAE: 1.2002
Epoch: 050, Train MAE: 1.0085, Test MAE: 1.0972
Epoch: 060, Train MAE: 1.0152, Test MAE: 1.1110
Epoch: 070, Train MAE: 0.9228, Test MAE: 1.0448
Epoch: 080, Train MAE: 1.0349, Test MAE: 1.1487
Epoch: 090, Train MAE: 0.8310, Test MAE: 0.9904
Epoch: 100, Train MAE: 0.8498, Test MAE: 0.9978
Epoch: 110, Train MAE: 0.8452, Test MAE: 1.0068
Epoch: 120, Train MAE: 0.8988, Test MAE: 1.0765
Epoch: 130, Train MAE: 0.7878, Test MAE: 1.0059
Epoch: 140, Train MAE: 0.7541, Test MAE: 0.9789
Epoch: 150, Train MAE: 0.7173, Test MAE: 0.9762
Epoch: 160, Train MAE: 0.7202, Test MAE: 0.9901
Epoch: 170, Train MAE: 0.7000, Test MAE: 0.9457
Epoch: 180, Tr