# Tutorial 04: Multiclass Classification

ðŸŸ¡ **Intermediate** â€” Familiarity with ML concepts helpful

Learn how to train a multiclass classifier using the softmax objective.

## What you'll learn

1. Train a multiclass classifier
2. Interpret multiclass probabilities
3. Visualize confusion matrices
4. Handle class imbalance

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay

from boosters.sklearn import GBDTClassifier

## Load Data

We'll use the classic Iris dataset (3 classes):

In [None]:
# Load Iris dataset
iris = load_iris()
X, y = iris.data, iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print(f"Classes: {iris.target_names}")
print(f"Features: {iris.feature_names}")
print(f"Class distribution: {np.bincount(y_train)}")

## Train Multiclass Classifier

In [None]:
# Train multiclass classifier
clf = GBDTClassifier(
    n_estimators=100,
    max_depth=4,
    learning_rate=0.1,
)
clf.fit(X_train, y_train)

print(f"Accuracy: {clf.score(X_test, y_test):.4f}")

## Multiclass Probabilities

In [None]:
# Get probability predictions
y_proba = clf.predict_proba(X_test)

print(f"Probability shape: {y_proba.shape}")
print(f"\nSample probabilities (first 5 samples):")
for i in range(5):
    print(f"  Sample {i}: {y_proba[i]} -> predicted: {clf.classes_[np.argmax(y_proba[i])]}")

## Confusion Matrix

In [None]:
# Get predictions
y_pred = clf.predict(X_test)

# Plot confusion matrix
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=iris.target_names)

fig, ax = plt.subplots(figsize=(8, 6))
disp.plot(ax=ax, cmap='Blues')
plt.title('Confusion Matrix')
plt.show()

## Classification Report

In [None]:
print(classification_report(y_test, y_pred, target_names=iris.target_names))

## Summary

In this tutorial, you learned how to:

1. âœ… Train a multiclass classifier with boosters
2. âœ… Interpret multiclass probability outputs
3. âœ… Visualize confusion matrices
4. âœ… Analyze per-class metrics

## Next Steps

- [Tutorial 05: Early Stopping](05-early-stopping.ipynb) â€” Prevent overfitting
- [Tutorial 07: Hyperparameter Tuning](07-hyperparameter-tuning.ipynb) â€” Optimize performance