Skip to content
Merged
142 changes: 142 additions & 0 deletions nbs/metric/base.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "00ef8db1",
"metadata": {},
"outputs": [],
"source": [
"#| default_exp metric.base"
]
},
{
"cell_type": "markdown",
"id": "2eb8f806",
"metadata": {},
"source": [
"# BaseMetric\n",
"> base class for all type of metrics in ragas"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8ccff58",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/random/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"#| export\n",
"\n",
"from abc import ABC, abstractmethod\n",
"import asyncio\n",
"from dataclasses import dataclass, field\n",
"from pydantic import BaseModel\n",
"import typing as t\n",
"from ragas_annotator.metric import MetricResult\n",
"from ragas_annotator.metric import LLM\n",
"\n",
"@dataclass\n",
"class Metric(ABC):\n",
" \"\"\"Base class for all metrics in the LLM evaluation library.\"\"\"\n",
" name: str\n",
" prompt: str\n",
" llm: LLM\n",
" _response_models: t.Dict[bool, t.Type[BaseModel]] = field(\n",
" default_factory=dict, init=False, repr=False\n",
" )\n",
" \n",
" @abstractmethod\n",
" def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]:\n",
" \"\"\"Get the appropriate response model.\"\"\"\n",
" pass\n",
"\n",
" @abstractmethod\n",
" def _ensemble(self, results: t.List[MetricResult]) -> MetricResult:\n",
" pass\n",
" \n",
" \n",
" def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any:\n",
" responses = []\n",
" prompt_input = self.prompt.format(**kwargs)\n",
" for _ in range(n):\n",
" response = self.llm.generate(prompt_input, response_model = self._get_response_model(reasoning)) \n",
" response = MetricResult(**response.model_dump())\n",
" responses.append(response)\n",
" return self._ensemble(responses)\n",
"\n",
"\n",
" async def ascore(self, reasoning: bool = True, n: int = 1, **kwargs) -> MetricResult:\n",
" responses = [] # Added missing initialization\n",
" prompt_input = self.prompt.format(**kwargs)\n",
" for _ in range(n):\n",
" response = await self.llm.agenerate(prompt_input, response_model = self._get_response_model(reasoning))\n",
" response = MetricResult(**response.model_dump()) # Fixed missing parentheses\n",
" responses.append(response)\n",
" return self._ensemble(responses)\n",
" \n",
" def batch_score(self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool = True, n: int = 1) -> t.List[t.Any]:\n",
" return [self.score(reasoning, n, **input_dict) for input_dict in inputs]\n",
" \n",
" async def abatch_score(self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool = True, n: int = 1) -> t.List[MetricResult]:\n",
" async_tasks = []\n",
" for input_dict in inputs:\n",
" # Add reasoning and n to the input parameters\n",
" async_tasks.append(self.ascore(reasoning=reasoning, n=n, **input_dict))\n",
" \n",
" # Run all tasks concurrently and return results\n",
" return await asyncio.gather(*async_tasks)"
]
},
{
"cell_type": "markdown",
"id": "fc4b7458",
"metadata": {},
"source": [
"### Example\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fcf208fa",
"metadata": {},
"outputs": [],
"source": [
"#| eval: false\n",
"\n",
"@dataclass\n",
"class CustomMetric(Metric):\n",
" values: t.List[str] = field(default_factory=lambda: [\"pass\", \"fail\"])\n",
" \n",
" def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]:\n",
" \"\"\"Get or create a response model based on reasoning parameter.\"\"\"\n",
" \n",
" class mymodel(BaseModel):\n",
" result: int\n",
" reason: t.Optional[str] = None\n",
" \n",
" return mymodel \n",
"\n",
" def _ensemble(self,results:t.List[MetricResult]) -> MetricResult:\n",
" \n",
" return results[0] # Placeholder for ensemble logic\n",
"\n",
"my_metric = CustomMetric(name=\"example\", prompt=\"What is the result of {input}?\", llm=LLM())\n",
"my_metric.score(input=\"test\")"
]
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 2
}
215 changes: 215 additions & 0 deletions nbs/metric/decorator.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| default_exp metric.decorator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# decorator factory for metrics\n",
"> decorator factory for creating custom metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"\n",
"import typing as t\n",
"import inspect\n",
"import asyncio\n",
"from dataclasses import dataclass\n",
"from ragas_annotator.metric import MetricResult\n",
"\n",
"\n",
"\n",
"\n",
"def create_metric_decorator(metric_class):\n",
" \"\"\"\n",
" Factory function that creates decorator factories for different metric types.\n",
" \n",
" Args:\n",
" metric_class: The metric class to use (DiscreteMetrics, NumericMetrics, etc.)\n",
" \n",
" Returns:\n",
" A decorator factory function for the specified metric type\n",
" \"\"\"\n",
" def decorator_factory(llm, prompt, name: t.Optional[str] = None, **metric_params):\n",
" \"\"\"\n",
" Creates a decorator that wraps a function into a metric instance.\n",
" \n",
" Args:\n",
" llm: The language model instance to use\n",
" prompt: The prompt template\n",
" name: Optional name for the metric (defaults to function name)\n",
" **metric_params: Additional parameters specific to the metric type\n",
" (values for DiscreteMetrics, range for NumericMetrics, etc.)\n",
" \n",
" Returns:\n",
" A decorator function\n",
" \"\"\"\n",
" def decorator(func):\n",
" # Get metric name and check if function is async\n",
" metric_name = name or func.__name__\n",
" is_async = inspect.iscoroutinefunction(func)\n",
" \n",
" @dataclass\n",
" class CustomMetric(metric_class):\n",
" def _extract_result(self, result, reasoning: bool):\n",
" \"\"\"Extract score and reason from the result.\"\"\"\n",
" if isinstance(result, tuple) and len(result) == 2:\n",
" score, reason = result\n",
" else:\n",
" score, reason = result, None\n",
" \n",
" # Use \"result\" instead of \"score\" for the new MetricResult implementation\n",
" return MetricResult(result=score, reason=reason if reasoning else None)\n",
" \n",
" def _run_sync_in_async(self, func, *args, **kwargs):\n",
" \"\"\"Run a synchronous function in an async context.\"\"\"\n",
" # For sync functions, just run them normally\n",
" return func(*args, **kwargs)\n",
" \n",
" def _execute_metric(self, is_async_execution, reasoning, **kwargs):\n",
" \"\"\"Execute the metric function with proper async handling.\"\"\"\n",
" try:\n",
" if is_async:\n",
" # Async function implementation\n",
" if is_async_execution:\n",
" # In async context, await the function directly\n",
" result = func(self.llm, self.prompt, **kwargs)\n",
" else:\n",
" # In sync context, run the async function in an event loop\n",
" try:\n",
" loop = asyncio.get_event_loop()\n",
" except RuntimeError:\n",
" loop = asyncio.new_event_loop()\n",
" asyncio.set_event_loop(loop)\n",
" result = loop.run_until_complete(func(self.llm, self.prompt, **kwargs))\n",
" else:\n",
" # Sync function implementation\n",
" result = func(self.llm, self.prompt, **kwargs)\n",
" \n",
" return self._extract_result(result, reasoning)\n",
" except Exception as e:\n",
" # Handle errors gracefully\n",
" error_msg = f\"Error executing metric {self.name}: {str(e)}\"\n",
" return MetricResult(result=None, reason=error_msg)\n",
" \n",
" def score(self, reasoning: bool = True, n: int = 1, **kwargs):\n",
" \"\"\"Synchronous scoring method.\"\"\"\n",
" return self._execute_metric(is_async_execution=False, reasoning=reasoning, **kwargs)\n",
" \n",
" async def ascore(self, reasoning: bool = True, n: int = 1, **kwargs):\n",
" \"\"\"Asynchronous scoring method.\"\"\"\n",
" if is_async:\n",
" # For async functions, await the result\n",
" result = await func(self.llm, self.prompt, **kwargs)\n",
" return self._extract_result(result, reasoning)\n",
" else:\n",
" # For sync functions, run normally\n",
" result = self._run_sync_in_async(func, self.llm, self.prompt, **kwargs)\n",
" return self._extract_result(result, reasoning)\n",
" \n",
" # Create the metric instance with all parameters\n",
" metric_instance = CustomMetric(\n",
" name=metric_name,\n",
" prompt=prompt,\n",
" llm=llm,\n",
" **metric_params\n",
" )\n",
" \n",
" # Preserve metadata\n",
" metric_instance.__name__ = metric_name\n",
" metric_instance.__doc__ = func.__doc__\n",
" \n",
" return metric_instance\n",
" \n",
" return decorator\n",
" \n",
" return decorator_factory\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example usage\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"high\n",
"reason\n"
]
}
],
"source": [
"#| eval: false\n",
"\n",
"\n",
"from ragas_annotator.metric import DiscreteMetric\n",
"from ragas_annotator.metric.llm import LLM\n",
"from pydantic import BaseModel\n",
"\n",
"discrete_metric = create_metric_decorator(DiscreteMetric)\n",
"\n",
"@discrete_metric(llm=LLM(),\n",
" prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n",
" name='new_metric',values=[\"low\",\"med\",\"high\"])\n",
"def my_metric(llm,prompt,**kwargs):\n",
"\n",
" class response_model(BaseModel):\n",
" output: t.List[bool]\n",
" reason: str\n",
" \n",
" response = llm.generate(prompt.format(**kwargs),response_model=response_model)\n",
" total = sum(response.output)\n",
" if total < 1:\n",
" score = 'low'\n",
" else:\n",
" score = 'high'\n",
" return score,\"reason\"\n",
"\n",
"result = my_metric.score(response='my response') # result\n",
"print(result)\n",
"print(result.reason)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading