In [2]:
from rdkit import Chem
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect
from rdkit.DataStructs.cDataStructs import ConvertToNumpyArray
import numpy as np
import torch
import torch_geometric

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

atorvastatin_smiles = 'O=C(O)C[C@H](O)C[C@H](O)CCn2c(c(c(c2c1ccc(F)cc1)c3ccccc3)C(=O)Nc4ccccc4)C(C)C'
atorvastatin = Chem.MolFromSmiles(atorvastatin_smiles) # Atorvastatin (aka Lipitor) is one of the world's best-selling drugs.

fingerprint = GetMorganFingerprintAsBitVect(atorvastatin, radius=2, nBits=2048)

fp_array = np.zeros((1, ))
ConvertToNumpyArray(fingerprint, fp_array)

# Fingerprints
print(fp_array)
# [0. 1. 0. ... 0. 0. 0.]

print(fp_array.shape)
# (2048,)

[0. 1. 0. ... 0. 0. 0.]
(2048,)


In [3]:
device

device(type='cuda', index=0)

![alt text](../metadata/Snipaste_2021-05-21_20-26-00.png "explanation")

## 2a. Atom Features and bond connections (edge indices)

We will use these atom features:

a) Atomic number (which determines atom type as well)

b) The number of hydrogens attached to the atom.

These are basic features but sufficient for our purposes.

In [4]:
def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(
            x, allowable_set))
    return [x == s for s in allowable_set]


def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return [x == s for s in allowable_set]


In [11]:
def get_atom_features(mol):
    atomic_number = []
    num_hs = []
    
    # for atom in mol.GetAtoms():
    #     # print(atom.GetAtomicNum())
    #     # 原子编号
    #     atomic_number.append(atom.GetAtomicNum())
    #     num_hs.append(atom.GetTotalNumHs(includeNeighbors=True))
    
    AttentiveFP = []

    for atom in mol.GetAtoms():
        AttentiveFP += [
            one_of_k_encoding_unk(
            atom.GetSymbol(),
            [
                'B',
                'C',
                'N',
                'O',
                'F',
                'Si',
                'P',
                'S',
                'Cl',
                'As',
                'Se',
                'Br',
                'Te',
                'I',
                'At',
                'other'
            ]) + one_of_k_encoding(atom.GetDegree(),
                                    [0, 1, 2, 3, 4, 5]) + \
                    [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
                    one_of_k_encoding_unk(atom.GetHybridization(), [
                        Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                        Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.
                                            SP3D, Chem.rdchem.HybridizationType.SP3D2,'other'
                    ]) + [atom.GetIsAromatic()]
                    + one_of_k_encoding_unk(atom.GetTotalNumHs(),
                                                      [0, 1, 2, 3, 4])
        ]

    # return torch.tensor([atomic_number, num_hs], dtype=torch.float).t()
    return torch.tensor(AttentiveFP, dtype=torch.float)

def get_edge_index(mol):
    row, col = [], []
    
    for bond in mol.GetBonds():
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        row += [start, end]
        col += [end, start]
        
    return torch.tensor([row, col], dtype=torch.long)

from torch_geometric.data.dataloader import DataLoader

def prepare_dataloader(mol_list, batch_size=3):
    data_list = []

    for i, mol in enumerate(mol_list):

        x = get_atom_features(mol)
        edge_index = get_edge_index(mol)

        data = torch_geometric.data.Data(x=x, edge_index=edge_index)
        data_list.append(data)

    return DataLoader(data_list, batch_size=batch_size, shuffle=False), data_list

In [12]:
smiles_list = ['Cc1cc(c(C)n1c2ccc(F)cc2)S(=O)(=O)NCC(=O)N',
'CN(CC(=O)N)S(=O)(=O)c1c(C)n(c(C)c1S(=O)(=O)N(C)CC(=O)N)c2ccc(F)cc2',
'Fc1ccc(cc1)n2cc(COC(=O)CBr)nn2',
'CCOC(=O)COCc1cn(nn1)c2ccc(F)cc2',
'COC(=O)COCc1cn(nn1)c2ccc(F)cc2',
'Fc1ccc(cc1)n2cc(COCC(=O)OCc3cn(nn3)c4ccc(F)cc4)nn2']

mol_list = [Chem.MolFromSmiles(smi) for smi in smiles_list]

dloader, dlist = prepare_dataloader(mol_list)
print(dlist)
#[Data(edge_index=[2, 46], x=[22, 2]),
# Data(edge_index=[2, 66], x=[32, 2]),
# Data(edge_index=[2, 38], x=[18, 2]),
# Data(edge_index=[2, 42], x=[20, 2]),
# Data(edge_index=[2, 40], x=[19, 2]),
# Data(edge_index=[2, 68], x=[31, 2])]

for batch in dloader:
  break

print(batch)
#把前三个原子拼接在一起训练，batch的作用
#Batch(batch=[72], edge_index=[2, 150], x=[72, 2])

[Data(edge_index=[2, 46], x=[22, 36]), Data(edge_index=[2, 66], x=[32, 36]), Data(edge_index=[2, 38], x=[18, 36]), Data(edge_index=[2, 42], x=[20, 36]), Data(edge_index=[2, 40], x=[19, 36]), Data(edge_index=[2, 68], x=[31, 36])]
Batch(batch=[72], edge_index=[2, 150], ptr=[4], x=[72, 36])


![alt text](../metadata/Snipaste_2021-05-24_15-15-53.png "png")

# Define the model

![alt text](../metadata/model.png "png")

In [13]:
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_add
from torch_geometric.utils import add_self_loops, degree
import torch.nn as nn 

class NeuralLoop(MessagePassing):
    def __init__(self, atom_features, fp_size):
        super(NeuralLoop, self).__init__(aggr='add')
        self.H = nn.Linear(atom_features, atom_features)
        self.W = nn.Linear(atom_features, fp_size)
        
    def forward(self, x, edge_index):
        # x shape: [Number of atoms in molecule, Number of atom features]; [N, in_channels]
        # edge_index shape: [2, E]; E is the number of edges
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
    
    def message(self, x_j, edge_index, size):
        # We simply sum all the neighbouring nodes (including self-loops)
        # This is done implicitly by PyTorch-Geometric :)
        return x_j 
    
    def update(self, v):
        
        updated_atom_features = self.H(v).sigmoid()
        updated_fingerprint = self.W(updated_atom_features).softmax(dim=-1)
        
        return updated_atom_features, updated_fingerprint # shape [N, atom_features]
    
class NeuralFP(nn.Module):
    def __init__(self, atom_features=52, fp_size=2048):
        super(NeuralFP, self).__init__()
        
        self.atom_features = atom_features
        self.fp_size = fp_size
        
        self.loop1 = NeuralLoop(atom_features=atom_features, fp_size=fp_size)
        self.loop2 = NeuralLoop(atom_features=atom_features, fp_size=fp_size)
        self.loops = nn.ModuleList([self.loop1, self.loop2])
        
    def forward(self, data):
        fingerprint = torch.zeros((data.batch.shape[0], self.fp_size), dtype=torch.float).cuda()
        out = data.x
        for idx, loop in enumerate(self.loops):
            updated_atom_features, updated_fingerprint = loop(out, data.edge_index)
            out = updated_atom_features
            fingerprint += updated_fingerprint
            
        return scatter_add(fingerprint, data.batch, dim=0)

INFO:tensorflow:Enabling eager execution
INFO:tensorflow:Enabling v2 tensorshape
INFO:tensorflow:Enabling resource variables
INFO:tensorflow:Enabling tensor equality
INFO:tensorflow:Enabling control flow v2


In [15]:
neural_fp = NeuralFP(atom_features=36, fp_size=2048).cuda()
fps = neural_fp(batch.cuda()) # remember, batch size was 3
print(fps.shape)
# torch.Size([3, 2048])

torch.Size([3, 2048])


# Work on a real dataset

In [17]:
import deepchem as dc
_, (train, valid, test), _ = dc.molnet.load_delaney(featurizer='Raw')

bs = 32

train_loader, _ = prepare_dataloader(list(train.X), batch_size=bs)
valid_loader, _ = prepare_dataloader(valid.X, bs)
test_loader, _ = prepare_dataloader(test.X, bs)

train_labels_loader = torch.utils.data.DataLoader(train.y, batch_size=bs)
valid_labels_loader = torch.utils.data.DataLoader(valid.y, batch_size=bs)
test_labels_loader = torch.utils.data.DataLoader(test.y, batch_size=bs)

## build a smal MLP on top of our neural fingerprint.

In [18]:
import torch.nn.functional as F

class MLP_Regressor(nn.Module):
    def __init__(self, atom_features=2, fp_size=2048, hidden_size=100):
        super(MLP_Regressor, self).__init__()
        self.neural_fp = neural_fp.cuda()
        self.lin1 = nn.Linear(fp_size, hidden_size)
        self.lin2 = nn.Linear(hidden_size, 1)
        self.dropout = nn.Dropout(p=0.3)
    
    def forward(self, batch):
        fp = self.neural_fp(batch)
        hidden = F.relu(self.dropout(self.lin1(fp)))
        out = self.lin2(hidden)
        return out

## Define our utility functions for training and validation:


In [19]:
def train_step(batch, labels, reg):
    batch.cuda()
    labels.cuda()
    out = reg(batch).cuda()
    loss = torch.sqrt(F.mse_loss(out, labels.to(torch.float).cuda(), reduction='mean').cuda())
    loss.backward()
    return loss

def valid_step(batch, labels, reg):
    batch.cuda()
    labels.cuda()
    out = reg(batch).cuda()
    loss = torch.sqrt(F.mse_loss(out, labels.to(torch.float).cuda(), reduction='mean').cuda())
    return loss

def train_fn(train_loader, train_labels_loader, reg, opt):
    reg.train()
    total_loss = 0
    for idx, (batch, labels) in enumerate(zip(train_loader, train_labels_loader)):
        loss = train_step(batch, labels, reg).cuda()
        total_loss += loss.item()

    torch.nn.utils.clip_grad_norm_(reg.parameters(), 1)    
    opt.step()
    opt.zero_grad()
    return total_loss/len(train_loader)

def valid_fn(valid_loader, valid_labels_loader, reg):
    reg.eval()
    total_loss = 0
    with torch.no_grad():
        for idx, (batch, labels) in enumerate(zip(valid_loader, valid_labels_loader)):
            loss = valid_step(batch, labels, reg)
            total_loss += loss.item()
    
    total_loss /= len(valid_loader)
        
    return total_loss

In [20]:
reg = MLP_Regressor(atom_features=36, fp_size=2048, hidden_size=100).cuda()
optimizer = torch.optim.SGD(reg.parameters(), lr=0.001, weight_decay=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=100)

total_epochs = 1000
for epoch in range(1, total_epochs+1):
    train_loss = train_fn(train_loader, train_labels_loader, reg, opt=optimizer)
    valid_loss = valid_fn(valid_loader, valid_labels_loader, reg)
    scheduler.step(valid_loss)

    if epoch % 10 == 0:
        print(f'Epoch:{epoch}, Train loss: {train_loss}, Valid loss: {valid_loss}')

Epoch:10, Train loss: 0.9496783737478585, Valid loss: 1.0468420386314392
Epoch:20, Train loss: 0.9487133601616169, Valid loss: 1.0484953075647354
Epoch:30, Train loss: 0.9480453832396145, Valid loss: 1.0499243587255478
Epoch:40, Train loss: 0.9466879347275043, Valid loss: 1.0511573404073715
Epoch:50, Train loss: 0.946238885665762, Valid loss: 1.0523102283477783
Epoch:60, Train loss: 0.9450802001459845, Valid loss: 1.0534222424030304
Epoch:70, Train loss: 0.9447108383836418, Valid loss: 1.0544604808092117
Epoch:80, Train loss: 0.9439025562385033, Valid loss: 1.0552864521741867
Epoch:90, Train loss: 0.9430742037707361, Valid loss: 1.0559005737304688
Epoch:100, Train loss: 0.9427291216521427, Valid loss: 1.0563311874866486
Epoch:110, Train loss: 0.9422031599899818, Valid loss: 1.0564383417367935
Epoch:120, Train loss: 0.9422001633150824, Valid loss: 1.056469202041626
Epoch:130, Train loss: 0.9422807796248074, Valid loss: 1.0564972460269928
Epoch:140, Train loss: 0.9419742781540443, Valid 

2 features：0.9302953