## Load libraries

In [120]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import torch
import mltools.dataset as dtools

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Build a dataset from random data

In [121]:
class RandomDataset(torch.utils.data.Dataset):
    
    def __init__(self, n: int, m: int, N: int, device: str = 'cpu'):
        super().__init__()
        self._n = n
        self._m = m
        self._N = N
        self.device = device
        self.__make()
        
    def __make(self):
        self._X = torch.rand([self._N, self._n], device=self.device)
        self._Y = torch.rand([self._N, self._m], device=self.device)
        
    def __len__(self):
        return self._N

    def __getitem__(self, idx: int):
        return self._X[idx, :], self._Y[idx, :]       


unbatched_dset = RandomDataset(5, 2, 1000)
print(unbatched_dset)

<__main__.RandomDataset object at 0x0000017E54866F70>


## Split into subsets

In [122]:
# Define the batch size and seed for reproducibility
batch_size = 32
seed = 42
generator = torch.Generator().manual_seed(seed)

# Split the dataset into train and val -> Those should not have leakage
train_set, val_set = torch.utils.data.random_split(unbatched_dset, [0.8, 0.2], generator=generator)

# Split the val dataset into two test sets -> They should have leakage with the val subset
test_set1, test_set2 = torch.utils.data.random_split(val_set, [0.5, 0.5], generator=generator)



## Create DataLoaders from subsets

In [123]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader1 = torch.utils.data.DataLoader(test_set1, batch_size=batch_size, shuffle=False)
test_loader2 = torch.utils.data.DataLoader(test_set2, batch_size=batch_size, shuffle=False)

## Check leakage

In [124]:
print(
    dtools.torch_are_dataloaders_leaking(
    dict(
        train = train_loader,
        val = val_loader,
        test1 = test_loader1,
        test2 = test_loader2,
    )
)
)

No data leakage detected between train and val.
No data leakage detected between train and test1.
No data leakage detected between train and test2.
Data leakage detected between val and test1: 100 common elements found.
Percentage of overlapping elements: 33.33%
Data leakage detected between val and test2: 100 common elements found.
Percentage of overlapping elements: 33.33%
No data leakage detected between test1 and test2.
True
