In [1]:
import numpy as np
import pandas as pd
import torch
import random
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

### Data in data_cube

In [2]:
pi_data_dict = {0: "3D",
                1: "VNIR 1",
                2: "VNIR 2",
                3: "VNIR 3",
                4: "VNIR 4",
                5: "VNIR 5",
                6: "VNIR 6",
                7: "VNIR 7",
                8: "VNIR 8",
                9: "XRT 1",
                10: "XRT 2"}

### Creation of a dataset class

In [3]:
class PickItDataset(Dataset):

    def __init__(self, pickle_path, transform=None):
        data = pd.read_pickle(pickle_path)
        self.cube = data["data_cube"]
        self.masks = data["masks"]
        self.labels = data["class"]
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            index = index.tolist()

        cube = torch.from_numpy(self.cube[idx].copy())
        mask = self.masks[idx]
        label = self.labels[idx]

        sample = {"cube": cube, "mask": mask, "label": label}

        if self.transform:
            sample = self.transform(sample)

        return sample


### Definition of the transforms

In [14]:
class Rescale(object):
    def __init__(self, h, w):
        self.h = h
        self.w = w

    def __call__(self, sample):
        cube, mask, label = sample["cube"], sample["mask"], sample["label"]
        cube = transforms.Resize((self.h, self.w))(cube)
        mask = np.resize(mask, (self.h, self.w))
        return {"cube": cube, "mask": mask, "label": label}


class RandomHorizontalFlip(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, sample):
        cube, mask, label = sample["cube"], sample["mask"], sample["label"]
        if random.random() < self.p:
            cube = transforms.RandomHorizontalFlip(p=1)(cube)
            mask = np.fliplr(mask)
        return {"cube": cube, "mask": mask.copy(), "label": label}
    
class RandomVerticalFlip(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, sample):
        cube, mask, label = sample["cube"], sample["mask"], sample["label"]
        if random.random() < self.p:
            cube = transforms.RandomVerticalFlip(p=1)(cube)
            mask = np.flipud(mask)
        return {"cube": cube, "mask": mask.copy(), "label": label}

### loadind the dataset, applying the transforms and creating the dataloader

In [15]:
pickle_path = './data/original/dataset_3d_vnir_xrt.pkl'

composed = transforms.Compose([Rescale(256, 256)])

pickit_dataset = PickItDataset(pickle_path, transform=composed)
pickit_dataloader = DataLoader(pickit_dataset, batch_size=4, shuffle=True, num_workers=0)


In [17]:
next(iter(pickit_dataloader))["cube"].shape

torch.Size([4, 11, 256, 256])