# Ensemble Evaluator Development Template

Use this template to iterate on a new ensemble. The Evaluator Development Template notebook may be helpful as a precursor to understanding the flow here.

## Imports

In [None]:
import datetime
from modelplane.runways import responder, annotator, scorer

## Settings

The `sut_id` refers to the model that generates the responses to the prompts. 

The `experiment` variable will be used to organize the various runs in mlflow.

The `prompts` should point to a location in `/flightpaths/data`. A sample dataset is provided.

You can cache prompt responses via `cache_dir`.

Finally, `n_jobs` can adjust the parallelism.

In [None]:
sut_id = "demo_yes_no"
experiment = "ensemble_experiment_" + datetime.date.today().strftime("%Y%m%d")
prompts = "data/airr_official_1.0_demo_en_us_prompt_set_release_reduced.csv"
ground_truth = "data/fakegroundtruth.csv"
cache_dir = None
n_jobs = 4

## Defining a new ensemble

An ensemble is a collection of annotators together with combining logic.

* An implementation of an annotator should define:

  * `translate_request`: takes a `TestItem` and `SUTResponse` (which have the needed information on the prompt responses for the given LLM) and translates the request into an arbitrary form useful for the `annotate` method.
  * `annotate`: take the translated request and produce an annotation in any form.
  * `translate_response`: take the annotation and translate it to the needed from (`SafetyModelResponse`).

* An implementation of the combining logic should define `compute_response` which takes a map of `SafetyModelResponse`s from the underlying annotators and combines them to produce a final annotation (`EnsembleSafetyModelResponse`, which is simply a `SafetyModelResponse` along with the map of underlying responses).

Below, we give a sample `AnyEnsemble`, which takes two independent `RandomAnnotator` responses and reports unsafe if any underlying annotation was unsafe. Implementations of the `RandomAnnotator` and `AnyUnsafeStrategy`.

### Annotators - `RandomAnnotator`

Of course, if the annotators are already defined, you can access them directly by their annotator ids.

Annotators must be registered using the annotator registry `ANNOTATORS`.

In [None]:
import random

from pydantic import BaseModel

from modelgauge.annotator import CompletionAnnotator
from modelgauge.annotator_registry import ANNOTATORS
from modelgauge.annotators.core import SafetyModelResponse
from modelgauge.single_turn_prompt_response import TestItem
from modelgauge.sut import SUTResponse
from modelgauge.tests.safe_v1 import BaseSafeTestVersion1

class AnnotationRequest(BaseModel):
    test_item: TestItem
    response: SUTResponse


class RandomAnnotator(CompletionAnnotator[SafetyModelResponse]):
    """Simple random implementation for demonstration. Uses SafetyModelResponse
    for consistency with AILuminate benchmarks.

    The methods to implement are based on the idea that annotate will 
    potentially make a call to another system and the translators are helpful 
    in preparing that call and converting the response to the needed annotation.
    """
    def __init__(self, uid, prob_unsafe):
        super().__init__(uid)
        self.prob_unsafe = prob_unsafe
    
    def translate_request(self, test_item: TestItem, response: SUTResponse):
        return AnnotationRequest(test_item=test_item, response=response)

    def annotate(self, annotation_request: AnnotationRequest) -> SafetyModelResponse:
        unsafe_categories = []
        for category in BaseSafeTestVersion1.hazards:
            if random.random() < self.prob_unsafe:
                unsafe_categories.append(category)
        return SafetyModelResponse(
            is_safe=not unsafe_categories,
            logprobs=random.random(),
            safety_categories=unsafe_categories,
        )

    def translate_response(self, request: AnnotationRequest, response: SafetyModelResponse) -> SafetyModelResponse:
        return response


annotator_id1 = "random01"
ANNOTATORS.register(RandomAnnotator, annotator_id1, 0.1)

annotator_id2 = "random02"
ANNOTATORS.register(RandomAnnotator, annotator_id2, 0.2)

### Combiner logic - `AnyUnsafeStrategy`

This implementation will return is_safe = "unsafe" if any of the ensemble members decide it's unsafe.

Combiner logic must be declared via `ENSEMBLE_STRATEGIES`.

In [None]:
from typing import Dict

from modelgauge.annotators.core import EnsembleSafetyModelResponse
from modelgauge.ensemble_annotator_set import EnsembleStrategy, ENSEMBLE_STRATEGIES


class AnyUnsafeEnsembleStrategy(EnsembleStrategy):
    def compute_response(self, annotations: Dict[str, SafetyModelResponse]) -> EnsembleSafetyModelResponse:
        response = EnsembleSafetyModelResponse(
            joined_responses=annotations, 
            is_safe=False,
            safety_categories=[],
            is_valid=False,
        )
        print(annotations.keys())
        for annotation_id in annotations:
            annotation = annotations[annotation_id]
            if annotation.is_valid:
                response.is_valid = True
                if not annotation.is_safe:
                    response.is_safe = False
                    current_categories = set(response.safety_categories)
                    current_categories.update(annotation.safety_categories)
                    response.safety_categories = sorted(current_categories)
        return response

ENSEMBLE_STRATEGIES["anyunsafe"] = AnyUnsafeEnsembleStrategy()

## Run the model

This step will get responses to the prompts from the given SUT.

Save this run_id to avoid having to re-run the model later. The results are saved as an artifact in mlflow.

You can see the the runs associated with the experiment you specified above at your MLFlow tracking server. There will be a link below after running. You may need to replace http://mlflow:8080 with http://localhost:8080 in the generated mlflow links if you used the default settings.

In [None]:
run_id = responder.respond(
    sut_id=sut_id,
    experiment=experiment,
    prompts=prompts,
    cache_dir=cache_dir,
    n_jobs=n_jobs,
)

## Annotate the model

This will run the evaluator on the responses from the prior step. You'll be able to see the details of the annotations in mlflow.

In [None]:
annotation_run_id = annotator.annotate(
    annotator_ids=[annotator_id1, annotator_id2],
    ensemble_strategy="anyunsafe",
    experiment=experiment,
    response_run_id=run_id,
    cache_dir=cache_dir,
    n_jobs=n_jobs,
)

## Score the model

Compute metrics against the given ground truth dataset.

In [None]:
scorer.score(
    annotation_run_id=annotation_run_id,
    experiment=experiment,
    ground_truth=ground_truth,
)