In [1]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchdrug.datasets import ZINC250k, ZINC2m


In [14]:
from typing import Optional
import torch

In [15]:
# dataset = ZINC250k(path = 'ZINC250K', 
#                             lazy = True,
#                             transform = None,
#                             atom_feature = 'default',
#                             bond_feature = 'default',
#                             with_hydrogen = False,
#                             kekulize = False) #ZINC2M, MOSES
class ZINCDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "/data/ongh0068/", partial_dataset: bool = True, transforms: Optional[list] = None):
        super().__init__()
        self.data_dir = data_dir
        self.partial_dataset = partial_dataset
        self.transforms = transforms
    def prepare_data(self):
        # download
        if self.partial_dataset:
            ZINC250k(path = self.data_dir, 
                  lazy = True,
                  transform = None,
                  atom_feature = 'default',
                  bond_feature = 'default',
                  with_hydrogen = False,
                  kekulize = False) #ZINC2M, MOSES
        else:
            ZINC2m(path = self.data_dir, 
                  lazy = True,
                  transform = None,
                  atom_feature = 'default',
                  bond_feature = 'default',
                  with_hydrogen = False,
                  kekulize = False) #ZINC2M, MOSES
    def setup(self, stage: Optional[str] = None):

        
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage == "test" or stage is None:
            dataset = ZINC250k(self.data_dir, lazy= True,transform=self.transforms) if self.partial_dataset else ZINC2m(self.data_dir, transform=self.transforms)
            train_len, val_len = int(0.8*len(dataset)), int(0.1*len(dataset))
            test_len = len(dataset) - train_len - val_len

            self.train_set, self.valid_set, self.test_set = torch.utils.data.random_split(dataset, [train_len, val_len, test_len])

        if stage == "predict":
            self.predict_set = ZINC250k(self.data_dir) if self.partial_dataset else ZINC2m(self.data_dir)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.valid_set, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.predict_set, batch_size=32)

In [16]:
dataset = ZINCDataModule(data_dir = '/data/ongh0068/zinc250k')


Loading /data/ongh0068/zinc250k/250k_rndm_zinc_drugs_clean_3.csv:  50%|▌
Constructing molecules from SMILES: 100%|█| 249455/249455 [00:00<00:00, 


In [None]:
dataset.setup()

In [19]:
len(dataset.train_dataloader())

6237

In [23]:
dir(dataset.train_dataloader())

['_DataLoader__initialized',
 '_DataLoader__multiprocessing_context',
 '_IterableDataset_len_called',
 '__annotations__',
 '__class__',
 '__class_getitem__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__orig_bases__',
 '__parameters__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_auto_collation',
 '_dataset_kind',
 '_get_iterator',
 '_get_shared_seed',
 '_index_sampler',
 '_is_protocol',
 '_iterator',
 'batch_sampler',
 'batch_size',
 'check_worker_number_rationality',
 'collate_fn',
 'dataset',
 'drop_last',
 'generator',
 'multiprocessing_context',
 'num_workers',
 'persistent_workers',
 'pin_memory',
 'pin_memory_device',
 'prefetch_factor',
 'sampler',
 'timeout',
 'worker_i

In [32]:
for i in iter(dataset.train_dataloader()):
    print(i)

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'torchdrug.data.molecule.Molecule'>