# Quickstart: Cat Breed Classification

Learn how to classify cat breeds using our pre-trained TinyDiT classifier model.

## What You'll Learn

- Load classifier model from HuggingFace Hub
- Preprocess images for inference
- Run classification and interpret results
- Visualize predictions

## Prerequisites

Install required packages:

In [None]:
!pip install onnxruntime huggingface_hub pillow numpy matplotlib -q

## Step 1: Load Model from HuggingFace

We'll download the quantized ONNX classifier from HuggingFace Hub.

In [None]:
from huggingface_hub import hf_hub_download
import onnxruntime as ort

# Download model from HuggingFace
print("Downloading classifier model...")
model_path = hf_hub_download(
    repo_id="d4oit/tiny-cats-model",
    filename="classifier/model.onnx"
)
print(f"Model downloaded to: {model_path}")

# Load ONNX model
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
print(f"Model loaded successfully!")
print(f"Input shape: {session.get_inputs()[0].shape}")
print(f"Output shape: {session.get_outputs()[0].shape}")

## Step 2: Define Breed Names

The model classifies images into 13 categories (12 cat breeds + other).

In [None]:
# Cat breed names (in order of class indices)
BREED_NAMES = [
    "Abyssinian",
    "Bengal",
    "Birman",
    "Bombay",
    "British Shorthair",
    "Egyptian Mau",
    "Maine Coon",
    "Persian",
    "Ragdoll",
    "Russian Blue",
    "Siamese",
    "Sphynx",
    "Other"  # Non-cat or unknown breed
]

print(f"Supported breeds: {len(BREED_NAMES)}")
for i, breed in enumerate(BREED_NAMES):
    print(f"  {i}: {breed}")

## Step 3: Image Preprocessing

Images need to be preprocessed to match the training data distribution.

In [None]:
import numpy as np
from PIL import Image

def preprocess_image(image_path: str, size: int = 224) -> np.ndarray:
    """
    Preprocess an image for classification.
    
    Args:
        image_path: Path to the image file
        size: Target size (224x224 for ResNet)
    
    Returns:
        Preprocessed image as numpy array (1, 3, size, size)
    """
    # Load image
    image = Image.open(image_path).convert("RGB")
    
    # Resize
    image = image.resize((size, size), Image.Resampling.LANCZOS)
    
    # Convert to numpy and normalize to [0, 1]
    image_array = np.array(image).astype(np.float32) / 255.0
    
    # Normalize with ImageNet statistics
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image_array = (image_array - mean) / std
    
    # Convert to CHW format and add batch dimension
    image_array = np.transpose(image_array, (2, 0, 1))
    image_array = np.expand_dims(image_array, axis=0)
    
    return image_array

print("Preprocessing function defined!")

## Step 4: Classification Function

Create a helper function to classify images.

In [None]:
def classify_image(session, image_path: str, top_k: int = 5) -> dict:
    """
    Classify an image and return predictions.
    
    Args:
        session: ONNX runtime session
        image_path: Path to the image file
        top_k: Number of top predictions to return
    
    Returns:
        Dictionary with predictions
    """
    # Preprocess
    input_tensor = preprocess_image(image_path)
    
    # Run inference
    outputs = session.run(None, {input_name: input_tensor})
    probabilities = outputs[0][0]
    
    # Get top-k predictions
    top_indices = np.argsort(probabilities)[::-1][:top_k]
    
    predictions = [
        {
            "breed": BREED_NAMES[idx],
            "confidence": float(probabilities[idx] * 100),
            "index": idx
        }
        for idx in top_indices
    ]
    
    return predictions

print("Classification function defined!")

## Step 5: Download and Classify a Sample Image

Let's test with a sample cat image.

In [None]:
import requests
from io import BytesIO

# Download a sample cat image
print("Downloading sample image...")
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/4d/Cat_November_2010-1a.jpg/800px-Cat_November_2010-1a.jpg"
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))

# Save locally
sample_path = "sample_cat.jpg"
image.save(sample_path)
print(f"Sample image saved to: {sample_path}")

# Display the image
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 8))
plt.imshow(image)
plt.title("Sample Cat Image")
plt.axis("off")
plt.tight_layout()
plt.show()

## Step 6: Run Classification

In [None]:
# Classify the sample image
print("Classifying image...")
predictions = classify_image(session, sample_path, top_k=5)

# Display results
print("\n" + "="*50)
print("CLASSIFICATION RESULTS")
print("="*50)
for i, pred in enumerate(predictions, 1):
    print(f"{i}. {pred['breed']:20s} - {pred['confidence']:.2f}%")
print("="*50)

## Step 7: Visualize Predictions

In [None]:
# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Show image
ax1.imshow(image)
ax1.set_title("Input Image", fontsize=14, fontweight='bold')
ax1.axis("off")

# Show bar chart of predictions
breeds = [p['breed'] for p in predictions]
confidences = [p['confidence'] for p in predictions]
colors = ['#2ecc71' if i == 0 else '#95a5a6' for i in range(len(predictions))]

bars = ax2.barh(breeds, confidences, color=colors)
ax2.set_xlabel('Confidence (%)', fontsize=12)
ax2.set_title('Top 5 Predictions', fontsize=14, fontweight='bold')
ax2.invert_yaxis()

# Add value labels
for bar, conf in zip(bars, confidences):
    ax2.text(bar.get_width() + 0.5, bar.get_y() + bar.get_height()/2,
             f'{conf:.1f}%', va='center', fontsize=10)

plt.tight_layout()
plt.show()

# Print best prediction
best = predictions[0]
print(f"\nüê± Predicted Breed: {best['breed']}")
print(f"üìä Confidence: {best['confidence']:.2f}%")

## Step 8: Try Your Own Images

Upload your cat photos and classify them!

In [None]:
from google.colab import files
import os

# Upload your images (Colab only)
print("Upload your cat images:")
uploaded = files.upload()

# Classify each uploaded image
for filename in uploaded.keys():
    if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
        print(f"\n{'='*50}")
        print(f"Classifying: {filename}")
        print('='*50)
        
        predictions = classify_image(session, filename, top_k=3)
        
        for i, pred in enumerate(predictions, 1):
            print(f"{i}. {pred['breed']:20s} - {pred['confidence']:.2f}%")
        
        # Display image
        img = Image.open(filename)
        plt.figure(figsize=(6, 6))
        plt.imshow(img)
        plt.title(f"Prediction: {predictions[0]['breed']} ({predictions[0]['confidence']:.1f}%)")
        plt.axis("off")
        plt.tight_layout()
        plt.show()

## Common Issues & Solutions

### Issue 1: "File not found"
**Solution:** Make sure the image path is correct and the file exists.

### Issue 2: "Invalid image size"
**Solution:** The preprocessing function automatically resizes to 224x224.

### Issue 3: "Low confidence predictions"
**Solution:** 
- Image may not contain a cat
- Cat may be occluded or in poor lighting
- Breed may not be in the 12 supported breeds

### Issue 4: "Model not loading"
**Solution:** Check your internet connection for HuggingFace download.

## Summary

‚úÖ You've learned how to:
- Load a pre-trained classifier from HuggingFace
- Preprocess images for inference
- Run classification and interpret results
- Visualize predictions with matplotlib

## Next Steps

- Try [Notebook 02: Conditional Generation](02_conditional_generation.ipynb) to generate cat images
- Try [Notebook 03: Training & Fine-Tuning](03_training_fine_tuning.ipynb) to train your own model
- Read the [model card](https://huggingface.co/d4oit/tiny-cats-model) on HuggingFace
- Check out the [ADR-008](../plans/ADR-008-adapt-tiny-models-architecture-for-cats-classifier-with-web-frontend.md) for architecture details

## References

- Model Repository: https://huggingface.co/d4oit/tiny-cats-model
- ONNX Runtime: https://onnxruntime.ai/
- HuggingFace Hub: https://huggingface.co/docs/huggingface_hub
- Oxford IIIT Pet Dataset: https://www.robots.ox.ac.uk/~vgg/data/pets/