In [1]:
from rdkit import Chem
import torch
from torch_geometric.data import Dataset, Data
import numpy as np
import os
from rdkit.Chem.rdchem import HybridizationType, ChiralType

[18:36:33] Enabling RDKit 2019.09.3 jupyter extensions


In [2]:
class AtomTopbased(Dataset):

    def __init__(self, root, filename, test=False,transform=None, pre_transform=None, pre_filter=None):
        self.filename = filename
        self.test = test
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return self.filename

    @property
    def processed_file_names(self):
        self.mols = Chem.SDMolSupplier(self.raw_paths[0])
        if self.test:
            return [f'data_test_{i}' for i in range(len(self.mols))]
        else:
            return [f'data_{i}.pt' for i in range(len(self.mols))]

    def download(self):
        pass

    def process(self):
        self.mols = Chem.SDMolSupplier(self.raw_paths[0])
        for idx, mol in enumerate(self.mols):
            # Get node features
            node_feats = self._get_node_features(mol)
            # Get edge features
            edge_feats = self._get_edge_features(mol)
            # Get adjacency info
            edge_index = self._get_adjacency_info(mol)
            # Get labels info
            label = self._get_labels(mol)
            # create data object
            data = Data(x=node_feats, 
                        edge_index=edge_index,
                        edge_attr=edge_feats,
                        y=label,
                        )
            if self.test:
                torch.save(data, os.path.join(self.processed_dir, \
                f'data_test_{idx}.pt'))
            else:
                torch.save(data, os.path.join(self.processed_dir, \
                f'data_{idx}.pt'))
        
    def _get_node_features(self, mol):
        all_node_feats = []

        identity = {
            'C':[1,0,0,0,0,0,0,0,0,0],
            'N':[0,1,0,0,0,0,0,0,0,0],
            'O':[0,0,1,0,0,0,0,0,0,0],
            'F':[0,0,0,1,0,0,0,0,0,0],
            'P':[0,0,0,0,1,0,0,0,0,0],
            'S':[0,0,0,0,0,1,0,0,0,0],
            'Cl':[0,0,0,0,0,0,1,0,0,0],
            'Br':[0,0,0,0,0,0,0,1,0,0],
            'I':[0,0,0,0,0,0,0,0,1,0],
            'other':[0,0,0,0,0,0,0,0,0,1],
        }
        for atom in mol.GetAtoms():
            node_feats = []
            # atom number
            idx = atom.GetIdx()
            # atom type one-hot 10
            node_feats.extend(identity.get(atom.GetSymbol(),[0,0,0,0,0,0,0,0,0,1]))
            # implicit valence
            node_feats.append(atom.GetImplicitValence())
            # formal charge
            node_feats.append(atom.GetFormalCharge())
            # radical electrons
            node_feats.append(atom.GetNumRadicalElectrons())
            
            # aromatic 0 or 1
            if atom.GetIsAromatic():
                node_feats.append(1)
            else:
                node_feats.append(0)

            # chirality
            chirality = atom.GetChiralTag()
            if chirality == ChiralType.CHI_TETRAHEDRAL_CCW: temp = [1, 0, 0, 0]
            if chirality == ChiralType.CHI_TETRAHEDRAL_CW: temp = [0, 1, 0, 0]
            if chirality == ChiralType.CHI_OTHER: temp = [0, 0, 1, 0]
            if chirality == ChiralType.CHI_UNSPECIFIED: temp = [0, 0, 0, 1]
            node_feats.extend(temp)
            # hybridization
            hybridization = atom.GetHybridization()
            if hybridization == HybridizationType.S: tmp = [1, 0, 0, 0, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP: tmp = [0, 1, 0, 0, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP2: tmp = [0, 0, 1, 0, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP3: tmp = [0, 0, 0, 1, 0, 0, 0, 0]
            if hybridization == HybridizationType.SP3D: tmp = [0, 0, 0, 0, 1, 0, 0, 0]
            if hybridization == HybridizationType.SP3D2: tmp = [0, 0, 0, 0, 0, 1, 0, 0]
            if hybridization == HybridizationType.OTHER: tmp = [0, 0, 0, 0, 0, 0, 1, 0]
            if hybridization == HybridizationType.UNSPECIFIED: tmp = [0, 0, 0, 0, 0, 0, 0, 1]
            node_feats.extend(tmp)

            # degree
            node_feats.append(atom.GetDegree())
            # hydrogen
            node_feats.append(atom.GetTotalNumHs())
            # ring status 0 or 1
            if atom.IsInRing():
                node_feats.append(1)
            else:
                node_feats.append(0)
            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        all_edge_feats = []

        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Rings
            edge_feats.append(bond.IsInRing())
            # Append node features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        """
        We could also use rdmolops.GetAdjacencyMatrix(mol)
        but we want to be sure that the order of the indices
        matches the order of the edge features
        """
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

    def _get_labels(self, mol):
        _y = []
        som = ['PRIMARY_SOM_1A2', 'PRIMARY_SOM_2A6','PRIMARY_SOM_2B6','PRIMARY_SOM_2C8','PRIMARY_SOM_2C9','PRIMARY_SOM_2C19','PRIMARY_SOM_2D6','PRIMARY_SOM_2E1','PRIMARY_SOM_3A4',
            'SECONDARY_SOM_1A2', 'SECONDARY_SOM_2A6','SECONDARY_SOM_2B6','SECONDARY_SOM_2C8','SECONDARY_SOM_2C9','SECONDARY_SOM_2C19','SECONDARY_SOM_2D6','SECONDARY_SOM_2E1','SECONDARY_SOM_3A4',
            'TERTIARY_SOM_1A2', 'TERTIARY_SOM_2A6','TERTIARY_SOM_2B6','TERTIARY_SOM_2C8','TERTIARY_SOM_2C9','TERTIARY_SOM_2C19','TERTIARY_SOM_2D6','TERTIARY_SOM_2E1','TERTIARY_SOM_3A4'
            ]
        result = []
        for k in som:
            try:
                _res = mol.GetProp(k)
                if ' ' in _res:
                    res = _res.split(' ')
                    for s in res:
                        result.append(int(s))
                else:
                    result.append(int(_res))
            except:
                pass

        for data in result:
            _y.append(data)
        _y = list(set(_y))

        y = np.zeros(len(mol.GetAtoms()))
        for i in _y:
            y[i-1] = 1
        # int64 or float32? past is float32
        return torch.tensor(y, dtype=torch.int64)

    def len(self):
        return len(self.mols)

    def get(self, idx):
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [3]:
train_dataset = AtomTopbased('../Dataset/atom_top_based/', 'train.sdf')
test_dataset = AtomTopbased('../Dataset/atom_top_based/', 'test.sdf', test=True)

Processing...
Done!
Processing...
Done!
