# dataloaders

> Module containing helper functions and classes around dataloaders

In [None]:
#|default_exp dataloaders

In [None]:
#|export
from functools import partial

import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader
import torch
import PIL

In [None]:
#|export

def hf_ds_collate_fn(data, flatten=True):
    '''
    Collation function for building a PyTorch DataLoader from a a huggingface dataset.
    Tries to put all items from an entry into the dataset to tensor.
    PIL images are converted to tensor, either flattened or not 
    '''

    def to_tensor(i, flatten):
        if isinstance(i, PIL.Image.Image):
            if flatten:
                return torch.flatten(TF.to_tensor(i))
            return TF.to_tensor(i)
        else:
            return torch.tensor(i)
    
    to_tensor = partial(to_tensor, flatten=flatten)      # partially apply to_tensor() with flatten arg
    data = [map(to_tensor, el.values()) for el in data]  # map each item from a dataset entry through to_tensor()
    data = zip(*data)                                    # zip data of any length not just (x,y) but also (x,y,z)
    return (torch.stack(i) for i in data)

In [None]:
#|export
class DataLoaders:
    def __init__(self, train, valid):
        '''Class that exposes two PyTorch dataloaders as train and valid arguments'''
        self.train = train
        self.valid = valid
    
    @classmethod
    def _get_dls(cls, train_ds, valid_ds, bs, collate_fn, **kwargs):
        '''Helper function returning 2 PyTorch Dataloaders as a tuple for 2 Datasets. **kwargs are passed to the DataLoader'''
        return (DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs),
                DataLoader(valid_ds, batch_size=bs*2, collate_fn=collate_fn, **kwargs))
        
    @classmethod
    def from_hf_dd(cls, dd, batch_size, collate_fn=hf_ds_collate_fn, **kwargs):
        '''Factory method to create a Dataloaders object for a Huggingface Dataset dict,
        uses the `hf_ds_collate_func` collation function by default, **kwargs are passes to the DataLoaders'''
        return cls(*cls._get_dls(*dd.values(), batch_size, collate_fn, **kwargs))

In [None]:
#|hide
from nbdev.showdoc import *

In [None]:
show_doc(DataLoaders.from_hf_dd)

---

[source](https://github.com/lucasvw/nntrain/blob/main/nntrain/dataloaders.py#L49){target="_blank" style="float:right; font-size:smaller"}

### DataLoaders.from_hf_dd

>      DataLoaders.from_hf_dd (dd, batch_size, collate_fn=<function
>                              hf_ds_collate_fn>, **kwargs)

Factory method to create a Dataloaders object for a Huggingface Dataset dict,
uses the `hf_ds_collate_func` collation function by default, **kwargs are passes to the DataLoaders

Example usage:

In [None]:
from datasets import load_dataset,load_dataset_builder
import torchvision.transforms.functional as TF
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
name = "fashion_mnist"
ds_builder = load_dataset_builder(name)
ds_hf = load_dataset(name)

Downloading builder script:   0%|          | 0.00/2.00k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.36k [00:00<?, ?B/s]

Downloading and preparing dataset fashion_mnist/fashion_mnist (download: 29.45 MiB, generated: 34.84 MiB, post-processed: Unknown size, total: 64.29 MiB) to /root/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1...


Downloading data files:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/26.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/29.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.42M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.15k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/4 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset fashion_mnist downloaded and prepared to /root/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
def accuracy(preds, targs):
    return (preds.argmax(dim=1) == targs).float().mean() 

def fit(epochs):
    for epoch in range(epochs):
        model.train()                                       
        n_t = train_loss_s = 0                              
        for xb, yb in dls.train:
            preds = model(xb)
            train_loss = loss_func(preds, yb)
            train_loss.backward()
            
            n_t += len(xb)
            train_loss_s += train_loss.item() * len(xb)
            
            opt.step()
            opt.zero_grad()
        
        model.eval()                                        
        n_v = valid_loss_s = acc_s = 0                      
        for xb, yb in dls.valid: 
            with torch.no_grad():                           
                preds = model(xb)
                valid_loss = loss_func(preds, yb)
                
                n_v += len(xb)
                valid_loss_s += valid_loss.item() * len(xb)
                acc_s += accuracy(preds, yb) * len(xb)
        
        train_loss = train_loss_s / n_t                     
        valid_loss = valid_loss_s / n_v
        acc = acc_s / n_v
        print(f'{epoch=} | {train_loss=:.3f} | {valid_loss=:.3f} | {acc=:.3f}')

def get_model_opt():
    layers = [nn.Linear(n_in, n_h), nn.ReLU(), nn.Linear(n_h, n_out)]
    model = nn.Sequential(*layers)
    
    opt = torch.optim.SGD(model.parameters(), lr)
    
    return model, opt

n_in  = 28*28
n_h   = 50
n_out = 10
lr    = 0.01
bs    = 1024
loss_func = F.cross_entropy

model, opt = get_model_opt()

dls = DataLoaders.from_hf_dd(ds_hf, bs)

fit(1)

epoch=0 | train_loss=2.185 | valid_loss=2.070 | acc=0.407


In [None]:
#|hide
import nbdev; nbdev.nbdev_export()