# Mosquito Species Classification Tutorial

This notebook demonstrates how to use the CulicidaeLab library for classifying mosquito species using FastAI and timm models.

In [None]:
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from culicidaelab.classification import MosquitoClassifier

%matplotlib inline

## Initialize the Classifier

First, we'll create a MosquitoClassifier instance. You can either provide a pre-trained model or train a new one.

In [None]:
# Initialize classifier with your model
classifier = MosquitoClassifier(
    model_path="path/to/your/model.pth",  # Optional
    arch="convnext_base",  # or other architectures from timm
    config_path="path/to/species_config.yaml",  # Optional
)

## Training a New Model

If you don't have a pre-trained model, you can train one using your dataset:

In [None]:
# Train the model
data_path = "path/to/your/training/data"
# Data directory structure should be:
# data_path/
#   ├── species1/
#   │   ├── image1.jpg
#   │   └── image2.jpg
#   └── species2/
#       ├── image3.jpg
#       └── image4.jpg

metrics = classifier.train(data_path=data_path, epochs=10, batch_size=16, learning_rate=1e-4)

# Plot training metrics
plt.figure(figsize=(10, 5))
plt.plot(metrics["train_loss"], label="Train Loss")
plt.plot(metrics["val_loss"], label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

## Classify Images

Now let's use the trained model to classify some mosquito images:

In [None]:
def classify_and_plot(image_path):
    # Classify the image
    predictions = classifier.classify(image_path)

    # Load and display the image
    image = plt.imread(image_path)
    plt.figure(figsize=(10, 5))

    # Plot image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.axis("off")
    plt.title("Input Image")

    # Plot predictions
    plt.subplot(1, 2, 2)
    species = [p[0] for p in predictions]
    scores = [p[1] for p in predictions]

    plt.barh(range(len(species)), scores)
    plt.yticks(range(len(species)), species)
    plt.xlabel("Confidence")
    plt.title("Species Predictions")

    plt.tight_layout()
    plt.show()


# Example usage:
# classify_and_plot('path/to/your/mosquito_image.jpg')

## Batch Processing

Process multiple images at once:

In [None]:
def process_directory(image_dir):
    image_dir = Path(image_dir)
    results = {}

    for img_path in image_dir.glob("*.jpg"):
        predictions = classifier.classify(str(img_path))
        top_prediction = predictions[0]  # Get the most likely species
        results[img_path.name] = top_prediction

        print(f"{img_path.name}: {top_prediction[0]} ({top_prediction[1]:.2%} confidence)")

    return results


# Example usage:
# results = process_directory('path/to/image/directory')

## Model Evaluation

Evaluate the model's performance on a test dataset:

In [None]:
def evaluate_model(test_data_path):
    test_dir = Path(test_data_path)
    y_true = []
    y_pred = []

    for species_dir in test_dir.iterdir():
        if species_dir.is_dir():
            species = species_dir.name
            for img_path in species_dir.glob("*.jpg"):
                predictions = classifier.classify(str(img_path))
                predicted_species = predictions[0][0]

                y_true.append(species)
                y_pred.append(predicted_species)

    # Calculate metrics
    metrics = classifier.evaluate(y_true, y_pred)

    print("\nModel Evaluation Metrics:")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")

    return metrics


# Example usage:
# metrics = evaluate_model('path/to/test/data')