In [1]:
# code to convert SMILES string to bigraph, pull energies as labels
# save and load graphs and energies as input (X) and labels (y)

In [3]:
import torch
from rdkit import Chem
from dgllife.utils import smiles_to_bigraph
from dgl.data.utils import load_graphs
from dgl.data.utils import save_graphs
from dgl.dataloading import GraphDataLoader as gdl
import pandas as pd
import numpy as np
import os
from torch.utils.data import Dataset
from dgl.data import DGLDataset

Using backend: pytorch


In [4]:
class MyDataset(DGLDataset):
    def __init__(self,
                 df:str,
                 url=None,
                 raw_dir=None,
                 save_dir=None,
                 force_reload=False,
                 verbose=False):
        
        self.df = pd.read_csv(df)
        
        self._name = os.path.basename(df)

        self.bigraph = [] # X
        self.labels =  [] # y 
        
        super(MyDataset, self).__init__(name=os.path.basename(df),
                                        url=url,
                                        raw_dir=raw_dir,
                                        save_dir=save_dir,
                                        force_reload=force_reload,
                                        verbose=verbose)

    def featurize_atoms(self, mol):
    # featurize atoms
        feats = []
        for atom in mol.GetAtoms():
            feats.append(atom.GetAtomicNum())
        return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
    
    def featurize_bonds(self, mol):
    # featurize bonds
        feats = []
        bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                    Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
        for bond in mol.GetBonds():
            btype = bond_types.index(bond.GetBondType())
            feats.extend([btype, btype])
        return {'type': torch.tensor(feats).reshape(-1, 1).float()}
    
    def process(self):
    # make bigraphs and labels (energy)
        self.bigraph = list(self.df['Smiles'].apply(lambda x: smiles_to_bigraph(str(x), node_featurizer=self.featurize_atoms, edge_featurizer=self.featurize_bonds,explicit_hydrogens=True)))
        self.labels = list(self.df['CorrectionEnergy'])
         
    def __getitem__(self, idx):
    # get one example by index
        return self.bigraph[idx], self.labels[idx]

    def __len__(self):
    # number of graphs in dataset
        return len(self.bigraph)
    
    def save(self):
    # save graphs and labels
        graph_path = os.path.join(self._name + '_dgl_graph.bin')
        save_graphs(graph_path, self.bigraph, {'labels': torch.tensor(self.labels)})
        
    def load(self):
        # load graphs and labels
        graph_path = os.path.join(self._name + '_dgl_graph.bin')
        graphs, label_dict = load_graphs(graph_path)
        labels = label_dict['labels']
        return graphs, labels


In [1]:
def concatenate():
    data_name = ['B', 'C', 'D', 'H']
    df_list = []
    
    #load data
    for name in data_name: 
        df_list.append(pd.read_csv('../../Data/Energy/CorrectionDataset-'+ name +'.csv'))
        
    #concatenate all datasets
    training_data =  pd.concat(df_list, axis=0, ignore_index=True)
    training_data.to_csv('../../Data/Energy/CorrectionDataset-HW8.csv')

In [4]:
#concatenate()

In [5]:
m = MyDataset( '../../Data/Energy/FinalCorrectionDataset.csv') # file path to csv

In [6]:
m.process()

In [8]:
dataloader = gdl(m,batch_size=32)

In [9]:
print('done')

done
