In [85]:
# sklearn, condusion matrix, mcc and auc
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score
# build dataset
from rdkit import Chem
import networkx as nx
import pickle
import numpy as np
from torch_geometric.utils import from_networkx
# torch
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv
# tensorboard
from torch.utils.tensorboard import SummaryWriter
# random
import random
import matplotlib.pyplot as plt

In [87]:
train_writer = SummaryWriter("./GCN2/runs/3g2lbceloss/train")
val_writer = SummaryWriter("./GCN2/runs/3g2lbceloss/val")

In [88]:
identity = {
    'C':[1,0,0,0,0,0,0,0,0,0],
    'N':[0,1,0,0,0,0,0,0,0,0],
    'O':[0,0,1,0,0,0,0,0,0,0],
    'F':[0,0,0,1,0,0,0,0,0,0],
    'P':[0,0,0,0,1,0,0,0,0,0],
    'S':[0,0,0,0,0,1,0,0,0,0],
    'Cl':[0,0,0,0,0,0,1,0,0,0],
    'Br':[0,0,0,0,0,0,0,1,0,0],
    'I':[0,0,0,0,0,0,0,0,1,0],
    'other':[0,0,0,0,0,0,0,0,0,1],
}

zero_five = {
    0:[1,0,0,0,0,0],
    1:[0,1,0,0,0,0],
    2:[0,0,1,0,0,0],
    3:[0,0,0,1,0,0],
    4:[0,0,0,0,1,0],
    5:[0,0,0,0,0,1]
}

num_H = {
    0:[1,0,0,0,0],
    1:[0,1,0,0,0],
    2:[0,0,1,0,0],
    3:[0,0,0,1,0],
    4:[0,0,0,0,1]
}

In [89]:
def mol2graph(mol):
    # mol = Chem.MolFromSmiles(smiles)
    # mol = add_atom_index(mol)
    # graph
    g = nx.Graph()
    for atom in mol.GetAtoms():
        # atom number
        idx = atom.GetIdx()
        # print(idx)
        feature = []
        # identity one-hot 10
        feature.extend(identity.get(atom.GetSymbol(),[0,0,0,0,0,0,0,0,0,1]))
        # degree of atom one-hot 6
        feature.extend(zero_five[atom.GetDegree()])
        # number of hydrogen atoms attached one-hot 5
        feature.extend(num_H[atom.GetNumImplicitHs()])
        # implicit valence electrons one-hot 6
        feature.extend(zero_five[atom.GetImplicitValence()])
        # aromatic 0 or 1
        if atom.GetIsAromatic():
            feature.append(1)
        else:
            feature.append(0)
        # total feature 28d
        g.add_node(idx, feature=feature)
    # add edge
    bonds_info = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()]
    # add self_loop
    for atom in mol.GetAtoms():
        bonds_info.append((atom.GetIdx(), atom.GetIdx()))
    g.add_edges_from(bonds_info)
    # print(g.nodes.data)
    return g


def mol2y(mol):
    _y = []
    som = ['PRIMARY_SOM_1A2', 'PRIMARY_SOM_2A6','PRIMARY_SOM_2B6','PRIMARY_SOM_2C8','PRIMARY_SOM_2C9','PRIMARY_SOM_2C19','PRIMARY_SOM_2D6','PRIMARY_SOM_2E1','PRIMARY_SOM_3A4',
           'SECONDARY_SOM_1A2', 'SECONDARY_SOM_2A6','SECONDARY_SOM_2B6','SECONDARY_SOM_2C8','SECONDARY_SOM_2C9','SECONDARY_SOM_2C19','SECONDARY_SOM_2D6','SECONDARY_SOM_2E1','SECONDARY_SOM_3A4',
           'TERTIARY_SOM_1A2', 'TERTIARY_SOM_2A6','TERTIARY_SOM_2B6','TERTIARY_SOM_2C8','TERTIARY_SOM_2C9','TERTIARY_SOM_2C19','TERTIARY_SOM_2D6','TERTIARY_SOM_2E1','TERTIARY_SOM_3A4'
          ]
    result = []
    for k in som:
        try:
            _res = mol.GetProp(k)
            if ' ' in _res:
                res = _res.split(' ')
                for s in res:
                    result.append(int(s))
                # res = [int(temp) for temp in res]
            else:
                # res = [int(_res)]
                result.append(int(_res))
        except:
            pass

    for data in result:
        _y.append(data)
    _y = list(set(_y))

    y = np.zeros(len(mol.GetAtoms()))
    for i in _y:
        y[i-1] = 1
    return y

In [90]:
mols = Chem.SDMolSupplier('./raw_database/merged.sdf')

In [91]:
dataset = []
for mol in mols:
    g = mol2graph(mol)
    y = mol2y(mol)
    graph = from_networkx(g)
    graph.feature = graph.feature.float()
    label = torch.tensor(y, dtype=torch.float)
    dataset.append((graph, label))

In [92]:
def som2vec(mol, som):
    y = np.zeros(len(mol.GetAtoms()))
    # try:
    for i in som:
        y[int(i) - 1] = 1
    return y

In [93]:
# load test set
filepath = './raw_database/new2.txt'
with open(filepath) as f:
    for line in f.readlines():
        raw = eval(line)
        smiles = raw[1]
        som = raw[-1]
        mol = Chem.MolFromSmiles(smiles)
        g = mol2graph(mol)
        y = som2vec(mol, som)
        mol = from_networkx(g)
        mol.feature = mol.feature.float()
        label = torch.tensor(y).float()
        dataset.append((mol,label))

In [94]:
random.seed('42')
for i in range(100):
    random.shuffle(dataset)

In [95]:
total = len(dataset)
ratio = 0.8
training_set = dataset[:int(total * 0.8)]
lea = dataset[int(total * 0.8):]
validation_set = lea[:int(len(lea) * 0.5)]
test_set = lea[int(len(lea) * 0.5):]

In [96]:
len(training_set), len(validation_set), len(test_set)

(754, 94, 95)

In [97]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = GCNConv(28, 56)
        self.conv2 = GCNConv(56, 112)
        self.conv3 = GCNConv(112, 224)
        self.linear1 = nn.Linear(224, 56)
        self.linear2 = nn.Linear(56, 1)
        self.relu = nn.ReLU()
        self.ln1 = nn.LayerNorm(56)
        self.ln2 = nn.LayerNorm(112)
        self.ln3 = nn.LayerNorm(224)

    
    def forward(self, mol):
        res = self.ln1(self.conv1(mol.feature, mol.edge_index))
        res = self.relu(res)
        res = self.ln2(self.conv2(res, mol.edge_index))
        res = self.relu(res)
        res = self.ln3(self.conv3(res, mol.edge_index))
        res = self.relu(res)
        res = self.linear1(res)
        res = self.relu(res)
        res = self.linear2(res)
        return res

In [98]:
# evaluation
def top2(output, label):
    preds = torch.sigmoid(output)
    _, indices = torch.topk(preds, 2)
    pos_index = []
    for i in range(label.shape[0]):
        if label[i] == 1:
            pos_index.append(i)
    # print(pos_index)      
    for li in pos_index:
        if li in indices:
            return True
    return False
    
def MCC(output, label):
    tn,fp,fn,tp=confusion_matrix(label, output).ravel()
    print(f"TN: {tn}, FP: {fp}, FN: {fn}, TP: {tp}")
    up = (tp * tn) - (fp * fn)
    down = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5
    return up / down

In [99]:
def val(args, model, device, val_set, optimizer, criterion, epoch):
    model.eval()
    total_loss = 0
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    for mol, target in val_set:
        mol, target = mol.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(mol)
        # squeeze
        output = torch.squeeze(output)
        # output.squeeze_(1)
        loss = criterion(output, target)
        # tracking
        top2n += top2(output, target)
        total_loss += loss.item()
        all_pred.append(np.rint(torch.sigmoid(output).cpu().detach().numpy()))
        all_pred_raw.append(torch.sigmoid(output).cpu().detach().numpy())
        all_labels.append(target.cpu().detach().numpy())
    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    mcc = MCC(all_pred, all_labels)
    val_writer.add_scalar('Ave Loss', total_loss / len(val_set), epoch)
    val_writer.add_scalar('ACC', accuracy_score(all_labels, all_pred), epoch)
    val_writer.add_scalar('Top2', top2n / len(val_set), epoch)
    val_writer.add_scalar('AUC', roc_auc_score(all_labels, all_pred_raw), epoch)
    val_writer.add_scalar('MCC', mcc, epoch)
    print(f'Val Epoch: {epoch}, Ave Loss: {total_loss / len(val_set)} ACC: {accuracy_score(all_labels, all_pred)} Top2: {top2n / len(val_set)} AUC: {roc_auc_score(all_labels, all_pred_raw)} MCC: {mcc}')
    return top2n / len(val_set)

In [100]:
def train(args, model, device, training_set, optimizer, criterion, epoch):
    model.train()
    total_loss = 0
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    for mol, target in training_set:
        mol, target = mol.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(mol)
        # squeeze
        output = torch.squeeze(output)
        # output.squeeze_(1)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        # tracking
        top2n += top2(output, target)
        total_loss += loss.item()
        all_pred.append(np.rint(torch.sigmoid(output).cpu().detach().numpy()))
        all_pred_raw.append(torch.sigmoid(output).cpu().detach().numpy())
        all_labels.append(target.cpu().detach().numpy())
    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    mcc = MCC(all_pred, all_labels)
    train_writer.add_scalar('Ave Loss', total_loss / len(training_set), epoch)
    train_writer.add_scalar('ACC', accuracy_score(all_labels, all_pred), epoch)
    train_writer.add_scalar('Top2', top2n / len(training_set), epoch)
    train_writer.add_scalar('AUC', roc_auc_score(all_labels, all_pred_raw), epoch)
    train_writer.add_scalar('MCC', mcc, epoch)
    print(f'Train Epoch: {epoch}, Ave Loss: {total_loss / len(training_set)} ACC: {accuracy_score(all_labels, all_pred)} Top2: {top2n / len(training_set)} AUC: {roc_auc_score(all_labels, all_pred_raw)} MCC: {mcc}')

In [101]:
def main(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(args['seed'])
    model = Model().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=args['lr'], momentum=args['momentum'], weight_decay=args['weight_decay'])
    criterion = nn.BCEWithLogitsLoss()
    max_top2 = 0
    for epoch in range(1, args['epoch'] + 1):
        train(args, model, device, training_set, optimizer, criterion, epoch)
        top2acc = val(args, model, device, validation_set, optimizer, criterion, epoch)
        random.shuffle(training_set)
        if top2acc > max_top2:
            max_top2 = top2acc
            print('Saving model (epoch = {:4d}, top2acc = {:.4f})'
                .format(epoch, max_top2))
            torch.save(model.state_dict(), args['save_path'])

In [102]:
args = {
    'lr': 0.01,
    'epoch': 1000,
    'seed': 42,
    'save_path': './GCN2/model/32bce/model',
    'momentum':0,
    'weight_decay': 0,
    'pos_weight': 3
}

In [103]:
main(args)

TN: 15363, FP: 0, FN: 1701, TP: 0
Train Epoch: 1, Ave Loss: 0.35497212282503 ACC: 0.9003164556962026 Top2: 0.41379310344827586 AUC: 0.6465178578842722 MCC: nan


  return up / down


TN: 2036, FP: 0, FN: 222, TP: 0
Val Epoch: 1, Ave Loss: 0.3273379722649747 ACC: 0.9016829052258636 Top2: 0.5106382978723404 AUC: 0.7371955698330944 MCC: nan
Saving model (epoch =    1, top2acc = 0.5106)


  return up / down


TN: 15362, FP: 1, FN: 1688, TP: 13
Train Epoch: 2, Ave Loss: 0.3298630436946606 ACC: 0.9010196905766527 Top2: 0.4960212201591512 AUC: 0.7311916025672743 MCC: 0.07928449936037156
TN: 2027, FP: 9, FN: 204, TP: 18
Val Epoch: 2, Ave Loss: 0.33349018718333956 ACC: 0.9056687333923826 Top2: 0.5212765957446809 AUC: 0.7663907768279085 MCC: 0.2099936671030089
Saving model (epoch =    2, top2acc = 0.5213)
TN: 15316, FP: 47, FN: 1620, TP: 81
Train Epoch: 3, Ave Loss: 0.3205649224591666 ACC: 0.9023089545241444 Top2: 0.5305039787798409 AUC: 0.7511107353332902 MCC: 0.15471146677480255
TN: 2034, FP: 2, FN: 206, TP: 16
Val Epoch: 3, Ave Loss: 0.3016253761313063 ACC: 0.9078830823737821 Top2: 0.5212765957446809 AUC: 0.789979689906016 MCC: 0.23801927198780376
TN: 15300, FP: 63, FN: 1603, TP: 98
Train Epoch: 4, Ave Loss: 0.3148277216845387 ACC: 0.9023675574308486 Top2: 0.5596816976127321 AUC: 0.7610877321437326 MCC: 0.1658249311536414
TN: 2016, FP: 20, FN: 184, TP: 38
Val Epoch: 4, Ave Loss: 0.303224231968

In [104]:
def test(model, device, test_set):
    model.eval()
    all_pred = []
    all_pred_raw = []
    all_labels = []
    top2n = 0
    with torch.no_grad():
        for mol, target in test_set:
            mol, target = mol.to(device), target.to(device)
            output = model(mol)
            # squeeze
            output = torch.squeeze(output)
            # tracking
            top2n += top2(output, target)
            all_pred.append(np.rint(torch.sigmoid(output).cpu().detach().numpy()))
            all_pred_raw.append(torch.sigmoid(output).cpu().detach().numpy())
            all_labels.append(target.cpu().detach().numpy())
    all_pred = np.concatenate(all_pred).ravel()
    all_pred_raw = np.concatenate(all_pred_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    mcc = MCC(all_pred, all_labels)
    print(f'ACC: {accuracy_score(all_labels, all_pred)} Top2: {top2n / len(test_set)} AUC: {roc_auc_score(all_labels, all_pred_raw)} MCC: {mcc}')

In [105]:
model = Model().to("cuda")
model.load_state_dict(torch.load(args['save_path']))

<All keys matched successfully>

In [106]:
test(model, "cuda", test_set)

TN: 2010, FP: 73, FN: 121, TP: 64
ACC: 0.9144620811287478 Top2: 0.6526315789473685 AUC: 0.8580957818115763 MCC: 0.3571903241263952
