In [None]:
#default_exp lightning_helper

# lightning_helper

> Helper functions for training models using the pytorch-lightning framework

In [None]:
#export
import recsys_slates_dataset.dataset_torch as dataset_torch
import recsys_slates_dataset.datahelper as datahelper
import pytorch_lightning as pl
import logging
class SlateDataModule(pl.LightningDataModule):
    """
    A LightningDataModule wrapper around the dataloaders created in dataset_torch.
    """
    def __init__(
        self,
        data_dir= "dat",
        batch_size=1024,
        num_workers= 0,
        sample_uniform_slate=False,
        valid_pct= 0.05,
        test_pct= 0.05,
        t_testsplit= 5, *args, **kwargs):

        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers =num_workers
        self.sample_uniform_slate=sample_uniform_slate
        self.valid_pct=valid_pct
        self.test_pct=test_pct
        self.t_testsplit=t_testsplit
    def prepare_data(self):
        """ 
        Download data to disk if not already downloaded.
        """
        datahelper.download_data_files(data_dir=self.data_dir)

    def setup(self, stage=None, num_negative_queries=0):

        logging.info('Load data..')
        self.ind2val, self.attributes, self.dataloaders = dataset_torch.load_dataloaders(
            data_dir= self.data_dir,
            batch_size=self.batch_size,
            num_workers= self.num_workers,
            sample_uniform_slate=self.sample_uniform_slate,
            valid_pct= self.valid_pct,
            test_pct= self.test_pct,
            t_testsplit= self.t_testsplit)

        
        # Add some descriptive stats to the dataset as variables for easy access later:
        self.num_items = self.train_dataloader().dataset.data['slate'].max().item()+1
        _ , self.num_interactions, self.maxlen_slate = self.train_dataloader().dataset.data['slate'].size()
        self.num_users = self.train_dataloader().dataset.data['userId'].max().item()+1
        self.num_interaction_types = len(self.ind2val['interaction_type'])
    
    def train_dataloader(self):
        return self.dataloaders["train"]
    
    def val_dataloader(self):
        return self.dataloaders["valid"]
    
    def test_dataloader(self):
        return self.dataloaders["test"]

In [None]:
# slow
# Test that data is loaded
dm = SlateDataModule()
dm.prepare_data()
dm.setup()

checksum = next(iter(dm.train_dataloader()))['slate'].sum().item()
assert checksum == 98897096275, "Data error: Checksum of first batch is not expected value. Seed error?"

  and should_run_async(code)
2021-07-06 08:46:02,891 Downloading data.npz
2021-07-06 08:46:02,892 Downloading ind2val.json
2021-07-06 08:46:02,893 Downloading itemattr.npz
2021-07-06 08:46:02,893 Done downloading all files.
2021-07-06 08:46:02,894 Load data..
2021-07-06 08:46:02,894 Download data if not in data folder..
2021-07-06 08:46:02,895 Downloading data.npz
2021-07-06 08:46:02,895 Downloading ind2val.json
2021-07-06 08:46:02,896 Downloading itemattr.npz
2021-07-06 08:46:02,896 Done downloading all files.
2021-07-06 08:46:02,897 Load data..
2021-07-06 08:46:24,423 Loading dataset with slate size=torch.Size([2277645, 20, 25]) and uniform candidate sampling=False
2021-07-06 08:46:24,510 In train: num_users: 2277645, num_batches: 2225
2021-07-06 08:46:24,511 In valid: num_users: 113882, num_batches: 112
2021-07-06 08:46:24,511 In test: num_users: 113882, num_batches: 112
