In [1]:
# sklearn, condusion matrix, mcc and auc
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score
# build dataset
from rdkit import Chem
import networkx as nx
import pickle
import numpy as np
from torch_geometric.utils import from_networkx
# torch
import torch
import torch.nn as nn
from torch_geometric.nn import ChebConv
# tensorboard
from torch.utils.tensorboard import SummaryWriter
# random
import random

[20:08:15] Enabling RDKit 2019.09.3 jupyter extensions


In [2]:
train_writer = SummaryWriter("./runs/3g1lCheb/train")
val_writer = SummaryWriter("./runs/3g1lCheb/val")

In [3]:
identity = {
    'C':[1,0,0,0,0,0,0,0,0,0],
    'N':[0,1,0,0,0,0,0,0,0,0],
    'O':[0,0,1,0,0,0,0,0,0,0],
    'F':[0,0,0,1,0,0,0,0,0,0],
    'P':[0,0,0,0,1,0,0,0,0,0],
    'S':[0,0,0,0,0,1,0,0,0,0],
    'Cl':[0,0,0,0,0,0,1,0,0,0],
    'Br':[0,0,0,0,0,0,0,1,0,0],
    'I':[0,0,0,0,0,0,0,0,1,0],
    'other':[0,0,0,0,0,0,0,0,0,1],
}

zero_five = {
    0:[1,0,0,0,0,0],
    1:[0,1,0,0,0,0],
    2:[0,0,1,0,0,0],
    3:[0,0,0,1,0,0],
    4:[0,0,0,0,1,0],
    5:[0,0,0,0,0,1]
}

num_H = {
    0:[1,0,0,0,0],
    1:[0,1,0,0,0],
    2:[0,0,1,0,0],
    3:[0,0,0,1,0],
    4:[0,0,0,0,1]
}

def mol2graph(mol):
    # mol = Chem.MolFromSmiles(smiles)
    # mol = add_atom_index(mol)
    # graph
    g = nx.Graph()
    for atom in mol.GetAtoms():
        # atom number
        idx = atom.GetIdx()
        # print(idx)
        feature = []
        # identity one-hot 10
        feature.extend(identity.get(atom.GetSymbol(),[0,0,0,0,0,0,0,0,0,1]))
        # degree of atom one-hot 6
        feature.extend(zero_five[atom.GetDegree()])
        # number of hydrogen atoms attached one-hot 5
        feature.extend(num_H[atom.GetNumImplicitHs()])
        # implicit valence electrons one-hot 6
        feature.extend(zero_five[atom.GetImplicitValence()])
        # aromatic 0 or 1
        if atom.GetIsAromatic():
            feature.append(1)
        else:
            feature.append(0)
        # total feature 28d
        g.add_node(idx, feature=feature)
    # add edge
    bonds_info = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()]
    # add self_loop
    for atom in mol.GetAtoms():
        bonds_info.append((atom.GetIdx(), atom.GetIdx()))
    g.add_edges_from(bonds_info)
    # print(g.nodes.data)
    return g


def mol2y(mol):
    _y = []
    som = ['PRIMARY_SOM_1A2', 'PRIMARY_SOM_2A6','PRIMARY_SOM_2B6','PRIMARY_SOM_2C8','PRIMARY_SOM_2C9','PRIMARY_SOM_2C19','PRIMARY_SOM_2D6','PRIMARY_SOM_2E1','PRIMARY_SOM_3A4',
           'SECONDARY_SOM_1A2', 'SECONDARY_SOM_2A6','SECONDARY_SOM_2B6','SECONDARY_SOM_2C8','SECONDARY_SOM_2C9','SECONDARY_SOM_2C19','SECONDARY_SOM_2D6','SECONDARY_SOM_2E1','SECONDARY_SOM_3A4',
           'TERTIARY_SOM_1A2', 'TERTIARY_SOM_2A6','TERTIARY_SOM_2B6','TERTIARY_SOM_2C8','TERTIARY_SOM_2C9','TERTIARY_SOM_2C19','TERTIARY_SOM_2D6','TERTIARY_SOM_2E1','TERTIARY_SOM_3A4'
          ]
    result = []
    for k in som:
        try:
            _res = mol.GetProp(k)
            if ' ' in _res:
                res = _res.split(' ')
                for s in res:
                    result.append(int(s))
                # res = [int(temp) for temp in res]
            else:
                # res = [int(_res)]
                result.append(int(_res))
        except:
            pass

    for data in result:
        _y.append(data)
    _y = list(set(_y))

    y = np.zeros(len(mol.GetAtoms()))
    for i in _y:
        y[i-1] = 1
    return y

In [4]:
mols = Chem.SDMolSupplier('../../raw_database/merged.sdf')
dataset = []
for mol in mols:
    g = mol2graph(mol)
    y = mol2y(mol)
    graph = from_networkx(g)
    graph.feature = graph.feature.float()
    label = torch.tensor(y, dtype=torch.float)
    dataset.append((graph, label))

In [5]:
random.seed('42')
random.shuffle(dataset)

In [6]:
total = len(dataset)
ratio = 0.8
training_set = dataset[:int(total * 0.8)]
test_set = dataset[int(total * 0.8):]

In [7]:
validation_set = training_set[int(len(training_set) * 0.8):]
training_set = training_set[:int(len(training_set) * 0.8)]

In [9]:
len(training_set), len(test_set), len(validation_set)

(435, 136, 109)

In [10]:
# evaluation
def top2(output, label):
    preds = torch.sigmoid(output)
    _, indices = torch.topk(preds, 2)
    pos_index = []
    for i in range(label.shape[0]):
        if label[i] == 1:
            pos_index.append(i)
    # print(pos_index)      
    for li in pos_index:
        if li in indices:
            return True
    return False
    
def MCC(output, label):
    tn,fp,fn,tp=confusion_matrix(label, output).ravel()
    # print(f"TN: {tn}, FP: {fp}, FN: {fn}, TP: {tp}")
    up = (tp * tn) - (fp * fn)
    down = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5
    return up / down

In [11]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = ChebConv(28, 1024, 2)
        self.conv2 = ChebConv(1024, 1024, 2)
        self.conv3 = ChebConv(1024, 1024, 2)
        self.linear1 = nn.Linear(1024, 1)
        self.relu = nn.ReLU()


    
    def forward(self, mol):
        res = self.conv1(mol.feature, mol.edge_index)
        res = self.relu(res)
        res = self.conv2(res, mol.edge_index)
        res = self.relu(res)
        res = self.conv3(res, mol.edge_index)
        res = self.relu(res)
        res = self.linear1(res)
        return res

In [12]:
def train(args, model, device, training_set, optimizer, criterion, epoch):
    model.train()
    total_loss = 0
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    # training_set.extend(validation_set)
    for mol, target in training_set:
        mol, target = mol.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(mol)
        # squeeze
        output = torch.squeeze(output)
        # output.squeeze_(1)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        # tracking
        top2n += top2(output, target)
        total_loss += loss.item()
        all_pred.append(np.rint(torch.sigmoid(output).cpu().detach().numpy()))
        all_pred_raw.append(torch.sigmoid(output).cpu().detach().numpy())
        all_labels.append(target.cpu().detach().numpy())
    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    mcc = MCC(all_pred, all_labels)
    train_writer.add_scalar('Ave Loss', total_loss / len(training_set), epoch)
    train_writer.add_scalar('ACC', accuracy_score(all_labels, all_pred), epoch)
    train_writer.add_scalar('Top2', top2n / len(training_set), epoch)
    train_writer.add_scalar('AUC', roc_auc_score(all_labels, all_pred_raw), epoch)
    train_writer.add_scalar('MCC', mcc, epoch)
    print(f'Train Epoch: {epoch}, Ave Loss: {total_loss / len(training_set)} ACC: {accuracy_score(all_labels, all_pred)} Top2: {top2n / len(training_set)} AUC: {roc_auc_score(all_labels, all_pred_raw)} MCC: {mcc}')

In [13]:
def val(args, model, device, val_set, optimizer, criterion, epoch):
    model.eval()
    total_loss = 0
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    for mol, target in val_set:
        mol, target = mol.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(mol)
        # squeeze
        output = torch.squeeze(output)
        # output.squeeze_(1)
        loss = criterion(output, target)
        # tracking
        top2n += top2(output, target)
        total_loss += loss.item()
        all_pred.append(np.rint(torch.sigmoid(output).cpu().detach().numpy()))
        all_pred_raw.append(torch.sigmoid(output).cpu().detach().numpy())
        all_labels.append(target.cpu().detach().numpy())
    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    mcc = MCC(all_pred, all_labels)
    val_writer.add_scalar('Ave Loss', total_loss / len(val_set), epoch)
    val_writer.add_scalar('ACC', accuracy_score(all_labels, all_pred), epoch)
    val_writer.add_scalar('Top2', top2n / len(val_set), epoch)
    val_writer.add_scalar('AUC', roc_auc_score(all_labels, all_pred_raw), epoch)
    val_writer.add_scalar('MCC', mcc, epoch)
    print(f'Val Epoch: {epoch}, Ave Loss: {total_loss / len(val_set)} ACC: {accuracy_score(all_labels, all_pred)} Top2: {top2n / len(val_set)} AUC: {roc_auc_score(all_labels, all_pred_raw)} MCC: {mcc}')
    return top2n / len(val_set)

In [14]:
def main(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(args['seed'])
    model = Model().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=args['lr'], momentum=args['momentum'], weight_decay=args['weight_decay'])
    criterion = nn.BCEWithLogitsLoss(torch.tensor(args['pos_weight']))
    max_top2 = 0
    for epoch in range(1, args['epoch'] + 1):
        train(args, model, device, training_set, optimizer, criterion, epoch)
        top2acc = val(args, model, device, validation_set, optimizer, criterion, epoch)
        random.shuffle(training_set)
        if top2acc > max_top2:
            max_top2 = top2acc
            print('Saving model (epoch = {:4d}, top2acc = {:.4f})'
                .format(epoch, max_top2))
            torch.save(model.state_dict(), args['save_path'])
    model.load_state_dict(torch.load(args['save_path']))
    # test(model, device, test_set)

In [15]:
args = {
    'lr': 0.01,
    'epoch':500,
    'seed': 42,
    'save_path': './retrain_model',
    'momentum':0.9,
    'weight_decay': 1e-7,
    'pos_weight': 3
}

In [16]:
main(args)

Train Epoch: 1, Ave Loss: 1.053682095874315 ACC: 0.8943157019640445 Top2: 0.6367816091954023 AUC: 0.750053446776485 MCC: 0.18780554007420672
Val Epoch: 1, Ave Loss: 0.931671139843967 ACC: 0.8963051251489869 Top2: 0.7706422018348624 AUC: 0.8171754125838061 MCC: 0.2436982066685073
Saving model (epoch =    1, top2acc = 0.7706)
Train Epoch: 2, Ave Loss: 0.9373206773023496 ACC: 0.8960823028161696 Top2: 0.6873563218390805 AUC: 0.8069832922813046 MCC: 0.2821868615108739
Val Epoch: 2, Ave Loss: 0.893361100636491 ACC: 0.9006754072308304 Top2: 0.7706422018348624 AUC: 0.8486607142857143 MCC: 0.3637290927845773
Train Epoch: 3, Ave Loss: 0.886433260536742 ACC: 0.8983684921542139 Top2: 0.7241379310344828 AUC: 0.8352636035143123 MCC: 0.3031971375707193
Val Epoch: 3, Ave Loss: 0.900269382043716 ACC: 0.8955105284068335 Top2: 0.7522935779816514 AUC: 0.8428563370293966 MCC: 0.3391392654791964
Train Epoch: 4, Ave Loss: 0.8620172794694188 ACC: 0.8989919983373168 Top2: 0.7471264367816092 AUC: 0.838882490880

In [26]:
def _mol2y(mol):
    _y = []
    som = ['PRIMARY_SOM', 
           'SECONDARY_SOM', 
           'TERTIARY_SOM', 
          ]
    result = []
    for k in som:
        try:
            _res = mol.GetProp(k)
            if ' ' in _res:
                res = _res.split(' ')
                for s in res:
                    result.append(int(s))
                # res = [int(temp) for temp in res]
            else:
                # res = [int(_res)]
                result.append(int(_res))
        except:
            pass

    for data in result:
        _y.append(data)
    _y = list(set(_y))

    y = np.zeros(len(mol.GetAtoms()))
    for i in _y:
        y[i-1] = 1
    return y

In [27]:
mols = Chem.SDMolSupplier('../../raw_database/2C9.sdf')
cyp2c9 = []
for mol in mols:
    g = mol2graph(mol)
    y = _mol2y(mol)
    graph = from_networkx(g)
    graph.feature = graph.feature.float()
    label = torch.tensor(y, dtype=torch.float)
    cyp2c9.append((graph, label))

In [28]:
mols = Chem.SDMolSupplier('../../raw_database/2D6.sdf')
cyp2d6 = []
for mol in mols:
    g = mol2graph(mol)
    y = _mol2y(mol)
    graph = from_networkx(g)
    graph.feature = graph.feature.float()
    label = torch.tensor(y, dtype=torch.float)
    cyp2d6.append((graph, label))

In [29]:
mols = Chem.SDMolSupplier('../../raw_database/3A4.sdf')
cyp3a4 = []
for mol in mols:
    g = mol2graph(mol)
    y = _mol2y(mol)
    graph = from_networkx(g)
    graph.feature = graph.feature.float()
    label = torch.tensor(y, dtype=torch.float)
    cyp3a4.append((graph, label))

In [30]:
def test(model, device, test_set):
    model.eval()
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    with torch.no_grad():
        for mol, target in test_set:
            mol, target = mol.to(device), target.to(device)
            output = model(mol)
            # squeeze
            output = torch.squeeze(output)
            # tracking
            top2n += top2(output, target)
            all_pred.append(np.rint(torch.sigmoid(output).cpu().detach().numpy()))
            all_pred_raw.append(torch.sigmoid(output).cpu().detach().numpy())
            all_labels.append(target.cpu().detach().numpy())
    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    mcc = MCC(all_pred, all_labels)
    print(f'ACC: {accuracy_score(all_labels, all_pred)} Top2: {top2n / len(test_set)} AUC: {roc_auc_score(all_labels, all_pred_raw)} MCC: {mcc}')

In [21]:
model = Model().to("cuda")
model.load_state_dict(torch.load(args['save_path']))
# model.load_state_dict(torch.load('./model/model80'))

<All keys matched successfully>

In [22]:
test(model, "cuda", test_set)

ACC: 0.9153084052412912 Top2: 0.8308823529411765 AUC: 0.8722380816443048 MCC: 0.48071529885557335


In [31]:
test(model, "cuda", cyp2c9)

ACC: 0.9255989911727617 Top2: 0.8539823008849557 AUC: 0.9386742021169395 MCC: 0.5366507631346034


In [32]:
test(model, "cuda", cyp2d6)

ACC: 0.9307528409090909 Top2: 0.8666666666666667 AUC: 0.9297754401497325 MCC: 0.5498012637959943


In [33]:
test(model, "cuda", cyp3a4)

ACC: 0.9337547408343868 Top2: 0.8610526315789474 AUC: 0.9291867990113111 MCC: 0.5497159854321518
