# Cifar10

This plugin is part of `openpifpaf.contrib`. It demonstrates the plugin architecture.
There already is a nice dataset for CIFAR10 in `torchvision` and a related [PyTorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html). 
The plugin adds a `DataModule` that uses this dataset.
Let's start with them setup for this notebook and registering all available OpenPifPaf plugins:

In [None]:
%matplotlib inline

import openpifpaf
import torchvision

openpifpaf.plugins.register()
print(openpifpaf.plugins.REGISTERED)

Next, we configure and instantiate the Cifar10 datamodule and look at the configured head metas:

In [None]:
# configure 
openpifpaf.contrib.cifar10.datamodule.Cifar10.debug = True 
openpifpaf.contrib.cifar10.datamodule.Cifar10.batch_size = 1

# instantiate and inspect
datamodule = openpifpaf.contrib.cifar10.datamodule.Cifar10()
datamodule.head_metas

We see here that CIFAR10 is being treated as a detection dataset (`CifDet`) and has 10 categories.
To create a network, we use the `factory()` function that takes the name of the base network `cifar10net` and the list of head metas.

In [None]:
net = openpifpaf.network.factory(base_name='cifar10net', head_metas=datamodule.head_metas)

We can inspect the training data that is returned from `datamodule.train_loader()`:

In [None]:
# configure visualization
openpifpaf.visualizer.Base.all_indices = [('cifdet', 9)]  # category 9 = truck
openpifpaf.visualizer.CifDet.show_regressions = True

# Create a wrapper for a data loader that iterates over a set of matplotlib axes.
# The only purpose is to set a different matplotlib axis before each call to 
# retrieve the next image from the data_loader so that it produces multiple
# debug images in one canvas side-by-side.
def loop_over_axes(axs, data_loader):
    previous_common_ax = openpifpaf.visualizer.Base.common_ax
    train_loader_iter = iter(data_loader)
    for ax in axs.reshape(-1):
        openpifpaf.visualizer.Base.common_ax = ax
        yield next(train_loader_iter)
    openpifpaf.visualizer.Base.common_ax = previous_common_ax

# create a canvas and loop over the first few entries in the training data
with openpifpaf.show.canvas(ncols=6, nrows=3, figsize=(10, 5)) as axs:
    for images, targets, meta in loop_over_axes(axs, datamodule.train_loader()):
        # print([t.shape for t in targets])
        pass

## Training

We train a very small network, `cifar10net`, for only one epoch. Afterwards, we will investigate its predictions.

In [None]:
!python -m openpifpaf.train --dataset=cifar10 --basenet=cifar10net --epochs=1 --log-interval=500 --lr-warm-up-epochs=0.1 --lr=3e-3 --batch-size=16 --loader-workers=2 --output=cifar10_tutorial.pkl

## Prediction

First using CLI:

In [None]:
!python -m openpifpaf.predict --checkpoint cifar10_tutorial.pkl.epoch001 images/cifar10_*.png --seed-threshold=0.1 --json-output .
!cat cifar10_*.json

Using API:

In [None]:
net_cpu, _ = openpifpaf.network.factory(checkpoint='cifar10_tutorial.pkl.epoch001')
preprocess = openpifpaf.transforms.Compose([
    openpifpaf.transforms.NormalizeAnnotations(),
    openpifpaf.transforms.CenterPadTight(16),
    openpifpaf.transforms.EVAL_TRANSFORM,
])

openpifpaf.decoder.CifDetSeeds.threshold = 0.1
decode = openpifpaf.decoder.factory([hn.meta for hn in net_cpu.head_nets])

data = openpifpaf.datasets.ImageList([
    'images/cifar10_airplane4.png',
    'images/cifar10_automobile10.png',
    'images/cifar10_ship7.png',
    'images/cifar10_truck8.png',
], preprocess=preprocess)
for image, _, meta in data:
    predictions = decode.batch(net_cpu, image.unsqueeze(0))[0]
    print(['{} {:.0%}'.format(pred.category, pred.score) for pred in predictions])

## Evaluation

TODO