In [1]:
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt

from torch_geometric.data import Data
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import PandasTools

In [2]:
ls sdf

mol_100071.sdf  mol_155576.sdf  mol_217387.sdf  mol_268202.sdf  mol_44900.sdf
mol_101372.sdf  mol_157394.sdf  mol_21800.sdf   mol_268542.sdf  mol_45225.sdf
mol_101795.sdf  mol_157970.sdf  mol_218301.sdf  mol_268672.sdf  mol_45405.sdf
mol_101915.sdf  mol_158733.sdf  mol_218565.sdf  mol_268774.sdf  mol_45900.sdf
mol_102321.sdf  mol_159039.sdf  mol_219115.sdf  mol_268868.sdf  mol_45939.sdf
mol_103251.sdf  mol_159049.sdf  mol_2195.sdf    mol_268887.sdf  mol_4650.sdf
mol_10330.sdf   mol_159285.sdf  mol_220088.sdf  mol_269727.sdf  mol_47262.sdf
mol_103430.sdf  mol_16001.sdf   mol_220291.sdf  mol_270147.sdf  mol_47417.sdf
mol_103699.sdf  mol_160391.sdf  mol_220567.sdf  mol_270152.sdf  mol_47972.sdf
mol_104416.sdf  mol_160744.sdf  mol_220874.sdf  mol_270449.sdf  mol_48008.sdf
mol_105092.sdf  mol_16112.sdf   mol_221492.sdf  mol_270663.sdf  mol_48464.sdf
mol_105528.sdf  mol_16121.sdf   mol_222057.sdf  mol_271040.sdf  mol_49189.sdf
mol_105966.sdf  mol_161843.sdf  mol_222568.sdf  mol_2

In [3]:
def get_drug_graph(mol):
    suppl = Chem.SDMolSupplier('sdf/' + i)
    mol = suppl[0]
    position = (mol.GetConformer().GetPositions())

    nodes = pd.DataFrame([[
        a.GetAtomicNum(), a.GetDegree(), a.GetHybridization(), a.GetIsAromatic(), a.GetFormalCharge()
    ] for a in mol.GetAtoms()])

    bonds = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), bond.GetBondType()) for bond in mol.GetBonds()]
    bonds = pd.DataFrame(bonds).values

    edges = bonds[:, :2]
    edges_attr = bonds[:, 2]

    node_features = torch.Tensor(nodes.values.astype(int))
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    edges_attr = torch.tensor(edges_attr, dtype=torch.long)
    position = torch.tensor(position, dtype=torch.long)
    data = Data(x=node_features, edge_index=edge_index, edge_attr=edges_attr, pos=position)
    return data

In [4]:
drugs = !ls sdf/

In [5]:
drug_graphs = {}
errors = []
for i in drugs:
    try:
        drug_graphs[i.split('.')[0]] = get_drug_graph('sdf/'+i)
    except:
        errors.append(i)



In [6]:
drug_graphs

{'mol_100071': Data(x=[41, 5], edge_index=[2, 45], edge_attr=[45], pos=[41, 3]),
 'mol_101372': Data(x=[33, 5], edge_index=[2, 37], edge_attr=[37], pos=[33, 3]),
 'mol_101795': Data(x=[17, 5], edge_index=[2, 20], edge_attr=[20], pos=[17, 3]),
 'mol_101915': Data(x=[22, 5], edge_index=[2, 24], edge_attr=[24], pos=[22, 3]),
 'mol_102321': Data(x=[34, 5], edge_index=[2, 37], edge_attr=[37], pos=[34, 3]),
 'mol_103251': Data(x=[38, 5], edge_index=[2, 42], edge_attr=[42], pos=[38, 3]),
 'mol_10330': Data(x=[25, 5], edge_index=[2, 28], edge_attr=[28], pos=[25, 3]),
 'mol_103430': Data(x=[16, 5], edge_index=[2, 17], edge_attr=[17], pos=[16, 3]),
 'mol_103699': Data(x=[21, 5], edge_index=[2, 24], edge_attr=[24], pos=[21, 3]),
 'mol_104416': Data(x=[21, 5], edge_index=[2, 23], edge_attr=[23], pos=[21, 3]),
 'mol_105092': Data(x=[30, 5], edge_index=[2, 32], edge_attr=[32], pos=[30, 3]),
 'mol_105528': Data(x=[11, 5], edge_index=[2, 10], edge_attr=[10], pos=[11, 3]),
 'mol_105966': Data(x=[27, 5]

In [7]:
len(errors)

1

In [8]:
errors

['mol_159039.sdf']

In [9]:
import json
import gzip
import zlib
import pickle

In [10]:
with gzip.open('drug.pkl.gz', 'wb') as f:
    pickle.dump(drug_graphs, f)