# CIFAR10: TrainDataLoaderIter and ValDataLoaderIter

This example demonstrates the usage of `TrainDataLoaderIter` and `ValDataLoaderIter` with a ResNet56 on the Cifar10 dataset.


`LRFinder.range_test()` assumes that the `DataLoader` objects passed in return something `input, label, *` where `input` is the tensor passed to the `model` and `label` is the tensor passed to the `criterion`. If this is not the case for your application then you'll probably need to use `TrainDataLoaderIter` and `ValDataLoaderIter`. These two classes essentially wrap around the `DataLoader` class for the training set and validation set, respectively, and allow you to customize how the batches are formed for each of them.


Common use-cases for `TrainDataLoaderIter` and `ValDataLoaderIter`:
1. Your `DataLoader`/`Dataset.__get_item__()` returns a `dict` or other containers that differ from the above (`input, label, *`)
2. Your `Dataset.__get_item__()` doesn't perform all the data processing/handling/conversion needed for `input` and `label` to be ready to be passed in to the `model` and `criterion`, respectively.

In [1]:
%matplotlib inline

import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import cifar10_resnet as rc10
from PIL import Image

try:
    from torch_lr_finder import LRFinder, TrainDataLoaderIter, ValDataLoaderIter
except ImportError:
    # Run from source
    import sys
    sys.path.insert(0, '..')
    from torch_lr_finder import LRFinder, TrainDataLoaderIter, ValDataLoaderIter

## Loading CIFAR10

To demonstrate how `TrainDataLoaderIter` and `ValDataLoaderIter` can be used, we'll create a new CIFAR10 dataset that instead of returning `img, target` like [the one from torchvision](https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py#L118), returns a dictionary. We are essentially provoking use-case 1 from above.

In [2]:
# All we need to do here is inherit from CIFAR10 and change it's __get_item__()
# method to return a dictionary
class CIFAR10WithDict(CIFAR10):
    def __getitem__(self, index):
        # IMPORTANT; this code works with torchvision 0.6.0; it's not guaranteed to
        # work with other versions
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)
            
        ret_dict = {}
        ret_dict["img"] = img
        ret_dict["target"] = target

        return ret_dict

In [3]:
cifar_pwd = "../data"
batch_size= 256

In [4]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = CIFAR10WithDict(root=cifar_pwd, train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)

testset = CIFAR10WithDict(root=cifar_pwd, train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size * 2, shuffle=False, num_workers=0)

Files already downloaded and verified
Files already downloaded and verified


## Model

In [5]:
model = rc10.resnet56()

## Training loss (fastai)



In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-7, weight_decay=1e-2)
lr_finder = LRFinder(model, optimizer, criterion, device="cuda")

Let's try to do a `range_test` and see what happens:

In [7]:
lr_finder.range_test(trainloader, end_lr=100, num_iter=100, step_mode="exp")

HBox(children=(IntProgress(value=0), HTML(value='')))




ValueError: Your batch type not supported: <class 'dict'>. Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method.

The problem is that `LRFinder` doesn't know how to build a batch using a dictionary. This is where `TrainDataLoaderIter` comes in:

In [8]:
class CustomTrainIter(TrainDataLoaderIter):
    def inputs_labels_from_batch(self, batch_data):
        return batch_data["img"], batch_data["target"]
    
custom_train_iter = CustomTrainIter(trainloader)

`inputs_labels_from_batch()` gets the return of `CIFAR10WithDict.__get_item__()` and must return a tuple `input, label` or in this case `img, target`.

In [9]:
lr_finder.reset()
lr_finder.range_test(custom_train_iter, end_lr=100, num_iter=100, step_mode="exp")

HBox(children=(IntProgress(value=0), HTML(value='')))

Stopping early, the loss has diverged
Learning rate search finished. See the graph with {finder_name}.plot()


## Validation loss (Leslie N. Smith)

We find something similar when using the validation loss test. If we try to run `range_test` with `trainloader` and `testloader` we get an error:

In [10]:
lr_finder.reset()
lr_finder.range_test(trainloader, val_loader=testloader, end_lr=100, num_iter=100, step_mode="exp")

HBox(children=(IntProgress(value=0), HTML(value='')))

ValueError: Your batch type not supported: <class 'dict'>. Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method.

Even if we replace `trainloader` by `custom_train_iter` we'll still get the error:

In [11]:
lr_finder.reset()
lr_finder.range_test(custom_train_iter, val_loader=testloader, end_lr=100, num_iter=100, step_mode="exp")

HBox(children=(IntProgress(value=0), HTML(value='')))

ValueError: Your batch type not supported: <class 'dict'>. Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method.

The solution is to wrap `testloader` in `ValDataLoaderIter` and overriding `inputs_labels_from_batch()`:

In [12]:
class CustomValIter(ValDataLoaderIter):
    def inputs_labels_from_batch(self, batch_data):
        return batch_data["img"], batch_data["target"]
    
custom_val_iter = CustomValIter(testloader)

In [13]:
lr_finder.reset()
lr_finder.range_test(custom_train_iter, val_loader=custom_val_iter, end_lr=100, num_iter=100, step_mode="exp")

HBox(children=(IntProgress(value=0), HTML(value='')))

Learning rate search finished. See the graph with {finder_name}.plot()
