In [None]:
import numpy as np
import pandas as pd
import torch
import random
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

### Data exploration

In [None]:
df = pd.read_pickle("data\original\dataset_3d_vnir_xrt.pkl")

print("Nb samples: ",len(df))
print("Nb chanels in data_cube: ",len(df["data_cube"][0]))
print("")
print("Shape of 2 random samples: ")
s1, s2 = random.choice(range(len(df))), random.choice(range(len(df)))
print("     sample 1: ",df["data_cube"][s1].shape)
print("     sample 2: ",df["data_cube"][s2].shape)
print("Nb categories: ",len(df["class"].unique()))

plt.subplots(1,2,figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow(df["data_cube"][s1][0])
plt.title("Sample 1: layer 0")
plt.subplot(1,2,2)
plt.imshow(df["data_cube"][s2][0])
plt.title("Sample 2: layer 0")

pi_data_dict = {0: "3D",
                1: "VNIR 1",
                2: "VNIR 2",
                3: "VNIR 3",
                4: "VNIR 4",
                5: "VNIR 5",
                6: "VNIR 6",
                7: "VNIR 7",
                8: "VNIR 8",
                9: "XRT 1",
                10: "XRT 2"}


### Creation of a dataset class

In [None]:
class PickItDataset(Dataset):    
    def __init__(self, pickle_path, transform=None,train = True, seed = 42):
        data = pd.read_pickle(pickle_path)
        data["cube"] = data["data_cube"].apply(lambda x: x.astype(np.float32))
        cube = data["cube"]
        mask = data["masks"]
        random.seed(seed)
        test_indices = random.sample(range(0, len(data)), int(len(data) * 0.2))
        if train:
            cube = cube.drop(test_indices)
            mask = mask.drop(test_indices)
            data = data.drop(test_indices)
        else:
            cube = cube.iloc[test_indices]
            mask = mask.iloc[test_indices]
            data = data.iloc[test_indices]
        self.data = cube
        self.masks = mask
        self.labels = data["class"]
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            index = index.tolist()

        data = torch.from_numpy(self.data[idx])
        mask = torch.tensor(self.masks[idx])
        label = self.labels[idx]

        sample = {"data": data,"mask":mask, "label": label}

        if self.transform:
            sample = self.transform(sample)

        return sample


### Definition of the transforms

In [None]:
class Rescale(object):
    def __init__(self, scale):
        self.scale = scale

    def __call__(self, sample):
        data, mask ,label = sample["data"], sample["mask"],sample["label"]
        data = transforms.Resize(self.scale)(data)
        mask = transforms.Resize(self.scale)(mask.unsqueeze(0)).squeeze(0)
        return {"data": data,"mask":mask ,"label": label}


class RandomHorizontalFlip(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, sample):
        data, mask, label = sample["data"], sample["mask"], sample["label"]
        if random.random() < self.p:
            data = transforms.RandomHorizontalFlip(p=1)(data)
            mask = transforms.RandomHorizontalFlip(p=1)(mask)
        return {"data": data,"mask":mask ,"label": label}

class RandomVerticalFlip(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, sample):
        data, mask ,label = sample["data"], sample["mask"],sample["label"]
        if random.random() < self.p:
            data = transforms.RandomVerticalFlip(p=1)(data)
            mask = transforms.RandomVerticalFlip(p=1)(mask)
        return {"data": data,"mask":mask ,"label": label}

class Normalize(object):
    def __call__(self, sample):
        data, mask, label = sample["data"], sample["mask"], sample["label"]
        mean = torch.mean(data)
        std = torch.std(data)
        data = transforms.Normalize(mean,std)(data)
        return {"data": data,"mask":mask ,"label": label}


In [None]:
def plot_sample(sample, layer):
    plt.imshow(sample["data"][layer])
    plt.colorbar()
    plt.show()

### loadind the dataset, applying the transforms and creating the dataloader

In [None]:
pickle_path = "data\original\dataset_3d_vnir_xrt.pkl"
new_scale = (256, 256)
p_flip = 0.5
batch_size = 4
num_workers = 0
seed = 42

tr_train = transforms.Compose([Rescale(new_scale),RandomHorizontalFlip(p_flip),RandomVerticalFlip(p_flip),Normalize()])
tr_test = transforms.Compose([Rescale(new_scale),Normalize()])

trainset = PickItDataset(pickle_path, transform=tr_train,train = True, seed = seed)
testset = PickItDataset(pickle_path, transform=tr_test,train = False, seed = seed)

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
