In [None]:
#| default_exp data.torch

# PyTorch data loaders and transforms

> PyTorch DataLoaders, DataSet and transforms

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#|export
#nbdev_comment from __future__ import annotations
import numpy as np

from fastcore.test import *

from mirzai.data.loading import load_kssl
from mirzai.data.selection import (select_y, select_tax_order, select_X)
from mirzai.data.transform import (log_transform_y, SNV)

from sklearn.model_selection import train_test_split

from fastcore.transform import compose

import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

  from .autonotebook import tqdm as notebook_tqdm


## Loaders & datasets

In [None]:
#|export
class SpectralDataset(Dataset):
    def __init__(self, X, y, tax_order, transform=None):
        self.X = X
        self.y = y
        self.tax_order = tax_order
        self.transform = transform

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        X = self.X[None, idx, :]
        y = self.y[None, idx]
        tax_order = self.tax_order[None, idx]
        if self.transform:
            X = self.transform(X)
        return X.astype(np.float32), y.astype(np.float32), tax_order.astype(np.intc)

In [None]:
#|export
class DataLoaders():
    def __init__(self, *args, transform=None, batch_size=32):
        """
        Convert numpy error to Pytorch data loaders (generators)
        Args:
            *args: one or many tuple as ((X_train, y_train, tax_order), (X_test, y_test, tax_order))
            transform: callable class (__class__)

        Returns:
            (training_generator, validation_generator)
        """
        self.data = args
        self.batch_size = batch_size
        self.transform = transform if transform else Noop()

    def loaders(self):
        return (DataLoader(SpectralDataset(X, y, tax_order, transform=self.transform), 
                           batch_size=self.batch_size,
                           drop_last=False)
                for X, y, tax_order in self.data)

## Transforms

In [None]:
#|export
class SNV_transform():
    def __init__(self):
        None
    def __call__(self, spectrum):
        return SNV().fit_transform(spectrum)

In [None]:
#|export
class Noop():
    def __init__(self):
        None
    def __call__(self, X):
        return X

## Example of use

### Load and preprocess data

In [None]:
src_dir = 'test'
fnames = ['spectra-features-smp.npy', 'spectra-wavenumbers-smp.npy', 
          'depth-order-smp.npy', 'target-smp.npy', 
          'tax-order-lu-smp.pkl', 'spectra-id-smp.npy']

X, X_names, depth_order, y, tax_lookup, X_id = load_kssl(src_dir, fnames=fnames)
transforms = [select_y, select_tax_order, select_X, log_transform_y]

data = X, y, X_id, depth_order
X, y, X_id, depth_order = compose(*transforms)(data)

### Train/test split

In [None]:
data = train_test_split(X, y, depth_order[:, 1], test_size=0.1, random_state=42)
X_train, X_test, y_train, y_test, tax_order_train, tax_order_test = data


data = train_test_split(X_train, y_train, tax_order_train, test_size=0.1, random_state=42)
X_train, X_valid, y_train, y_valid, tax_order_train, tax_order_valid = data

### Create the generators

In [None]:
dls = DataLoaders((X_train, y_train, tax_order_train), 
                  (X_valid, y_valid, tax_order_valid), 
                  (X_test, y_test, tax_order_test), transform=SNV_transform())

training_generator, validation_generator, test_generator = dls.loaders()

### Iterate over data (features, targets) mini batches

In [None]:
for features, target, tax in training_generator:
    print(f'Batch of features (spectra): {features.shape}')
    print(f'Batch of targets: {target.shape}')
    print(f'Batch of Soil taxonomy orders id: {tax.shape}')

Batch of features (spectra): torch.Size([32, 1, 1764])
Batch of targets: torch.Size([32, 1])
Batch of Soil taxonomy orders id: torch.Size([32, 1])
Batch of features (spectra): torch.Size([31, 1, 1764])
Batch of targets: torch.Size([31, 1])
Batch of Soil taxonomy orders id: torch.Size([31, 1])


In [None]:
for features, target, _ in validation_generator:
    print(f'Batch of features (spectra): {features.shape}')
    print(f'Batch of targets: {target.shape}')

Batch of features (spectra): torch.Size([8, 1, 1764])
Batch of targets: torch.Size([8, 1])


In [None]:
for features, target, _ in test_generator:
    print(f'Batch of features (spectra): {features.shape}')
    print(f'Batch of targets: {target.shape}')

Batch of features (spectra): torch.Size([8, 1, 1764])
Batch of targets: torch.Size([8, 1])
