In [1]:
import pandas as pd
import numpy as np
import networkx as nx
import dgl
import torch

from torch.nn import functional as F
from tqdm import tqdm as tqdm

from rdkit import Chem
from rdkit.Chem import AllChem, Draw

from matplotlib import pyplot as plt
%matplotlib inline

In [2]:
def mol_to_nx(mol):
    G = nx.Graph()

    for atom in mol.GetAtoms():
        G.add_node(atom.GetIdx(),
                   label=atom.GetSymbol(),
                   atomic_num=atom.GetAtomicNum(),
                   formal_charge=atom.GetFormalCharge(),
                   chiral_tag=atom.GetChiralTag(),
                   hybridization=atom.GetHybridization(),
                   num_explicit_hs=atom.GetNumExplicitHs(),
                   is_aromatic=atom.GetIsAromatic())
    for bond in mol.GetBonds():
        G.add_edge(bond.GetBeginAtomIdx(),
                   bond.GetEndAtomIdx(),
                   bond_type=bond.GetBondType())
    return G

def nx_to_mol(G):
    mol = Chem.RWMol()
    atomic_nums = nx.get_node_attributes(G, 'atomic_num')
    chiral_tags = nx.get_node_attributes(G, 'chiral_tag')
    formal_charges = nx.get_node_attributes(G, 'formal_charge')
    node_is_aromatics = nx.get_node_attributes(G, 'is_aromatic')
    node_hybridizations = nx.get_node_attributes(G, 'hybridization')
    num_explicit_hss = nx.get_node_attributes(G, 'num_explicit_hs')
    node_to_idx = {}
    for node in G.nodes():
        a=Chem.Atom(atomic_nums[node])
        a.SetChiralTag(chiral_tags[node])
        a.SetFormalCharge(formal_charges[node])
        a.SetIsAromatic(node_is_aromatics[node])
        a.SetHybridization(node_hybridizations[node])
        a.SetNumExplicitHs(num_explicit_hss[node])
        idx = mol.AddAtom(a)
        node_to_idx[node] = idx

    bond_types = nx.get_edge_attributes(G, 'bond_type')
    for edge in G.edges():
        first, second = edge
        ifirst = node_to_idx[first]
        isecond = node_to_idx[second]
        bond_type = bond_types[first, second]
        mol.AddBond(ifirst, isecond, bond_type)

    Chem.SanitizeMol(mol)
    return mol

In [63]:
def atoms2vec(atoms):
    atom_list = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', \
                 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', \
                 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr',\
                 'Cr', 'Pt', 'Hg', 'Pb']
    charge_list = [-3, -2, -1, 0, 1, 2, 3]
    electron_list = [0, 1, 2]
    chirality_list = ['R', 'S']
    
    atom_emb = []
    charge_emb = []
    electron_emb = []
    chirality_emb = []
    aromatic_emb = []
    for atom in atoms:
        try:
            idx = atom_list.index(atom.GetSymbol())
        except:
            idx = len(atom_list)
        
        atom_emb.append(F.one_hot(torch.Tensor([idx]).long(), num_classes=len(atom_list)+1))
        charge_emb.append(F.one_hot(torch.Tensor([atom.GetFormalCharge()+3]).long(), num_classes=len(charge_list)))
        electron_emb.append(F.one_hot(torch.Tensor([atom.GetNumRadicalElectrons()]).long(), num_classes=len(electron_list)))
        aromatic_emb.append(torch.Tensor([atom.GetIsAromatic()]).long().unsqueeze(0))
        try:
            idx = chirality_list.index(atom.GetProp('_CIPCode'))
        except:
            idx = len(chirality_list)
        chirality_emb.append(F.one_hot(torch.Tensor([idx]).long(), num_classes=len(chirality_list)+1))

    return atom_emb, charge_emb, electron_emb, chirality_emb, aromatic_emb
   
def bonds2vec(bonds):
    conjugated = []
    ring = []
    bond_emb = []
    chirality = []
    for bond in bonds:
        bt = bond.GetBondType()
        bs = bond.GetStereo()
        bond_emb.append(torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,
                         bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC]).long())
        conjugated.append(torch.Tensor([bond.GetIsConjugated()]).long())
        ring.append(torch.Tensor([bond.IsInRing()]).long())
        chirality.append(torch.Tensor([bs=="STEREONONE", bs=="STEREOANY", 
                                       bs=="STEREOZ", bs=="STEREOE"]).long())
    return bond_emb, conjugated, ring, chirality
              
def mol2graph(mol):
    bfs_root = list(Chem.CanonicalRankAtoms(mol)).index(0)
    
    mol = Chem.AddHs(mol)
    G = dgl.DGLGraph()

    atoms = list(mol.GetAtoms())
    bonds = list(mol.GetBonds())
    G.add_nodes(len(atoms))
    G.add_edges([b.GetBeginAtomIdx() for b in bonds], [b.GetEndAtomIdx() for b in bonds])
              
    atoms_emb, charge_emb, electron_emb, chirality_emb, aromatic_emb = atoms2vec(atoms)
    feats = []
    for i in range(len(atoms)):
        f = torch.cat([atoms_emb[i], charge_emb[i], electron_emb[i], chirality_emb[i], aromatic_emb[i]], dim=1)
        feats.append(f)
    feats = torch.cat(feats, 0)
    G.ndata['feats'] = feats.float()

    bond_emb, conjugated, ring, chirality = bonds2vec(bonds)
    feats = []
    for i in range(len(bonds)):
        f = torch.cat([bond_emb[i], conjugated[i], ring[i], chirality[i]], dim=0)
        feats.append(f.unsqueeze(0))
    feats = torch.cat(feats, 0)
    G.edata['feats'] = feats.float()
    
    bfs_edge = torch.cat(dgl.bfs_edges_generator(G, bfs_root))
    atom_order = []
    bond_order = []
    for edge in bfs_edge:
        s = bonds[edge].GetBeginAtomIdx()
        e = bonds[edge].GetEndAtomIdx()
        if s not in atom_order: atom_order.append(s)
        if e not in atom_order: atom_order.append(e)
        bond_order.append(edge.item())
    
    atoms_emb = torch.cat(atoms_emb, dim=0)[atom_order]
    charge_emb = torch.cat(charge_emb, dim=0)[atom_order]
    electron_emb = torch.cat(electron_emb, dim=0)[atom_order]
    chirality_emb = torch.cat(chirality_emb, dim=0)[atom_order]
    aromatic_emb = torch.cat(aromatic_emb, dim=0)[atom_order]

    bond_emb = torch.cat(bond_emb, dim=0)[bond_order]
    conjugated = torch.cat(conjugated, dim=0)[bond_order]
    ring = torch.cat(ring, dim=0)[bond_order]
    chirality = torch.cat(chirality, dim=0)[bond_order]

    return G, (atoms_emb, charge_emb, electron_emb, chirality_emb, aromatic_emb), (bond_emb, conjugated, ring, chirality)

In [36]:
zinc = pd.read_csv('/Users/dawood/Datasets/ZINC250K/data.csv')

In [66]:
atoms_list = []
radical_electrons = []
charges = []
for smiles in tqdm(zinc['smiles'][10:]):
    print(smiles)
    
    mol = Chem.MolFromSmiles(smiles)
    x = list(Chem.CanonicalRankAtoms(mol))
    ##mol = Chem.AddHs(mol)
    G = mol_to_nx(mol)

    break

  0%|          | 0/249445 [00:00<?, ?it/s]

CCOc1ccc(OCC)c([C@H]2C(C#N)=C(N)N(c3ccccc3C(F)(F)F)C3=C2C(=O)CCC3)c1






In [75]:
G, atom_feats, bond_feats = mol2graph(mol)