In [2]:
from torch.utils.data import DataLoader
from torchdrug.datasets import ZINC250k, ZINC2m
import pytorch_lightning as pl

In [None]:
from typing import Optional

In [None]:
dataset = ZINC250k(path = 'ZINC250K', 
                            lazy = True,
                            transform = None,
                            atom_feature = 'default',
                            bond_feature = 'default',
                            with_hydrogen = False,
                            kekulize = False) #ZINC2M, MOSES
class ZINCDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "/data/ongh0068/", partial_dataset: bool = True, transforms: Optional[list] = None):
        super().__init__()
        self.data_dir = data_dir
        self.partial_dataset = partial_dataset
        self.transforms = transforms
    def prepare_data(self):
        # download
        if self.partial_dataset:
            ZINC250k(path = 'ZINC250K', 
                  lazy = True,
                  transform = None,
                  atom_feature = 'default',
                  bond_feature = 'default',
                  with_hydrogen = False,
                  kekulize = False) #ZINC2M, MOSES
        else:
            ZINC2m(path = 'ZINC2m', 
                  lazy = True,
                  transform = None,
                  atom_feature = 'default',
                  bond_feature = 'default',
                  with_hydrogen = False,
                  kekulize = False) #ZINC2M, MOSES
    def setup(self, stage: Optional[str] = None):

        
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage == "test" or stage is None:
            dataset = ZINC250k(self.data_dir, transform=self.transforms) if self.partial_dataset else ZINC2m(self.data_dir, transform=self.transforms)
            train_len, val_len = int(0.8*len(dataset)), int(0.1*len(dataset))
            test_len = len(dataset) - train_len - val_len

            self.train_set, self.valid_set, self.test_set = torch.utils.data.random_split(dataset, [train_len, val_len, test_len])

        if stage == "predict":
            self.predict_set = ZINC250k(self.data_dir) if self.partial_dataset else ZINC2m(self.data_dir)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.valid_set, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.predict_set, batch_size=32)