In [1]:
from rdkit import Chem
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect
from rdkit.DataStructs.cDataStructs import ConvertToNumpyArray
import numpy as np
import torch
import torch_geometric

atorvastatin_smiles = 'O=C(O)C[C@H](O)C[C@H](O)CCn2c(c(c(c2c1ccc(F)cc1)c3ccccc3)C(=O)Nc4ccccc4)C(C)C'
atorvastatin = Chem.MolFromSmiles(atorvastatin_smiles) # Atorvastatin (aka Lipitor) is one of the world's best-selling drugs.

fingerprint = GetMorganFingerprintAsBitVect(atorvastatin, radius=2, nBits=2048)

fp_array = np.zeros((1, ))
ConvertToNumpyArray(fingerprint, fp_array)

# Fingerprints
print(fp_array)
# [0. 1. 0. ... 0. 0. 0.]

print(fp_array.shape)
# (2048,)

[0. 1. 0. ... 0. 0. 0.]
(2048,)


![alt text](../../metadata/Snipaste_2021-05-21_20-26-00.png "explanation")

## 2a. Atom Features and bond connections (edge indices)

We will use these atom features:

a) Atomic number (which determines atom type as well)

b) The number of hydrogens attached to the atom.

These are basic features but sufficient for our purposes.

In [2]:
def get_atom_features(mol):
    atomic_number = []
    num_hs = []
    
    for atom in mol.GetAtoms():
        # print(atom.GetAtomicNum())
        # 原子编号
        atomic_number.append(atom.GetAtomicNum())
        num_hs.append(atom.GetTotalNumHs(includeNeighbors=True))
        
    return torch.tensor([atomic_number, num_hs]).t()

def get_edge_index(mol):
    row, col = [], []
    
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        row += [start, end]
        col += [end, start]
        
    return torch.tensor([row, col], dtype=torch.long)

from torch_geometric.data.dataloader import DataLoader

def prepare_dataloader(mol_list):
    data_list = []

    for i, mol in enumerate(mol_list):

        x = get_atom_features(mol)
        edge_index = get_edge_index(mol)

        data = torch_geometric.data.data.Data(x=x, edge_index=edge_index)
        data_list.append(data)

    return DataLoader(data_list, batch_size=3, shuffle=False), data_list

In [3]:
smiles_list = ['Cc1cc(c(C)n1c2ccc(F)cc2)S(=O)(=O)NCC(=O)N',
'CN(CC(=O)N)S(=O)(=O)c1c(C)n(c(C)c1S(=O)(=O)N(C)CC(=O)N)c2ccc(F)cc2',
'Fc1ccc(cc1)n2cc(COC(=O)CBr)nn2',
'CCOC(=O)COCc1cn(nn1)c2ccc(F)cc2',
'COC(=O)COCc1cn(nn1)c2ccc(F)cc2',
'Fc1ccc(cc1)n2cc(COCC(=O)OCc3cn(nn3)c4ccc(F)cc4)nn2']

mol_list = [Chem.MolFromSmiles(smi) for smi in smiles_list]

dloader, dlist = prepare_dataloader(mol_list)
print(dlist)
#[Data(edge_index=[2, 46], x=[22, 2]),
# Data(edge_index=[2, 66], x=[32, 2]),
# Data(edge_index=[2, 38], x=[18, 2]),
# Data(edge_index=[2, 42], x=[20, 2]),
# Data(edge_index=[2, 40], x=[19, 2]),
# Data(edge_index=[2, 68], x=[31, 2])]

for batch in dloader:
  break

print(batch)
#Batch(batch=[72], edge_index=[2, 150], x=[72, 2])

[Data(edge_index=[2, 46], x=[22, 2]), Data(edge_index=[2, 66], x=[32, 2]), Data(edge_index=[2, 38], x=[18, 2]), Data(edge_index=[2, 42], x=[20, 2]), Data(edge_index=[2, 40], x=[19, 2]), Data(edge_index=[2, 68], x=[31, 2])]
Batch(batch=[72], edge_index=[2, 150], ptr=[4], x=[72, 2])


![alt text](../../metadata/Snipaste_2021-05-24_15-15-53.png "png")

# Define the model

In [19]:
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_add
from torch_geometric.utils import add_self_loops, degree
import torch.nn as nn 

class NeuralLoop(MessagePassing):
    def __init__(self, atom_features, fp_size):
        super(NeuralLoop, self).__init__(aggr='add')
        self.H = nn.Linear(atom_features, atom_features)
        self.W = nn.Linear(atom_features, fp_size)
        
    def forward(self, x, edge_index):
        # x shape: [Number of atoms in molecule, Number of atom features]; [N, in_channels]
        # edge_index shape: [2, E]; E is the number of edges
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
    
    def message(self, x_j, edge_index, size):
        # We simply sum all the neighbouring nodes (including self-loops)
        # This is done implicitly by PyTorch-Geometric :)
        return x_j 
    
    def update(self, v):
        
        updated_atom_features = self.H(v).sigmoid()
        updated_fingerprint = self.W(updated_atom_features).softmax(dim=-1)
        
        return updated_atom_features, updated_fingerprint # shape [N, atom_features]
    
class NeuralFP(nn.Module):
    def __init__(self, atom_features=52, fp_size=2048):
        super(NeuralFP, self).__init__()
        
        self.atom_features = 52
        self.fp_size = 2048
        
        self.loop1 = NeuralLoop(atom_features=atom_features, fp_size=fp_size)
        self.loop2 = NeuralLoop(atom_features=atom_features, fp_size=fp_size)
        self.loops = nn.ModuleList([self.loop1, self.loop2])
        
    def forward(self, data):
        fingerprint = torch.zeros((data.batch.shape[0], self.fp_size), dtype=torch.float)
        
        out = data.x
        for idx, loop in enumerate(self.loops):
            updated_atom_features, updated_fingerprint = loop(out, data.edge_index)
            out = updated_atom_features
            fingerprint += updated_fingerprint
            
        return scatter_add(fingerprint, data.batch, dim=0)