# Mosquito Segmentation Tutorial

This notebook demonstrates how to use the CulicidaeLab library for segmenting mosquitoes in images using the Segment Anything Model (SAM).

In [None]:
import cv2
import numpy as np
from pathlib import Path
from culicidaelab.segmentation import MosquitoSegmenter

# For visualization
import matplotlib.pyplot as plt

%matplotlib inline

## Initialize the Segmenter

First, we'll create a MosquitoSegmenter instance. You'll need to provide the path to your SAM model checkpoint.

In [None]:
# Initialize segmenter with your model checkpoint
segmenter = MosquitoSegmenter(
    model_type="vit_h",  # or 'vit_l', 'vit_b'
    checkpoint_path="path/to/your/sam_checkpoint.pth",
)

## Load and Process an Image

Let's load an example image and segment mosquitoes in it.

In [None]:
# Load an example image
image_path = "path/to/your/image.jpg"
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Define points or bounding boxes for segmentation
# You can get these from the detector or specify manually
points = np.array([[100, 100]])  # Example point
labels = np.array([1])  # 1 for foreground

# Perform segmentation
masks = segmenter.segment(image, points=points, point_labels=labels)


# Visualize results
def plot_segmentation(image, masks):
    plt.figure(figsize=(12, 8))
    plt.imshow(image)

    for mask in masks:
        # Create a colored overlay for the mask
        overlay = np.zeros_like(image)
        overlay[mask] = [255, 0, 0]  # Red color for the mask
        plt.imshow(overlay, alpha=0.5)

    plt.axis("off")
    plt.show()


plot_segmentation(image, masks)

## Combine Detection and Segmentation

You can combine the detector and segmenter for automatic segmentation:

In [None]:
from culicidaelab.detection import MosquitoDetector


def detect_and_segment(image_path):
    # Load image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Detect mosquitoes
    detector = MosquitoDetector("path/to/your/yolov8_model.pt")
    detections = detector.detect(image)

    # Convert detections to points for segmentation
    points = []
    labels = []
    for x1, y1, x2, y2, conf in detections:
        # Use center point of each detection
        center_x = (x1 + x2) / 2
        center_y = (y1 + y2) / 2
        points.append([center_x, center_y])
        labels.append(1)  # Foreground

    # Perform segmentation
    masks = segmenter.segment(image, points=np.array(points), point_labels=np.array(labels))

    return image, masks


# Example usage:
# image, masks = detect_and_segment('path/to/your/image.jpg')
# plot_segmentation(image, masks)