# ðŸ”¥ LightlyTrain - Image Classification with DINOv3 ðŸ”¥

This notebook demonstrates how to use LightlyTrain for image classification with state-of-the-art DINOv3 backbone.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lightly-ai/lightly-train/blob/main/examples/notebooks/image_classification.ipynb)

> **Important**: When running on Google Colab make sure to select a GPU runtime for faster processing. You can do this by going to `Runtime` > `Change runtime type` and selecting a GPU hardware accelerator.

## Installation

LightlyTrain can be installed directly via `pip`:

In [None]:
!pip install lightly-train

> **Important**: LightlyTrain is officially supported on
> - Linux: CPU or CUDA
> - MacOS: CPU only
> - Windows (experimental): CPU or CUDA
>
> We are planning to support MPS for MacOS.
>
> Check the [installation instructions](https://docs.lightly.ai/train/stable/installation.html) for more details on installation.

## Train image classification model

Training your own classification model is straightforward with LightlyTrain.

### Download dataset

First download a dataset. The dataset must be organized in class folders, see the [documentation](https://docs.lightly.ai/train/stable/image_classification.html#data) for more details.

In [None]:
!git clone https://github.com/alexeygrigorev/clothing-dataset-small

The dataset looks like this after the download completes:
```
clothing-dataset-small
â”œâ”€â”€ train
â”‚   â”œâ”€â”€ dress
â”‚   â”‚   â”œâ”€â”€ 009b3c31-fb62-45c0-be9a-37a5c238cb88.jpg
â”‚   â”‚   â”œâ”€â”€ 041c6bde-e737-46fd-9586-984c1503941f.jpg
â”‚   â”‚   â””â”€â”€ ...
â”‚   â”œâ”€â”€ hat
â”‚   â”‚   â”œâ”€â”€ 00d94e21-5891-492e-be0e-792e7338c077.jpg
â”‚   â”‚   â””â”€â”€ ...
â”‚   â””â”€â”€ ...
â”œâ”€â”€ validation
â”‚   â”œâ”€â”€ dress
â”‚   â”‚   â”œâ”€â”€ 07cddef1-1fc8-47e4-a28a-613e60912590.jpg
â”‚   â”‚   â””â”€â”€ ...
â”‚   â””â”€â”€ ...
â””â”€â”€ test
    â”œâ”€â”€ dress
    â”‚   â”œâ”€â”€ 06a00c0f-5f9a-410d-a7da-3881a9df3a71.jpg
    â”‚   â””â”€â”€ ...
    â””â”€â”€ ...
```

Next, start the training with the `train_image_classification` function. You only have to specify the output directory, model, and input data. LightlyTrain automatically sets the remaining training parameters and applies image augmentations. Of course you can always customize these settings if needed.

In [None]:
import lightly_train

lightly_train.train_image_classification(
    out="out/my_experiment",
    model="dinov3/vitt16",
    steps=1000,  # Small number of steps for demonstration, default is 100_000.
    batch_size=16,
    data={
        "train": "clothing-dataset-small/train",
        "val": "clothing-dataset-small/validation",
        "classes": {
            0: "dress",
            1: "hat",
            2: "longsleeve",
            3: "outwear",
            4: "pants",
            5: "shirt",
            6: "shoes",
            7: "shorts",
            8: "skirt",
            9: "t-shirt",
        },
    },
)

Once the training is complete, the output directory looks like this:
```
out/my_experiment
â”œâ”€â”€ checkpoints
â”‚   â”œâ”€â”€ best.ckpt
â”‚   â””â”€â”€ last.ckpt
â”œâ”€â”€ events.out.tfevents.1764251158.ef9b159fe4b8.273.0
â”œâ”€â”€ exported_models
â”‚   â”œâ”€â”€ exported_best.pt
â”‚   â””â”€â”€ exported_last.pt
â””â”€â”€ train.log
```

The best model checkpoint is saved to `out/my_experiment/exported_models/exported_best.pt`. You can load it for inference like this:

In [None]:
# Load the model for inference
model = lightly_train.load_model("out/my_experiment/exported_models/exported_best.pt")

In [None]:
import matplotlib.pyplot as plt
from torchvision.io import read_image

# Run inference
image_path = (
    "clothing-dataset-small/validation/dress/07cddef1-1fc8-47e4-a28a-613e60912590.jpg"
)
results = model.predict(image_path, topk=1)

image = read_image(image_path)
plt.imshow(image.permute(1, 2, 0))
plt.title(
    f"Predicted class: {model.classes[results['labels'][0].item()]}\nConfidence: {results['scores'][0]:.1%}"
)
plt.axis("off")
plt.show()

## Next Steps

* [Image Classification Documentation](https://docs.lightly.ai/train/stable/image_classification.html): If you want to learn more about image classification with LightlyTrain.
* [Quick Start Guide](https://docs.lightly.ai/train/stable/quick_start_object_detection.html): If you want to learn more about LightlyTrain in general.
* [DINOv2 Pretraining](https://docs.lightly.ai/train/stable/pretrain_distill/methods/dinov2.html): If you want to learn how to pretrain foundation models on your data.