In [None]:
# code to convert SMILES string to bigraph, pull energies as labels
# save and load graphs and energies as input (X) and labels (y)

In [7]:
import torch
from rdkit import Chem
from dgllife.utils import smiles_to_bigraph
# from dgllife.utils import mol_to_bigraph
from dgl.data.utils import load_graphs
from dgl.data.utils import save_graphs
from dgl.dataloading import GraphDataLoader
import pandas as pd
import numpy as np
import os

In [27]:
class MyDataset():
    def __init__(self, df:str):
        self.df = pd.read_csv(df)
        self.name = os.path.basename(df)
        
        self.graphs = [] # X
        self.label =  np.ndarray([]) # y 
        

    def featurize_atoms(self, mol):
    # featurize atoms
        feats = []
        for atom in mol.GetAtoms():
            feats.append(atom.GetAtomicNum())
        return {'atomic': torch.tensor(feats).reshape(-1, 1).float()}
    
    def featurize_bonds(self, mol):
    # featurize bonds
        feats = []
        bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                    Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
        for bond in mol.GetBonds():
            btype = bond_types.index(bond.GetBondType())
            feats.extend([btype, btype])
        return {'type': torch.tensor(feats).reshape(-1, 1).float()}
    
    def process(self):
    # make bigraphs and labels (energy)
        self.graphs = list(self.df['Smiles'].apply(lambda x: smiles_to_bigraph(str(x), node_featurizer=self.featurize_atoms, edge_featurizer=self.featurize_bonds,explicit_hydrogens=True)))
        self.label = np.array(self.df["Energy"])
    
    def __getitem__(self, idx):
    # get one example by index
        return self.graphs[idx], self.label[idx]

    def __len__(self):
    # number of graphs in dataset
        return len(self.graphs)
    
    def save(self, path):
    # save graphs and labels
        graph_path = os.path.join(path, self.name + '_dgl_graph.bin')
        save_graphs(graph_path, self.graphs, {'labels': torch.tensor(self.label)})
        

In [28]:
m = MyDataset( '../../Data/Energy/EnergyDataset-B-Copy1.csv') # file path to csv
m.process()
m.save('../../Data/Energy/graphs')

In [29]:
graphdataloader = GraphDataLoader(m, batch_size=32, shuffle=True)

In [None]:
print('done')

In [32]:
'''def load(name, path):
    # load graphs and labels
    graph_path = os.path.join(path, name + '_dgl_graph.bin')
    graphs, label_dict = load_graphs(graph_path)
    labels = label_dict['labels']
    return graphs, labels'''

In [5]:
#X, y = load("EnergyDataset-B-Copy1.csv","../../Data/Energy/graphs") # name = csv file name, "./graphs" = folder name to store graph data

In [6]:
# X

In [7]:
# y