---
title: "Image Classification: From Training to Inference within PyTorch ecosystem (timm)"
description: "In this tutorial we are going to learn how to train a image classification model, and use it for inference. We are going to explore timm models, within Kornia Augmentation for a dataset from Hugging Face Hub."
author:
    - "João G. A. Amorim"
date: 02-10-2024
categories:
    - Basic
    - Training
    - Inference
    - kornia.augmentation
image: "../tutorials/assets/image_classification.png"
---

<a href="https://colab.sandbox.google.com/github/kornia/tutorials/blob/master/nbs/image_classification.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in google colab"></a>

In [None]:
%%capture
!pip install kornia
!pip install kornia-rs
!pip install timm
!pip install datasets[vision]

## Setup

In [None]:
import pprint
import random

import kornia
import numpy as np
import timm
from datasets import Image, load_dataset

## Exploring timm models

> [PyTorch Image Models (timm)](https://github.com/huggingface/pytorch-image-models) is a collection of image models, layers, utilities, optimizers, schedulers, data-loaders / augmentations, and reference training / validation scripts that aim to pull together a wide variety of SOTA models with ability to reproduce ImageNet training results.

Let's list all model available within pre-trained weights.

In [None]:
model_names_pretrained = timm.list_models(pretrained=True)
model_names = timm.list_models()

k = 15

print(
    f"Timm have a total of {len(model_names_pretrained)} pretrained models (with a total of {len(model_names)}), here is some of them:"
)
for m in random.choices(model_names_pretrained, k=k):
    print(f"\t-> `{m}`")

## Loading the dataset and create the dataloader

### Datasets

For datasets, we will use [Hugging Faces datasets library](https://github.com/huggingface/datasets), which you can explore the image classification in [🤗 Hub](https://huggingface.co/datasets?task_categories=task_categories:image-classification&sort=trending).

For these example, let's use the [beans dataset](https://huggingface.co/datasets/beans).

> Beans leaf dataset with images of diseased and health leaves. Based on a leaf image, the goal of this task is to predict the disease type (Angular Leaf Spot and Bean Rust), if any.


> Within 🤗 `datasets` library, your data can be stored in various places; they can be on your local machine’s disk, in a Github repository, and in in-memory data structures like Python dictionaries and Pandas DataFrames.

Let's load the beans from 🤗 Hub using the 🤗 `datasets` library. For it we will use `load_dataset`, where we can specify what dataset part we want to load (usually `train`, `valid` and `test` are available), defined via the `split` parameter. We will also use `streaming=True`, which allow we work with a dataset without downloading all data. The data is downloaded as you iterate over the dataset. By setting the `cache_dir` parameter, you can manage where the `datasets` library will read/write the data in your disk, let's use the temporary directory. Check more about this function in the official docs: https://huggingface.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset


In [None]:
data_directory = "/tmp/hf/datasets/"
dataset = load_dataset("beans", split="train", streaming=True, cache_dir=data_directory)

Let's retrieve the information of the dataset from the Hub. And also check it's features.

In [None]:
pprint(dataset.info)

In [None]:
pprint(dataset.features)

By default, for image datasets, the library will automated decode (read) the images files using [Pillow](https://pypi.org/project/pillow/). If desired you can disable this decode.

Within the `decode` option disable, instead of returning a PIL Image object, our dataset will return the path for the image and it's bytes (as a dictionary within `path` and `bytes`).

At the moment (kornia v0.7.1), the `io` module, powered by [kornia-rs](https://github.com/kornia/kornia-rs) do not support reading data from bytes strings. So we are still using the default decode option, but you can check bellow how to disable it

In [None]:
dataset = dataset.cast_column("image", Image(decode=False))

In [None]:
pprint(dataset.features)

In [None]:
dataset = dataset.cast_column("image", Image(decode=True))

 Let's check it in two examples:

In [None]:
counter = 0
n = 2
device = "cpu"

for sample in dataset:
    print(f"dataset item {counter}.")
    print(f'-> `image`: {sample["image"]}')
    print(f'-> `labels`: {sample["labels"]}')

    image = kornia.image_to_tensor(np.asarray(sample["image"]))
    print(f"-> Image: Shape={image.shape} | Mean={image.mean(dim=(2, 3))} | Std={image.std(dim=(2, 3))}")
    if counter >= n:
        break
    counter += 1