# Birder - Getting Started

In this notebook we will explore some of the API's provided by Birder.

Before we start, if you're running in Colab, make sure to install Birder first.
Numpy 2.0 and above are not yet supported on Colab, so you might have to downgrade as well.

In [None]:
# When running in Colab
# !pip install birder

# When running in a cloned repository (instead of pip installation)
# %cd ..

In [None]:
import birder
import torch
from birder.inference.classification import infer_image
from birder.results.gui import show_top_k

In [None]:
birder.__version__

## Exploring Models

Birder uses a systematic naming convention that helps you identify the key characteristics of each model. The naming pattern includes:

* Architecture prefix (e.g., xcit, resnext, mobilenet)
* Optional: Net parameter value indicating model configuration
* Optional: Training indicators tags (intermediate, mim)
* Optional: Geographical tags indicating data source (il-common, eu-all)
* Optional: Optimization tags (quantized, reparameterized)
* Optional: Epoch number

We can list all pretrained models according to any filter (or without). The filter uses glob-style pattern matching, where '*' matches any sequence of characters.

Let's look at all models that were trained on the *il-common* dataset and load one of them.
The pattern below will match any XCiT model trained on the *il-common* dataset:

In [None]:
birder.list_pretrained_models("xcit*il-common*")

In [None]:
(net, model_info) = birder.load_pretrained_model("xcit_nano12_p16_il-common", inference=True)

# Get the image size the model was trained on
size = birder.get_size_from_signature(model_info.signature)

# Create an inference transform
transform = birder.classification_transform(size, model_info.rgb_stats)

## Inference

Now we shall fetch an example image (of a Eurasian teal) and try to classify it.

In [None]:
# In case it's a first run, create the data dir
!mkdir data

In [None]:
image_path = "data/example.jpeg"
birder.common.cli.download_file("https://huggingface.co/spaces/birder-project/birder-image-classification/resolve/main/Eurasian%20teal.jpeg", image_path)

In [None]:
(out, _) = infer_image(net, image_path, transform)
show_top_k(image_path, out.squeeze(), model_info.class_to_idx, "Eurasian teal")

In [None]:
# The model failed to classify it correctly, this is probably because the object is small
# and we are using a low resolution compact model.
#
# We will try again using an aggressive center crop.
transform = birder.classification_transform(size, model_info.rgb_stats, center_crop=0.5)
(out, _) = infer_image(net, image_path, transform)
show_top_k(image_path, out.squeeze(), model_info.class_to_idx, "Eurasian teal")

## Fine-tuning

We shall now fine-tune the model on an example dataset.

For this example we will use the Caltech-UCSD Birds-200-2011 dataset - <https://authors.library.caltech.edu/records/cvm3y-5hh21>.

It has about ~12K images of 200 species.

We will first do simple linear probing and later a full training.

In [None]:
from birder.datahub.classification import CUB_200_2011
from birder.scripts import train

In [None]:
training_dataset = CUB_200_2011(download=True, split="training")  # Will download all splits
validation_dataset = CUB_200_2011(split="validation")

In [None]:
# Linear probing
args = train.args_from_dict(
    network="xcit_nano12_p16",
    pretrained=True,  # Implies "reset head"
    freeze_body=True,
    tag="il-common",
    num_workers=2,
    lr=0.1,
    lr_scheduler="cosine",
    epochs=5,
    size=256,
    data_path=training_dataset.root,
    val_path=validation_dataset.root,
)

In [None]:
train.train(args)

In [None]:
# Full fine-tuning for 10 epochs
args = train.args_from_dict(
    network="xcit_nano12_p16",
    tag="il-common",
    num_workers=2,
    opt="adamw",
    lr=0.0001,
    lr_scheduler="cosine",
    lr_cosine_min=1e-7,
    epochs=15,
    resume_epoch=5,
    size=256,
    wd=0.05,
    norm_wd=0,
    grad_accum_steps=2,
    smoothing_alpha=0.1,
    mixup_alpha=0.2,
    cutmix=True,
    aug_level=4,
    clip_grad_norm=1,
    fast_matmul=True,
    # compile=True,
    data_path=training_dataset.root,
    val_path=validation_dataset.root,
)

In [None]:
train.train(args)

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [None]:
# Examine the training
%tensorboard --logdir runs

## Evaluate

In [None]:
import polars as pl
from birder.common.fs_ops import load_model
from torch.utils.data import DataLoader

In [None]:
device = torch.device("cuda")
(net, model_info) = load_model(
    device, "xcit_nano12_p16", tag="il-common", epoch=15, inference=True
)

In [None]:
transform = birder.classification_transform(size, model_info.rgb_stats)
dataset = CUB_200_2011(split="validation", transform=transform)
inference_loader = DataLoader(
    dataset,
    batch_size=128,
    num_workers=2,
)

In [None]:
results = birder.evaluate_classification(device, net, inference_loader, model_info.class_to_idx)

In [None]:
results.log_short_report()

In [None]:
# We can examine a detailed per-class report
report_df = results.detailed_report()
report_df

## Error Analysis

In [None]:
from birder.results.gui import ProbabilityHistogram

In [None]:
# The 5 lowest scoring classes
results.pretty_print(sort_by="f1-score", n=5)

In [None]:
# See the top most confused pairs
results.most_confused()

In [None]:
net.to(torch.device("cpu"))

# Examine the most confused classes
confusion_sample = results.mistakes.filter(pl.col("label_name") == results.most_confused()["actual"][0])[0]
image_path = confusion_sample["sample"].item()
(out, _) = infer_image(net, image_path, transform)
show_top_k(image_path, out.squeeze(), model_info.class_to_idx, confusion_sample["label"].item())

In [None]:
confusion_sample = results.mistakes.filter(pl.col("label_name") == results.most_confused()["predicted"][0])[0]
image_path = confusion_sample["sample"].item()
(out, _) = infer_image(net, image_path, transform)
show_top_k(image_path, out.squeeze(), model_info.class_to_idx, confusion_sample["label"].item())

In [None]:
ProbabilityHistogram(results).show(results.most_confused()["actual"][0], results.most_confused()["predicted"][0])