In [2]:
# Import data and mutate such that each SMILES representation has one row with all recorded kinases values
import pandas as pd
from rdkit import Chem, RDLogger
import dgl
import torch
import numpy as np

df = pd.read_csv('data/kinase_JAK.csv')
kinases = ['JAK1', 'JAK2', 'JAK3', 'TYK2']
df_kinase = df.loc[df['measurement_type'] == 'pKi']
df_smiles = df_kinase.pivot(index="SMILES", columns="Kinase_name", values='measurement_value')
df_smiles = df_smiles.reset_index()
display(df_smiles)

Kinase_name,SMILES,JAK1,JAK2,JAK3,TYK2
0,Brc1cnc2[nH]cc(-c3ccccc3)c2c1,,6.20,6.3,
1,C#Cc1cc2c(cc1OC)-c1[nH]nc(-c3ccc(C#N)nc3)c1C2,,6.20,,
2,C=CC(=O)N1CC(Nc2ncnc3[nH]ccc23)CCC1C,8.20,,8.2,
3,C=CCN(CCOc1ccc(C)cc1)C1CCN(C(=O)Cn2cc(NC(=O)c3...,8.49,8.14,,
4,CC(=NNC(=N)N)c1cc(NC(=O)NCCCCCCNC(=O)Nc2cc(C(C...,,6.00,6.0,
...,...,...,...,...,...
979,c1ccc(Cn2cc(-c3ccc4[nH]ncc4c3)nn2)cc1,,6.55,6.0,
980,c1ccc(Cn2nnc(-c3ccc4[nH]ncc4c3)c2C2CC2)cc1,,6.20,,
981,c1cnc2[nH]cc(-c3ccnc(NC4CCCCC4)n3)c2c1,,6.90,6.8,6.5
982,c1cncc(CN2CCC(n3nnc4cnc5[nH]ccc5c43)CC2)c1,7.80,7.55,,


## Data Processing
The first thing that we need to do is turn each molecule into a graph representation using DGL. Each node of the graph will be an atom, and information will be recorded about that atom. In the future, I would like to perform backwards feature elimination to figure out which atom features are not necessary, but as I have a limited amount of time and haven't yet built and optimized the model I will just be including all available atom features.

Each node will be connected to another node if there is a bond between those atoms. That will represent an edge in the graph

In [5]:
# Build graph from molecule with atoms as nodes and bonds as edges. Not currently using bond features, so commented out
def mol_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    bond_inds_front = []
    bond_inds_back = []
#     bond_feats = []
    for bond in mol.GetBonds():
        atom_1 = bond.GetBeginAtomIdx()
        atom_2 = bond.GetEndAtomIdx()
        bond_inds_front += [atom_1, atom_2]
        bond_inds_back += [atom_2, atom_1]
        
#         features = []
#         features.append(bond.GetBondType())
#         features.append(bond.GetStereo())
#         features.append(bond.GetIsConjugated())

    g = dgl.graph((bond_inds_front, bond_inds_back))

    return g

# Helper method, get all atom features in a molecule
def mol_to_atom_features(smiles):
    mol = Chem.MolFromSmiles(smiles)
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_features_list.append(atom_featurizer(atom))
    return np.array(atom_features_list)

# Grabbing all item features for now
def atom_featurizer(atom):
    feats = []
    feats.append(atom.GetAtomicNum()) 
    feats.append(atom.GetDegree())
    feats.append(atom.GetTotalDegree())
    feats.append(atom.GetTotalNumHs())
    feats.append(atom.GetTotalValence())
    feats.append(atom.GetNumImplicitHs())
    feats.append(atom.GetExplicitValence())
    feats.append(atom.GetImplicitValence())
    feats.append(atom.GetNumRadicalElectrons())
    feats.append(atom.GetFormalCharge())
    feats.append(atom.GetNumExplicitHs())
    feats.append(atom.GetIsAromatic())
    feats.append(atom.IsInRing())
    feats.append(atom.GetMass())
    feats.append(atom.GetIsotope())
    feats.append(atom.GetChiralTag())
    feats.append(atom.GetHybridization())
    return feats

In [6]:
# Setup the data for each different kinase predictor
kin_graphs = {}
kin_targets = {}
kinases = ['JAK1', 'JAK2', 'JAK3', 'TYK2']
for kin in kinases:
    df_kin = df_smiles.dropna(subset=[kin])
    smiles = df_kin['SMILES'].to_list()
    targets = df_kin[kin].to_list()
    index=0
    graphs=[]
    for mol in smiles:
        g_mol=mol_to_graph(mol)
        g_mol.ndata['feat']=torch.tensor(mol_to_atom_features(mol))
        graphs.append(g_mol)
        index+=1
    kin_graphs[kin] = graphs
    kin_targets[kin] = targets

In [7]:
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from dgl.data import DGLDataset

# Setup DGL dataset for 
class MyGraphDataset(DGLDataset):
    def __init__(self):
        super().__init__(name='gcn')

    def process(self):
        self.graphs = graphs
        self.targets = targets

    def __getitem__(self, i):
        return self.graphs[i], self.targets[i]
    def __len__(self):
        return len(self.graphs)

# Model Development
The general idea of a GCN is to apply a convolution over the graphs that represent the molecules that we have created. These convolution layers allow for the spatial representation of the molecule to be captured by adding up the adjacent local neighborhood of each atom. We will then apply a linear layer after the convolution to achieve our end prediction value.

In [10]:
from dgl.nn import GraphConv
from dgl.nn import AvgPooling
import torch.nn as nn
import torch
import numpy as np
import random
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
import torch.utils.data as data

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, h_feats)
        self.conv3 = GraphConv(h_feats, h_feats)
        self.lin1 = torch.nn.Linear(h_feats, 1)


    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        
        h = self.conv2(g, h)
        h = F.relu(h)
        
        h = self.conv3(g, h)
        h = F.relu(h)
        
        h = self.lin1(h)

        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')
    
def get_loss(preds, pkis):
    preds = np.array(preds).squeeze()
    pkis = np.array(pkis).squeeze() 
    mses = ((preds-pkis)**2).mean(axis=0)
    return mses

# Create the model with given dimensions
for kin in kinases:
    device = torch.device('cpu')
    graphs=kin_graphs[kin]
    targets=kin_targets[kin]
    dataset = MyGraphDataset()
    num_examples = len(dataset)
    num_train = int(len(dataset) * 0.8)
    train_dataset, test_dataset = data.random_split(dataset, [num_train, int(num_examples-num_train)], generator=torch.Generator().manual_seed(42))

    train_dataloader = GraphDataLoader(
        train_dataset, batch_size=16)
    test_dataloader = GraphDataLoader(
        test_dataset, batch_size=16)
    model = GCN(17,  6, 1)
    model = model.float()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    for epoch in range(100):
        for batched_graph, pkis in train_dataloader:
            pkis = pkis.float()
            pred = model(batched_graph, batched_graph.ndata['feat'].float())
            pred = pred.squeeze()
            loss = F.mse_loss(pred, pkis)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    all_preds = []
    all_targets = []
    num_tests = 0
    for batched_graph, pkis in test_dataloader:
        pred = model(batched_graph, batched_graph.ndata['feat'].float())
        all_preds.extend(pred.tolist())
        all_targets.extend(pkis.tolist())
    mses = get_loss(all_preds, all_targets)
    print('MSE for ' + kin + ': ' + str(mses))

MSE for JAK1: 1.0833782829126917
MSE for JAK2: 0.8823662153617801
MSE for JAK3: 1.650273600580743
MSE for TYK2: 0.741984074542709


### Notes on Evaluation Metrics
While I conceptually understood GCNs, this is my first time implementing one and execution took a bit longer than expected. For that reason, it is not fully optimized. I performed optimization of the size of the convolutional layers, but I will go further into more optimization techniques in the future ideas.

## Future Ideas
1. Optimization \
There are many hyperparameters here that still require optimization. A few off of the top of my head are number and size of convolutional layers, number and size of prediction layers, optimizer type and learning rate, and many more.
2. Bond Features \
Right now, bonds are simply being treated as edges in the graph. This does not do them justice, as there are different types of bonds, conjugations, and so much more. A molecules structure is strongly dependent on this information, so it should be incorporated in future iterations.
3. Backward Atom Feature Elimination \
As I have little to no understanding on how atomic features would effect pKi, I've just included all of the atomic features that I could find. Some of these could be confounding, so it would be possible (but computationally expensive) to perform backward feature elimination by comparing test evaluation scores of subsets of the features.