In [None]:
"""
Hierarchical macro F1 metric for the CMI 2025 Challenge.

This script defines a single entry point `score(solution, submission, row_id_column_name)`
that the Kaggle metrics orchestrator will call.
It performs validation on submission IDs and computes a combined binary & multiclass F1 score.
"""

import pandas as pd
from sklearn.metrics import f1_score


class ParticipantVisibleError(Exception):
    """Errors raised here will be shown directly to the competitor."""

    pass


class CompetitionMetric:
    """Hierarchical macro F1 for the CMI 2025 challenge."""

    def __init__(self):
        self.target_gestures = [
            "Above ear - pull hair",
            "Cheek - pinch skin",
            "Eyebrow - pull hair",
            "Eyelash - pull hair",
            "Forehead - pull hairline",
            "Forehead - scratch",
            "Neck - pinch skin",
            "Neck - scratch",
        ]
        self.non_target_gestures = [
            "Write name on leg",
            "Wave hello",
            "Glasses on/off",
            "Text on phone",
            "Write name in air",
            "Feel around in tray and pull out an object",
            "Scratch knee/leg skin",
            "Pull air toward your face",
            "Drink from bottle/cup",
            "Pinch knee/leg skin",
        ]
        self.all_classes = self.target_gestures + self.non_target_gestures

    def calculate_hierarchical_f1(self, sol: pd.DataFrame, sub: pd.DataFrame) -> float:
        # Validate gestures
        invalid_types = {i for i in sub["gesture"].unique() if i not in self.all_classes}
        if invalid_types:
            raise ParticipantVisibleError(f"Invalid gesture values in submission: {invalid_types}")

        # Compute binary F1 (Target vs Non-Target)
        y_true_bin = sol["gesture"].isin(self.target_gestures).values
        y_pred_bin = sub["gesture"].isin(self.target_gestures).values
        f1_binary = f1_score(y_true_bin, y_pred_bin, pos_label=True, zero_division=0, average="binary")

        # Build multi-class labels for gestures
        y_true_mc = sol["gesture"].apply(lambda x: x if x in self.target_gestures else "non_target")
        y_pred_mc = sub["gesture"].apply(lambda x: x if x in self.target_gestures else "non_target")

        # Compute macro F1 over all gesture classes
        f1_macro = f1_score(y_true_mc, y_pred_mc, average="macro", zero_division=0)

        return 0.5 * f1_binary + 0.5 * f1_macro


def score(solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str) -> float:
    """
    Compute hierarchical macro F1 for the CMI 2025 challenge.

    Expected input:
      - solution and submission as pandas.DataFrame
      - Column 'sequence_id': unique identifier for each sequence
      - 'gesture': one of the eight target gestures or "Non-Target"

    This metric averages:
    1. Binary F1 on SequenceType (Target vs Non-Target)
    2. Macro F1 on gesture (mapping non-targets to "Non-Target")

    Raises ParticipantVisibleError for invalid submissions,
    including invalid SequenceType or gesture values.


    Examples
    --------
    >>> import pandas as pd
    >>> row_id_column_name = "id"
    >>> solution = pd.DataFrame({'id': range(4), 'gesture': ['Eyebrow - pull hair']*4})
    >>> submission = pd.DataFrame({'id': range(4), 'gesture': ['Forehead - pull hairline']*4})
    >>> score(solution, submission, row_id_column_name=row_id_column_name)
    0.5
    >>> submission = pd.DataFrame({'id': range(4), 'gesture': ['Text on phone']*4})
    >>> score(solution, submission, row_id_column_name=row_id_column_name)
    0.0
    >>> score(solution, solution, row_id_column_name=row_id_column_name)
    1.0
    """
    # Validate required columns
    for col in (row_id_column_name, "gesture"):
        if col not in solution.columns:
            raise ParticipantVisibleError(f"Solution file missing required column: '{col}'")
        if col not in submission.columns:
            raise ParticipantVisibleError(f"Submission file missing required column: '{col}'")

    metric = CompetitionMetric()
    return metric.calculate_hierarchical_f1(solution, submission)