In [73]:
import numpy as np
from rdkit import Chem
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.optim import SGD
import pandas as pd
from torch_geometric.data import Data

In [3]:
sider = pd.read_csv("../data/sider.csv")
sider

Unnamed: 0,smiles,Hepatobiliary disorders,Metabolism and nutrition disorders,Product issues,Eye disorders,Investigations,Musculoskeletal and connective tissue disorders,Gastrointestinal disorders,Social circumstances,Immune system disorders,...,"Congenital, familial and genetic disorders",Infections and infestations,"Respiratory, thoracic and mediastinal disorders",Psychiatric disorders,Renal and urinary disorders,"Pregnancy, puerperium and perinatal conditions",Ear and labyrinth disorders,Cardiac disorders,Nervous system disorders,"Injury, poisoning and procedural complications"
0,C(CNCCNCCNCCN)N,1,1,0,0,1,1,1,0,0,...,0,0,1,1,0,0,1,1,1,0
1,CC(C)(C)C1=CC(=C(C=C1NC(=O)C2=CNC3=CC=CC=C3C2=...,0,1,0,0,1,1,1,0,0,...,0,1,1,0,0,0,1,0,1,0
2,CC[C@]12CC(=C)[C@H]3[C@H]([C@@H]1CC[C@]2(C#C)O...,0,1,0,1,1,0,1,0,1,...,0,0,0,1,0,0,0,0,1,0
3,CCC12CC(=C)C3C(C1CC[C@]2(C#C)O)CCC4=CC(=O)CCC34,1,1,0,1,1,1,1,0,1,...,1,1,1,1,1,1,0,0,1,1
4,C1C(C2=CC=CC=C2N(C3=CC=CC=C31)C(=O)N)O,1,1,0,1,1,1,1,0,1,...,0,1,1,1,0,0,1,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1422,C[C@H]1CN(CC[C@@]1(C)C2=CC(=CC=C2)O)C[C@H](CC3...,0,1,0,0,0,1,1,0,0,...,0,0,0,0,1,0,0,0,0,0
1423,CC[C@@H]1[C@@]2([C@@H]([C@@H](C(=O)[C@@H](C[C@...,1,1,0,1,1,1,1,0,1,...,0,1,1,1,1,0,1,1,1,1
1424,CCOC1=CC=C(C=C1)CC2=C(C=CC(=C2)[C@H]3[C@@H]([C...,1,1,0,0,1,1,1,0,1,...,0,1,0,0,1,0,0,1,1,1
1425,C1CN(CCC1N2C3=CC=CC=C3NC2=O)CCCC(C4=CC=C(C=C4)...,0,1,0,1,1,1,1,0,0,...,0,0,0,1,1,0,0,1,1,1


In [23]:
test_smi = sider["smiles"][3]
mol = Chem.MolFromSmiles(test_smi)

edges = []
for bond in mol.GetBonds():
    i = bond.GetBeginAtomIdx()
    j = bond.GetEndAtomIdx()
    edges.extend([(i,j), (j,i)])

edge_idx = list(zip(*edges))

In [46]:
def atom_feats(atom):
    return [
        atom.GetAtomicNum(),
        atom.GetDegree(),
        atom.GetNumImplicitHs(),
        atom.GetIsAromatic()
    ]

def bond_feats(bond):
    return [
        bond.GetBondType(),
        bond.GetStereo()
    ]

In [47]:
node_f = [atom_feats(a) for a in mol.GetAtoms()]
edge_f = [bond_feats(b) for b in mol.GetBonds()]

In [58]:
# Put this into a graph to use in PyG

g = Data(x=torch.Tensor(node_f),
        edge_index=torch.LongTensor(edge_idx),
        edge_attr=torch.Tensor(edge_f),
        smiles=test_smi,
        mol=mol)
g

Data(x=[24, 4], edge_index=[2, 54], edge_attr=[27, 2], smiles='CCC12CC(=C)C3C(C1CC[C@]2(C#C)O)CCC4=CC(=O)CCC34', mol=<rdkit.Chem.rdchem.Mol object at 0x7f7d0bbfdd20>)

In [75]:
def smi_to_PyG(smi, task_):
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return None
    edges = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edges.extend([(i,j), (j,i)])

    edge_idx = list(zip(*edges))
    node_f = [atom_feats(a) for a in mol.GetAtoms()]
    edge_f = [bond_feats(b) for b in mol.GetBonds()]

    return Data(x=torch.Tensor(node_f),
               edge_index=torch.LongTensor(edge_idx),
               edge_attr=torch.Tensor(edge_f),
               y=torch.Tensor([task_]))

In [76]:
class GetData(Dataset):
    def __init__(self, smiles, task):
        mols = [smi_to_PyG(smi, task_) for smi, task_\
               in zip(smiles, task)]
        self.X = [mol for mol in mols if mol] #If valid graph

    def __getitem__(self, idx):
        return self.X[idx]

    def __len__(self):
        return len(self.X)

In [77]:
sider_PyG = GetData(sider["smiles"], sider["Hepatobiliary disorders"])



In [83]:
len(sider_PyG)
g_test = sider_PyG[1000]
g_test

Data(x=[21, 4], edge_index=[2, 42], edge_attr=[21, 2], y=[1])