In [1]:
from mnist1d.data import get_dataset_args, make_dataset

defaults = get_dataset_args()
data = make_dataset(defaults)
x, y, t = data["x"], data["y"], data["t"]

In [16]:
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset

from lightning_uq_box.datamodules.utils import collate_fn_tensordataset

In [11]:
import pickle
from urllib.request import urlopen

In [None]:
from lightning import LightningDataModule


class MNIST1DDataModule(LightningDataModule):
    def __init__(self, batch_size: int = 64, num_workers: int = 4) -> None:
        super().__init__()

        url = "https://github.com/greydanus/mnist1d/raw/master/mnist1d_data.pkl"
        _ = pickle.load(urlopen(url))

        X_all, y_all = make_regression(
            n_samples=1000, n_targets=5, n_features=10, noise=10
        )

        self.X_train, self.X_test, self.Y_train, self.Y_test = train_test_split(
            X_all, y_all, test_size=0.2
        )

        self.X_train, self.X_val, self.Y_train, self.y_val = train_test_split(
            self.X_train, self.Y_train, test_size=0.2
        )

        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train = TensorDataset(self.X_train, self.Y_train)
            self.val = TensorDataset(self.X_val, self.y_val)
        if stage == "test" or stage is None:
            self.test = TensorDataset(self.X_test, self.Y_test)

    def train_dataloader(self):
        return DataLoader(
            self.train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=collate_fn_tensordataset,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=collate_fn_tensordataset,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=collate_fn_tensordataset,
        )

In [15]:
X_all, y_all = make_regression(n_samples=1000, n_targets=5, n_features=10, noise=10)
X_all.shape, y_all.shape

((1000, 10), (1000, 5))