In [None]:
from IPython import get_ipython
from IPython.display import display

from typing import Generator, Iterable, List, TypeVar

import numpy as np
import supervision as sv
import torch
import umap
from sklearn.cluster import KMeans
from tqdm import tqdm
from transformers import AutoProcessor, SiglipVisionModel

# Define a generic type variable for type hinting
V = TypeVar("V")

# Path to the pretrained SigLip model from Hugging Face
SIGLIP_MODEL_PATH = 'google/siglip-base-patch16-224'

In [None]:
def create_batches(
    sequence: Iterable[V], batch_size: int
) -> Generator[List[V], None, None]:
    """
    Generate batches from a sequence with a specified batch size.

    Args:
        sequence (Iterable[V]): The input sequence to be batched.
        batch_size (int): The size of each batch.

    Yields:
        Generator[List[V], None, None]: A generator yielding batches of the input
            sequence.
    """
    # Ensure batch size is at least 1
    batch_size = max(batch_size, 1)
    current_batch = []

    # Iterate over each element in the sequence
    for element in sequence:
        # If current batch reaches desired batch size, yield it and reset
        if len(current_batch) == batch_size:
            yield current_batch
            current_batch = []
        # Append current element to the batch
        current_batch.append(element)

    # Yield any remaining elements as the last batch
    if current_batch:
        yield current_batch

In [None]:
class TeamClassifier:
    """
    A classifier that uses a pre-trained SiglipVisionModel for feature extraction,
    UMAP for dimensionality reduction, and KMeans for clustering.
    """
    def __init__(self, device: str = 'cpu', batch_size: int = 32):
        """
        Initialize the TeamClassifier with device and batch size.

        Args:
            device (str): The device to run the model on ('cpu' or 'cuda').
            batch_size (int): The batch size for processing images.
        """
        self.device = device
        self.batch_size = batch_size

        # Load the pre-trained Siglip vision model and move it to the specified device
        self.features_model = SiglipVisionModel.from_pretrained(
            SIGLIP_MODEL_PATH).to(device)

        # Initialize the processor for image preprocessing
        self.processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH)

        # Initialize UMAP for reducing feature dimensions to 3 components
        self.reducer = umap.UMAP(n_components=3)

        # Initialize KMeans clustering model with 2 clusters
        self.cluster_model = KMeans(n_clusters=2)

    def extract_features(self, crops: List[np.ndarray]) -> np.ndarray:
        """
        Extract features from a list of image crops using the pre-trained
        SiglipVisionModel.

        Args:
            crops (List[np.ndarray]): List of image crops.

        Returns:
            np.ndarray: Extracted features as a numpy array.
        """
        # Convert OpenCV images to PIL images for processing
        crops = [sv.cv2_to_pillow(crop) for crop in crops]

        # Create batches of images for efficient processing
        batches = create_batches(crops, self.batch_size)
        data = []

        # Disable gradient calculation for inference
        with torch.no_grad():
            # Iterate over batches with progress bar
            for batch in tqdm(batches, desc='Embedding extraction'):
                # Process images and prepare tensors on the device
                inputs = self.processor(
                    images=batch, return_tensors="pt").to(self.device)

                # Extract features from the model
                outputs = self.features_model(**inputs)

                # Compute mean pooling of the last hidden state as embeddings
                embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()

                # Collect embeddings from all batches
                data.append(embeddings)

        # Concatenate all batch embeddings into one array
        return np.concatenate(data)

    def fit(self, crops: List[np.ndarray]) -> None:
        """
        Fit the classifier model on a list of image crops.

        Args:
            crops (List[np.ndarray]): List of image crops.
        """
        # Extract feature embeddings from crops
        data = self.extract_features(crops)

        # Apply UMAP dimensionality reduction
        projections = self.reducer.fit_transform(data)

        # Fit KMeans clustering on reduced projections
        self.cluster_model.fit(projections)

    def predict(self, crops: List[np.ndarray]) -> np.ndarray:
        """
        Predict the cluster labels for a list of image crops.

        Args:
            crops (List[np.ndarray]): List of image crops.

        Returns:
            np.ndarray: Predicted cluster labels.
        """
        # Return empty array if input list is empty
        if len(crops) == 0:
            return np.array([])

        # Extract feature embeddings from crops
        data = self.extract_features(crops)

        # Project embeddings using fitted UMAP reducer
        projections = self.reducer.transform(data)

        # Predict cluster labels using trained KMeans model
        return self.cluster_model.predict(projections)