# Custom Domains

This library is designed to be extensible to new domains. To evaluate a model on a custom domain, you need to:

1. Define your dataset
2. Implement an evaluator
3. Implement a model adaptor

The following example demonstrates these steps on the pattern matching domain.


## 1. Define your dataset

A dataset is an iterator over dataset instances satisfying a schema. The schema is defined by a class that inherits from `Instance`.

In [1]:
from genlm.eval import Instance


class PatternMatchingInstance(Instance):
    """Schema for a pattern matching instance."""

    pattern: str
    instance_id: int

    def __repr__(self):
        return f"pattern: {self.pattern} (id: {self.instance_id})"

Given a dataset schema, you can define a dataset by subclassing `Dataset` and implementing an `__iter__` method which yields instances of the schema.


In [2]:
from genlm.eval import Dataset


class PatternMatchingDataset(Dataset[PatternMatchingInstance]):
    """Dataset for pattern matching evaluation."""

    def __init__(self, patterns):
        self.patterns = patterns

    def __iter__(self):
        """Iterate over regex patterns.

        Returns:
            (Iterator[PatternMatchingInstance]): Iterator over regex instances.
        """
        for pattern_id, pattern in enumerate(self.patterns):
            yield PatternMatchingInstance(pattern=pattern, instance_id=pattern_id)

    @property
    def schema(self):
        """Get the schema class for this dataset."""
        return PatternMatchingInstance


## 2. Implement an evaluator

An evaluator is the class responsible for scoring model outputs. Subclasses must minimally implement the `evaluate_sample` method which takes an instance and a response and returns an evaluation result.

In [3]:
import regex
from genlm.eval import Evaluator, EvaluationResult


class PatternMatchingEvaluator(Evaluator[PatternMatchingInstance]):
    """Evaluator for pattern matching."""

    def evaluate_sample(self, instance, response):
        """Evaluate if a response matches the regex pattern."""
        is_valid = regex.compile(instance.pattern).fullmatch(response) is not None
        return EvaluationResult(
            score=int(is_valid), desc="valid" if is_valid else "invalid"
        )

## 3. Implement a model adaptor

A model adaptor is an async callable that takes a `PatternMatchingInstance` and returns a `ModelOutput`. For this example, we'll use a constrained `genlm.control.PromptedLLM` to generate responses.

In [4]:
from genlm.control import PromptedLLM, AWRS
from genlm.eval import ModelOutput, ModelResponse
from genlm.eval.domains.pattern_matching import (
    default_prompt_formatter,
    PatternPotential,
)

# Load an LLM
LLM = PromptedLLM.from_name("gpt2", eos_tokens=[b"\n", b"\n\n"])


async def model(instance, output_dir, replicate):
    # Set the prompt for the LLM.
    LLM.prompt_ids = default_prompt_formatter(
        LLM.model.tokenizer, instance, use_chat_format=False
    )

    # Define a potential that ensures the generated text matches the pattern
    potential = PatternPotential(instance.pattern).coerce(LLM, f=b"".join)

    # Define an adaptive weighted rejection sampler to sample tokens from the constrained model.
    sampler = AWRS(LLM, potential)

    # Run SMC to sample sequences from the constrained model.
    sequences = await sampler.smc(
        n_particles=5,
        ess_threshold=0.5,
        max_tokens=100,
    )

    # Return the sampled sequences and their probabilities as a ModelOutput.
    return ModelOutput(
        responses=[
            ModelResponse(response=sequence, weight=prob)
            for sequence, prob in sequences.decoded_posterior.items()
        ],
    )

  from .autonotebook import tqdm as notebook_tqdm


## 4. Run the evaluation

Using the dataset, evaluator, and model adaptor, we can now run the evaluation:

In [6]:
from genlm.eval import run_evaluation

dataset = PatternMatchingDataset([r"xy|xz", r"ab|c(e|f)"])
evaluator = PatternMatchingEvaluator()

results = await run_evaluation(
    dataset=dataset,
    evaluator=evaluator,
    model=model,
    n_replicates=1,
    verbosity=1,
    # output_dir="results", # uncomment to save results
)

Instance instance_id=0 pattern='xy|xz'
Mean weighted accuracy (instance): 0.9999999999999999
Mean weighted accuracy (total): 0.9999999999999999

Instance instance_id=1 pattern='ab|c(e|f)'
Mean weighted accuracy (instance): 1.0
Mean weighted accuracy (total): 1.0

