## First exports

When exporting code to a module with `nbdev` the first thing we need to do is declare the `default_exp` directive. This makes sure that when we run the export, the module will be exported to `dataloaders.py`

In [None]:
#| default_exp dataloaders

In [None]:
#|export

import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader
import torch
import PIL

In [None]:
#|export

def hf_ds_collate_func(data):
    '''
    Collation function for building a PyTorch DataLoader from a a huggingface dataset.
    Tries to put all items from an entry into the dataset to tensor.
    PIL images are converted to tensor.
    '''

    def to_tensor(i):
        if isinstance(i, PIL.Image.Image):
            return TF.to_tensor(i).view(-1)
        else:
            return torch.tensor(i)
    
    data = [map(to_tensor, el.values()) for el in data]  # map each item from a dataset entry through to_tensor()
    data = zip(*data)                                    # zip data of any length not just (x,y) but also (x,y,z)
    return (torch.stack(i) for i in data)                

In [None]:
#|export
class DataLoaders:
    def __init__(self, train, valid):
        self.train = train
        self.valid = valid
    
    @classmethod
    def _get_dls(cls, train_ds, valid_ds, bs, collate_fn):
        return (DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn),
                DataLoader(valid_ds, batch_size=bs*2, collate_fn=collate_fn))
        
    @classmethod
    def from_hf_dd(cls, dd, batch_size):
        return cls(*cls._get_dls(*dd.values(), batch_size, hf_ds_collate_func))

In [None]:
def fit(epochs):
    for epoch in range(epochs):
        model.train()                                       
        n_t = train_loss_s = 0                              
        for xb, yb in dls.train:
            preds = model(xb)
            train_loss = loss_func(preds, yb)
            train_loss.backward()
            
            n_t += len(xb)
            train_loss_s += train_loss.item() * len(xb)
            
            opt.step()
            opt.zero_grad()
        
        model.eval()                                        
        n_v = valid_loss_s = acc_s = 0                      
        for xb, yb in dls.valid: 
            with torch.no_grad():                           
                preds = model(xb)
                valid_loss = loss_func(preds, yb)
                
                n_v += len(xb)
                valid_loss_s += valid_loss.item() * len(xb)
                acc_s += accuracy(preds, yb) * len(xb)
        
        train_loss = train_loss_s / n_t                     
        valid_loss = valid_loss_s / n_v
        acc = acc_s / n_v
        print(f'{epoch=} | {train_loss=:.3f} | {valid_loss=:.3f} | {acc=:.3f}')

In [None]:
dls = DataLoaders.from_hf_dd(ds_hf, bs)

In [None]:
model, opt = get_model_opt()

fit(5)

epoch=0 | train_loss=2.175 | valid_loss=2.050 | acc=0.406
epoch=1 | train_loss=1.917 | valid_loss=1.780 | acc=0.536
epoch=2 | train_loss=1.651 | valid_loss=1.532 | acc=0.616
epoch=3 | train_loss=1.427 | valid_loss=1.340 | acc=0.637
epoch=4 | train_loss=1.262 | valid_loss=1.203 | acc=0.648


In [None]:
import nbdev; nbdev.nbdev_export()