In [None]:
#default_exp pl_dataloaders
#export

import warnings
from torchmeta.datasets.helpers import omniglot, miniimagenet, ClassSplitter
from torchmeta.datasets import Omniglot
from torchmeta.utils.data import BatchMetaDataLoader

import pytorch_lightning as pl

In [None]:
#export
class OmniglotDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        shots: int,
        ways: int,
        shuffle_ds: bool,
        test_shots: int,
        meta_train: bool,
        download: bool,
        batch_size: str,
        shuffle: bool,
        num_workers: int):
        super().__init__()
        self.data_dir = data_dir
        self.shots = shots
        self.ways = ways
        self.shuffle_ds = shuffle_ds
        self.test_shots = test_shots
        self.meta_train = meta_train
        self.download = download
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
    
    def setup(self, stage=None):
        self.task_dataset = omniglot(
            self.data_dir,
            shots=self.shots,
            ways=self.ways,
            shuffle=self.shuffle_ds,
            test_shots=self.test_shots,
            meta_train=self.meta_train,
            download=self.download
        )
    def train_dataloader(self):
        return BatchMetaDataLoader(
            self.task_dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers
        )
    
    def val_dataloader(self):
        self.val_tasks = omniglot(
            self.data_dir,
            shots=self.shots,
            ways=self.ways,
            shuffle=self.shuffle_ds,
            test_shots=self.test_shots,
            meta_val=True,
            download=self.download
        )
        return BatchMetaDataLoader(
            self.val_tasks,
            batch_size=self.batch_size,
            num_workers=self.num_workers
        )
    
    def test_dataloader(self):
        self.test_tasks = omniglot(
            self.data_dir,
            shots=self.shots,
            ways=self.ways,
            shuffle=self.shuffle_ds,
            test_shots=self.test_shots,
            meta_test=True,
            download=self.download
        )
        return BatchMetaDataLoader(
            self.test_tasks,
            batch_size=self.batch_size,
            num_workers=self.num_workers
        )

In [None]:
ds = omniglot(
            'data/',
            shots=1,
            ways=5,
            shuffle=True,
            test_shots=15,
            meta_train=True,
            download=True
        )

In [None]:
dl = BatchMetaDataLoader(
            ds,
            batch_size=16,
            num_workers=4
        )

In [None]:
next(iter(dl))['test'][0].shape



torch.Size([16, 75, 1, 28, 28])

In [None]:
#export
class MiniImagenetDataModule(pl.LightningDataModule):
    def __init__(self,
                 data_dir: str,
                 shots: int,
                 ways: int,
                 shuffle_ds: bool,
                 test_shots: int,
                 meta_train: bool,
                 download: bool,
                 batch_size: str,
                 shuffle: bool,
                 num_workers: int):
        self.data_dir = data_dir
        self.shots = shots
        self.ways = ways
        self.shuffle_ds = shuffle_ds
        self.test_shots = test_shots
        self.meta_train = meta_train
        self.download = download
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
    
    def setup(self):
        self.train_taskset = miniimagenet(
            self.data_dir,
            shots=self.shots,
            ways=self.ways,
            shuffle=self.shuffle_ds,
            test_shots=self.test_shots,
            meta_train=True,
            download=self.download
        )
    
    def train_dataloader(self):
        return BatchMetaDataLoader(
            self.train_taskset,
            shuffle=self.shuffle,
            batch_size=self.batch_size,
            num_workers=self.num_workers
        )
    def val_dataloader(self):
        self.val_taskset = miniimagenet(
            self.data_dir,
            shots=self.shots,
            ways=self.ways,
            shuffle=self.shuffle_ds,
            test_shots=self.test_shots,
            meta_val=True,
            download=self.download
        )
        return BatchMetaDataLoader(
            self.val_taskset,
            shuffle=self.shuffle,
            batch_size=self.batch_size,
            num_workers=self.num_workers
        )
    
    def test_dataloader(self):
        self.test_taskset = miniimagenet(
            self.data_dir,
            shots=self.shots,
            ways=self.ways,
            shuffle=False,
            test_shots=self.test_shots,
            meta_test=True,
            download=self.download
        )
        return BatchMetaDataLoader(
            self.test_taskset,
            shuffle=False,
            batch_size=self.batch_size,
            num_workers=self.num_workers
        )

In [None]:
ds = omniglot(
            'data/',
            shots=1,
            ways=5,
            shuffle=False,
            test_shots=15,
            meta_train=True,
            download=True
        )

In [None]:
dl = BatchMetaDataLoader(ds, batch_size=16)

In [None]:
next(iter(dl))['train'][1]



tensor([[0, 3, 1, 4, 2],
        [4, 2, 0, 1, 3],
        [4, 0, 3, 2, 1],
        [1, 4, 3, 2, 0],
        [4, 0, 2, 3, 1],
        [0, 2, 1, 3, 4],
        [2, 1, 0, 4, 3],
        [4, 2, 1, 3, 0],
        [0, 3, 2, 1, 4],
        [0, 1, 4, 3, 2],
        [0, 1, 2, 3, 4],
        [2, 4, 1, 0, 3],
        [4, 0, 3, 1, 2],
        [2, 4, 0, 3, 1],
        [0, 3, 1, 4, 2],
        [3, 1, 2, 0, 4]])

In [None]:
from nbdev.export import notebook2script; notebook2script()

Converted 01_nn_utils.ipynb.
Converted 01b_data_loaders_pl.ipynb.
Converted 01c_grad_utils.ipynb.
Converted 01d_hessian_free.ipynb.
Converted 02_maml_pl.ipynb.
Converted 02b_iMAML.ipynb.
Converted 03_protonet_pl.ipynb.
Converted 04_cactus.ipynb.
Converted index.ipynb.
