# Setup

In [1]:
#default_exp pl_dataloaders
#export
import warnings
import h5py
import os
import io
import numpy as np
import json
import pytorch_lightning as pl
import torch
import torch.nn.functional as F

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import to_tensor
from torchvision.utils import make_grid
from torchvision import transforms

from torchmeta.datasets.helpers import omniglot, miniimagenet, ClassSplitter
from torchmeta.datasets import Omniglot
from torchmeta.utils.data import BatchMetaDataLoader

  rank_zero_deprecation(


# Custom unlabelled dataset

In [2]:
#export
class UnlabelledDataset(Dataset):
    def __init__(self, dataset, datapath, split, transform=None,
                 n_support=1, n_query=1, n_images=None, n_classes=None,
                 seed=10, no_aug_support=False, no_aug_query=False):
        """
        Args:
            dataset (string): Dataset name.
            datapath (string): Directory containing the datasets.
            split (string): The dataset split to load.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            n_support (int): Number of support examples
            n_query (int): Number of query examples
            no_aug_support (bool): Wheteher to not apply any augmentations to the support
            no_aug_query (bool): Wheteher to not apply any augmentations to the query
            n_images (int): Limit the number of images to load.
            n_classes (int): Limit the number of classes to load.
            seed (int): Random seed to for selecting images to load.
        """
        self.n_support = n_support
        self.n_query = n_query
        self.img_size = (28, 28) if dataset=='omniglot' else (84, 84)
        self.no_aug_support = no_aug_support
        self.no_aug_query = no_aug_query

        # Get the data or paths
        self.dataset = dataset
        self.data = self._extract_data_from_hdf5(dataset, datapath, split, 
                                                 n_classes, seed)

        # Optionally only load a subset of images
        if n_images is not None:
            random_idxs = np.random.RandomState(seed).permutation(len(self))[:n_images]
            self.data = self.data[random_idxs]

        # Get transform
        if transform is not None:
            self.transform = transform
        else:
            if self.dataset == 'cub':
                self.transform = transforms.Compose([
                    get_cub_default_transform(self.img_size),
                    get_custom_transform(self.img_size)])
                self.original_transform = transforms.Compose([
                    get_cub_default_transform(self.img_size),
                    transforms.ToTensor()])
            elif self.dataset == 'omniglot':
                self.transform = get_omniglot_transform((28, 28))
                self.original_transform = identity_transform((28, 28))
            else:
                self.transform = get_custom_transform(self.img_size)
                self.original_transform = identity_transform(self.img_size)

    def _extract_data_from_hdf5(self, dataset, datapath, split,
                                n_classes, seed):
        datapath = os.path.join(datapath, dataset)

        # Load omniglot
        if dataset == 'omniglot':
            classes = []
            with h5py.File(os.path.join(datapath, 'data.hdf5'), 'r') as f_data:
                with open(os.path.join(datapath,
                          'vinyals_{}_labels.json'.format(split))) as f_labels:
                    labels = json.load(f_labels)
                    for label in labels:
                        img_set, alphabet, character = label
                        classes.append(f_data[img_set][alphabet][character][()])
        # Load mini-imageNet
        else:
            with h5py.File(os.path.join(datapath, split + '_data.hdf5'), 'r') as f:
                datasets = f['datasets']
                classes = [datasets[k][()] for k in datasets.keys()]

        # Optionally filter out some classes
        if n_classes is not None:
            random_idxs = np.random.RandomState(seed).permutation(len(classes))[:n_classes]
            classes = [classes[i] for i in random_idxs]

        # Collect in single array
        data = np.concatenate(classes)
        return data

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        if self.dataset == 'cub':
            image = Image.open(io.BytesIO(self.data[index])).convert('RGB')
        else:
            image = Image.fromarray(self.data[index])

        view_list = []
        
        
        for _ in range(self.n_support):
            if not self.no_aug_support:
                view_list.append(self.transform(image).unsqueeze(0))
            else:
                assert self.n_support == 1
                view_list.append(self.original_transform(image).unsqueeze(0))
        
        for _ in range(self.n_query):
            if not self.no_aug_query:
                view_list.append(self.transform(image).unsqueeze(0))
            else:
                assert self.n_query == 1
                view_list.append(self.original_transform(image).unsqueeze(0))
        
        return dict(data=torch.cat(view_list))

In [3]:
#export
def get_cub_default_transform(size):
    return transforms.Compose([
        transforms.Resize([int(size[0] * 1.5), int(size[1] * 1.5)]),
        transforms.CenterCrop(size)])

def get_simCLR_transform(img_shape):
    """Adapted from https://github.com/sthalles/SimCLR/blob/master/data_aug/dataset_wrapper.py"""
    color_jitter = transforms.ColorJitter(brightness=0.8, contrast=0.8,
                                          saturation=0.8, hue=0.2)
    data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=img_shape[-2:]),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomApply([color_jitter], p=0.8),
                                          transforms.RandomGrayscale(p=0.2),
                                         # GaussianBlur(kernel_size=int(0.1 * self.input_shape[0])),
                                          transforms.ToTensor()])
    return data_transforms

def get_omniglot_transform(img_shape):
    data_transforms = transforms.Compose([
                                          transforms.Resize(img_shape[-2:]),
                                          transforms.RandomResizedCrop(size=img_shape[-2:],
                                                                       scale=(0.6, 1.4)),
                                          transforms.RandomHorizontalFlip(p=0.5),
                                          transforms.RandomVerticalFlip(p=0.5),
                                          transforms.ToTensor(),
                                          transforms.Lambda(lambda t: F.dropout(t, p=0.3)),
                                          transforms.RandomErasing()
                                          ])
    return data_transforms

def get_custom_transform(img_shape):
    color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                          saturation=0.4, hue=0.1)
    data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=img_shape[-2:],
                                                                       scale=(0.5, 1.0)),
                                          transforms.RandomHorizontalFlip(p=0.5),
                                          transforms.RandomVerticalFlip(p=0.5),
                                          transforms.RandomApply([color_jitter], p=0.8),
                                          transforms.RandomGrayscale(p=0.2),
                                          transforms.ToTensor()])
    return data_transforms

def identity_transform(img_shape):
    return transforms.Compose([transforms.Resize(img_shape),
                               transforms.ToTensor()])

In [None]:
#export

class UnlabelledDataModule(pl.LightningDataModule):
    def __init__(self, dataset, datapath, split, transform=None,
                 n_support=1, n_query=1, n_images=None, n_classes=None,
                 seed=10, no_aug_support=False, no_aug_query=False, merge_train_val=False):
        self.n_images = n_images
        self.n_support = n_support
        self.n_query = n_query
        self.img_size = (28, 28) if dataset=='omniglot' else (84, 84)
        self.no_aug_support = no_aug_support
        self.no_aug_query = no_aug_query

        # Get the data or paths
        self.dataset = dataset
        self.datapath = datapath
        
        self.merge_train_val = merge_train_val
        
    def setup(self, stage=None):
        self.dataset_train = UnlabelledDataset(self.dataset,
                                          self.datapath, split='train',
                                          transform=None,
                                          n_images=self.n_images,
                                          n_classes=self.n_classes,
                                          n_support=self.n_support,
                                          n_query=self.n_query,
                                          no_aug_support=self.no_aug_support,
                                          no_aug_query=self.no_aug_query)
        if self.merge_train_val:
            dataset_val = UnlabelledDataset(args.dataset, args.datapath, 'val',
                                            transform=None,
                                            n_support=self.n_support,
                                            n_query=self.n_query,
                                            no_aug_support=self.no_aug_support,
                                            no_aug_query=self.no_aug_query)

            self.dataset_train = ConcatDataset([dataset_train, dataset_val])
            
    def train_dataloader(self):
        dataloader_train = DataLoader(self.dataset_train,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      pin_memory=torch.cuda.is_available())
        return dataloader_train
    
    def val_dataloader(self):
        dataset_val = UnlabelledDataset(args.dataset, args.datapath, 'val',
                                            transform=None,
                                            n_support=self.n_support,
                                            n_query=self.n_query,
                                            no_aug_support=self.no_aug_support,
                                            no_aug_query=self.no_aug_query)
        dataloader_val = DataLoader(dataset_val,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      pin_memory=torch.cuda.is_available())
        
        return dataloader_val
        

In [None]:
#export
class OmniglotDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        shots: int,
        ways: int,
        shuffle_ds: bool,
        test_shots: int,
        meta_train: bool,
        download: bool,
        batch_size: str,
        shuffle: bool,
        num_workers: int):
        super().__init__()
        self.data_dir = data_dir
        self.shots = shots
        self.ways = ways
        self.shuffle_ds = shuffle_ds
        self.test_shots = test_shots
        self.meta_train = meta_train
        self.download = download
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
    
    def setup(self, stage=None):
        self.task_dataset = omniglot(
            self.data_dir,
            shots=self.shots,
            ways=self.ways,
            shuffle=self.shuffle_ds,
            test_shots=self.test_shots,
            meta_train=self.meta_train,
            download=self.download
        )
    def train_dataloader(self):
        return BatchMetaDataLoader(
            self.task_dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers
        )
    
    def val_dataloader(self):
        self.val_tasks = omniglot(
            self.data_dir,
            shots=self.shots,
            ways=self.ways,
            shuffle=self.shuffle_ds,
            test_shots=self.test_shots,
            meta_val=True,
            download=self.download
        )
        return BatchMetaDataLoader(
            self.val_tasks,
            batch_size=self.batch_size,
            num_workers=self.num_workers
        )
    
    def test_dataloader(self):
        self.test_tasks = omniglot(
            self.data_dir,
            shots=self.shots,
            ways=self.ways,
            shuffle=self.shuffle_ds,
            test_shots=self.test_shots,
            meta_test=True,
            download=self.download
        )
        return BatchMetaDataLoader(
            self.test_tasks,
            batch_size=self.batch_size,
            num_workers=self.num_workers
        )

In [None]:
ds = omniglot(
            'data/',
            shots=1,
            ways=5,
            shuffle=True,
            test_shots=15,
            meta_train=True,
            download=True
        )

In [None]:
dl = BatchMetaDataLoader(
            ds,
            batch_size=16,
            num_workers=4
        )

In [None]:
next(iter(dl))['test'][0].shape



torch.Size([16, 75, 1, 28, 28])

In [None]:
#export
class MiniImagenetDataModule(pl.LightningDataModule):
    def __init__(self,
                 data_dir: str,
                 shots: int,
                 ways: int,
                 shuffle_ds: bool,
                 test_shots: int,
                 meta_train: bool,
                 download: bool,
                 batch_size: str,
                 shuffle: bool,
                 num_workers: int):
        self.data_dir = data_dir
        self.shots = shots
        self.ways = ways
        self.shuffle_ds = shuffle_ds
        self.test_shots = test_shots
        self.meta_train = meta_train
        self.download = download
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
    
    def setup(self):
        self.train_taskset = miniimagenet(
            self.data_dir,
            shots=self.shots,
            ways=self.ways,
            shuffle=self.shuffle_ds,
            test_shots=self.test_shots,
            meta_train=True,
            download=self.download
        )
    
    def train_dataloader(self):
        return BatchMetaDataLoader(
            self.train_taskset,
            shuffle=self.shuffle,
            batch_size=self.batch_size,
            num_workers=self.num_workers
        )
    def val_dataloader(self):
        self.val_taskset = miniimagenet(
            self.data_dir,
            shots=self.shots,
            ways=self.ways,
            shuffle=self.shuffle_ds,
            test_shots=self.test_shots,
            meta_val=True,
            download=self.download
        )
        return BatchMetaDataLoader(
            self.val_taskset,
            shuffle=self.shuffle,
            batch_size=self.batch_size,
            num_workers=self.num_workers
        )
    
    def test_dataloader(self):
        self.test_taskset = miniimagenet(
            self.data_dir,
            shots=self.shots,
            ways=self.ways,
            shuffle=False,
            test_shots=self.test_shots,
            meta_test=True,
            download=self.download
        )
        return BatchMetaDataLoader(
            self.test_taskset,
            shuffle=False,
            batch_size=self.batch_size,
            num_workers=self.num_workers
        )

In [None]:
ds = omniglot(
            'data/',
            shots=1,
            ways=5,
            shuffle=False,
            test_shots=15,
            meta_train=True,
            download=True
        )

In [None]:
dl = BatchMetaDataLoader(ds, batch_size=16)

In [None]:
next(iter(dl))['train'][1]



tensor([[0, 3, 1, 4, 2],
        [4, 2, 0, 1, 3],
        [4, 0, 3, 2, 1],
        [1, 4, 3, 2, 0],
        [4, 0, 2, 3, 1],
        [0, 2, 1, 3, 4],
        [2, 1, 0, 4, 3],
        [4, 2, 1, 3, 0],
        [0, 3, 2, 1, 4],
        [0, 1, 4, 3, 2],
        [0, 1, 2, 3, 4],
        [2, 4, 1, 0, 3],
        [4, 0, 3, 1, 2],
        [2, 4, 0, 3, 1],
        [0, 3, 1, 4, 2],
        [3, 1, 2, 0, 4]])

In [None]:
#export
def get_cub_default_transform(size):
    return transforms.Compose([
        transforms.Resize([int(size[0] * 1.5), int(size[1] * 1.5)]),
        transforms.CenterCrop(size)])

def get_simCLR_transform(img_shape):
    """Adapted from https://github.com/sthalles/SimCLR/blob/master/data_aug/dataset_wrapper.py"""
    color_jitter = transforms.ColorJitter(brightness=0.8, contrast=0.8,
                                          saturation=0.8, hue=0.2)
    data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=img_shape[-2:]),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomApply([color_jitter], p=0.8),
                                          transforms.RandomGrayscale(p=0.2),
                                         # GaussianBlur(kernel_size=int(0.1 * self.input_shape[0])),
                                          transforms.ToTensor()])
    return data_transforms

def get_omniglot_transform(img_shape):
    data_transforms = transforms.Compose([
                                          transforms.Resize(img_shape[-2:]),
                                          transforms.RandomResizedCrop(size=img_shape[-2:],
                                                                       scale=(0.6, 1.4)),
                                          transforms.RandomHorizontalFlip(p=0.5),
                                          transforms.RandomVerticalFlip(p=0.5),
                                          transforms.ToTensor(),
                                          transforms.Lambda(lambda t: F.dropout(t, p=0.3)),
                                          transforms.RandomErasing()
                                          ])
    return data_transforms

def get_custom_transform(img_shape):
    color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                          saturation=0.4, hue=0.1)
    data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=img_shape[-2:],
                                                                       scale=(0.5, 1.0)),
                                          transforms.RandomHorizontalFlip(p=0.5),
                                          transforms.RandomVerticalFlip(p=0.5),
                                          transforms.RandomApply([color_jitter], p=0.8),
                                          transforms.RandomGrayscale(p=0.2),
                                          transforms.ToTensor()])
    return data_transforms

def identity_transform(img_shape):
    return transforms.Compose([transforms.Resize(img_shape),
                               transforms.ToTensor()])

In [None]:
#export
class UnlabelledDataset(Dataset):
    def __init__(self, dataset, datapath, split, transform=None,
                 n_support=1, n_query=1, n_images=None, n_classes=None,
                 seed=10, no_aug_support=False, no_aug_query=False):
        """
        Args:
            dataset (string): Dataset name.
            datapath (string): Directory containing the datasets.
            split (string): The dataset split to load.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            n_support (int): Number of support examples
            n_query (int): Number of query examples
            no_aug_support (bool): Wheteher to not apply any augmentations to the support
            no_aug_query (bool): Wheteher to not apply any augmentations to the query
            n_images (int): Limit the number of images to load.
            n_classes (int): Limit the number of classes to load.
            seed (int): Random seed to for selecting images to load.
        """
        self.n_support = n_support
        self.n_query = n_query
        self.img_size = (28, 28) if dataset=='omniglot' else (84, 84)
        self.no_aug_support = no_aug_support
        self.no_aug_query = no_aug_query

        # Get the data or paths
        self.dataset = dataset
        self.data = self._extract_data_from_hdf5(dataset, datapath, split, 
                                                 n_classes, seed)

        # Optionally only load a subset of images
        if n_images is not None:
            random_idxs = np.random.RandomState(seed).permutation(len(self))[:n_images]
            self.data = self.data[random_idxs]

        # Get transform
        if transform is not None:
            self.transform = transform
        else:
            if self.dataset == 'cub':
                self.transform = transforms.Compose([
                    get_cub_default_transform(self.img_size),
                    get_custom_transform(self.img_size)])
                self.original_transform = transforms.Compose([
                    get_cub_default_transform(self.img_size),
                    transforms.ToTensor()])
            elif self.dataset == 'omniglot':
                self.transform = get_omniglot_transform((28, 28))
                self.original_transform = identity_transform((28, 28))
            else:
                self.transform = get_custom_transform(self.img_size)
                self.original_transform = identity_transform(self.img_size)

    def _extract_data_from_hdf5(self, dataset, datapath, split,
                                n_classes, seed):
        datapath = os.path.join(datapath, dataset)

        # Load omniglot
        if dataset == 'omniglot':
            classes = []
            with h5py.File(os.path.join(datapath, 'data.hdf5'), 'r') as f_data:
                with open(os.path.join(datapath,
                          'vinyals_{}_labels.json'.format(split))) as f_labels:
                    labels = json.load(f_labels)
                    for label in labels:
                        img_set, alphabet, character = label
                        classes.append(f_data[img_set][alphabet][character][()])
        # Load mini-imageNet
        else:
            with h5py.File(os.path.join(datapath, split + '_data.hdf5'), 'r') as f:
                datasets = f['datasets']
                classes = [datasets[k][()] for k in datasets.keys()]

        # Optionally filter out some classes
        if n_classes is not None:
            random_idxs = np.random.RandomState(seed).permutation(len(classes))[:n_classes]
            classes = [classes[i] for i in random_idxs]

        # Collect in single array
        data = np.concatenate(classes)
        return data

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        if self.dataset == 'cub':
            image = Image.open(io.BytesIO(self.data[index])).convert('RGB')
        else:
            image = Image.fromarray(self.data[index])

        view_list = []
        
        
        for _ in range(self.n_support):
            if not self.no_aug_support:
                view_list.append(self.transform(image).unsqueeze(0))
            else:
                assert self.n_support == 1
                view_list.append(self.original_transform(image).unsqueeze(0))
        
        for _ in range(self.n_query):
            if not self.no_aug_query:
                view_list.append(self.transform(image).unsqueeze(0))
            else:
                assert self.n_query == 1
                view_list.append(self.original_transform(image).unsqueeze(0))
        
        return dict(data=torch.cat(view_list))

In [1]:
from nbdev.export import notebook2script; notebook2script()

Converted 01_nn_utils.ipynb.
Converted 01b_data_loaders_pl.ipynb.
Converted 01c_grad_utils.ipynb.
Converted 01d_proto_utils.ipynb.
Converted 02_maml_pl.ipynb.
Converted 02b_iMAML.ipynb.
Converted 03_protonet_pl.ipynb.
Converted 03b_ProtoCLR.ipynb.
Converted 04_cactus.ipynb.
Converted index.ipynb.
