# Dataloaders

> Custom PyTorch data loaders

In [None]:
#| default_exp dataloaders

In [None]:
%load_ext autoreload
%autoreload 2

In [28]:
#| export
from pathlib import Path
from tqdm import tqdm
from collections import namedtuple
import fastcore.all as fc

from torch.utils.data import Dataset, DataLoader

In [20]:
#|eval: false
from lssm.loading import load_ossl
from lssm.preprocessing import ToAbsorbance, ContinuumRemoval, SNV
from lssm.visualization import plot_spectra
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split

In [57]:
#| export
class SpectralDataset(Dataset):
    def __init__(self, X, y, metadata=None):
        self.X = X
        self.y = y
        self.metadata = metadata
        
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx,:][None,:], self.y[idx,:]

In [39]:
#| export
def get_dls(train_ds, valid_ds, bs, **kwargs):
    Dataloaders = namedtuple('Dataloader', ['train', 'valid'])
    return Dataloaders(
        DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
        DataLoader(valid_ds, batch_size=bs*2, **kwargs)) 

Example:

In [52]:
#|eval: false
fname_ossl = Path.home() / 'pro/data/ossl/gcs_version/ossl_all_L0_v1.2.csv.gz'
analytes = 'k.ext_usda.a725_cmolc.kg'

# Load dataset
data = load_ossl(fname_ossl, analytes, spectra_type='visnir')
X, y, X_names, smp_idx, ds_name, ds_label = data

100%|██████████| 3725/3725 [00:01<00:00, 2876.28it/s]


First batch X dim: torch.Size([32, 1, 1076])
First batch y dim: torch.Size([32, 1])


In [58]:
#|eval: false

# Transform
X = Pipeline([('to_abs', ToAbsorbance()), 
              ('cr', ContinuumRemoval(X_names))]).fit_transform(X)

# Train/valid split
X_train, X_valid, y_train, y_valid = train_test_split(X, y, 
                                                      test_size=0.2,
                                                      stratify=ds_name, 
                                                      random_state=41)

# Get PyTorch datasets
train_ds, valid_ds = [SpectralDataset(X, y, ) 
                      for X, y, in [(X_train, y_train), (X_valid, y_valid)]]

# Then PyTorch dataloaders
dls = get_dls(train_ds, valid_ds, bs=32)

first_batch = next(iter(dls.train))
print(f'First batch X dim: {first_batch[0].shape}')
print(f'First batch y dim: {first_batch[1].shape}')


  0%|          | 0/3725 [00:00<?, ?it/s]

100%|██████████| 3725/3725 [00:00<00:00, 7023.45it/s]

First batch X dim: torch.Size([32, 1, 1076])
First batch y dim: torch.Size([32, 1])



