From 0a82a1e1a2d8e6d089bb7d3426c46f1a7c7230b9 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Wed, 9 Apr 2025 17:20:07 -0700 Subject: [PATCH] patch for fixing few shot optim --- nbs/metric/base.ipynb | 42 ++++++++++++++++++++++++++++------ ragas_annotator/__init__.py | 2 +- ragas_annotator/_modidx.py | 2 ++ ragas_annotator/metric/base.py | 22 +++++++++++++----- 4 files changed, 54 insertions(+), 14 deletions(-) diff --git a/nbs/metric/base.ipynb b/nbs/metric/base.ipynb index c31ecaa..6366fee 100644 --- a/nbs/metric/base.ipynb +++ b/nbs/metric/base.ipynb @@ -42,8 +42,9 @@ "from dataclasses import dataclass, field\n", "from pydantic import BaseModel\n", "import typing as t\n", - "import json\n", "from tqdm import tqdm\n", + "import string\n", + "\n", "\n", "from ragas_annotator.prompt.base import Prompt\n", "from ragas_annotator.embedding.base import BaseEmbedding\n", @@ -76,7 +77,14 @@ " @abstractmethod\n", " def _ensemble(self, results: t.List[MetricResult]) -> MetricResult:\n", " pass\n", - " \n", + " \n", + " def get_variables(self) -> t.List[str]:\n", + " if isinstance(self.prompt, Prompt):\n", + " fstr = self.prompt.instruction\n", + " else:\n", + " fstr = self.prompt\n", + " vars = [field_name for _, field_name, _, _ in string.Formatter().parse(fstr) if field_name]\n", + " return vars\n", " \n", " def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any:\n", " responses = []\n", @@ -130,13 +138,15 @@ " datasets.append(experiment_data)\n", " \n", " total_items = sum([len(dataset) for dataset in datasets])\n", + " input_vars = self.get_variables()\n", + " output_vars = [self.name, f'{self.name}_reason']\n", " with tqdm(total=total_items, desc=\"Processing examples\") as pbar:\n", " for dataset in datasets:\n", " for row in dataset:\n", - " if hasattr(row, f'{self.name}_traces'):\n", - " traces = json.loads(getattr(row, f'{self.name}_traces'))\n", - " if traces:\n", - " self.prompt.add_example(traces['input'],traces['output'])\n", + " inputs = {var: getattr(row, var) for var in input_vars if hasattr(row, var)}\n", + " output = {var: getattr(row, var) for var in output_vars if hasattr(row, var)}\n", + " if output:\n", + " self.prompt.add_example(inputs,output)\n", " pbar.update(1)\n", " \n", " \n", @@ -160,7 +170,18 @@ "execution_count": null, "id": "fcf208fa", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "100" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "#| eval: false\n", "\n", @@ -189,6 +210,13 @@ "my_metric = CustomMetric(name=\"example\", prompt=\"What is the result of {input}?\", llm=llm)\n", "my_metric.score(input=\"test\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/ragas_annotator/__init__.py b/ragas_annotator/__init__.py index 0d55343..cb36591 100644 --- a/ragas_annotator/__init__.py +++ b/ragas_annotator/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.1" +__version__ = "0.0.2" # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/init_module.ipynb. # %% auto 0 diff --git a/ragas_annotator/_modidx.py b/ragas_annotator/_modidx.py index e9a1ff1..67117fd 100644 --- a/ragas_annotator/_modidx.py +++ b/ragas_annotator/_modidx.py @@ -183,6 +183,8 @@ 'ragas_annotator/metric/base.py'), 'ragas_annotator.metric.base.Metric.batch_score': ( 'metric/base.html#metric.batch_score', 'ragas_annotator/metric/base.py'), + 'ragas_annotator.metric.base.Metric.get_variables': ( 'metric/base.html#metric.get_variables', + 'ragas_annotator/metric/base.py'), 'ragas_annotator.metric.base.Metric.score': ( 'metric/base.html#metric.score', 'ragas_annotator/metric/base.py'), 'ragas_annotator.metric.base.Metric.train': ( 'metric/base.html#metric.train', diff --git a/ragas_annotator/metric/base.py b/ragas_annotator/metric/base.py index 78d1104..1c64ce4 100644 --- a/ragas_annotator/metric/base.py +++ b/ragas_annotator/metric/base.py @@ -11,8 +11,9 @@ from dataclasses import dataclass, field from pydantic import BaseModel import typing as t -import json from tqdm import tqdm +import string + from ..prompt.base import Prompt from ..embedding.base import BaseEmbedding @@ -45,7 +46,14 @@ def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]: @abstractmethod def _ensemble(self, results: t.List[MetricResult]) -> MetricResult: pass - + + def get_variables(self) -> t.List[str]: + if isinstance(self.prompt, Prompt): + fstr = self.prompt.instruction + else: + fstr = self.prompt + vars = [field_name for _, field_name, _, _ in string.Formatter().parse(fstr) if field_name] + return vars def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any: responses = [] @@ -99,13 +107,15 @@ def train(self,project:Project, experiment_names: t.List[str], model:NotionModel datasets.append(experiment_data) total_items = sum([len(dataset) for dataset in datasets]) + input_vars = self.get_variables() + output_vars = [self.name, f'{self.name}_reason'] with tqdm(total=total_items, desc="Processing examples") as pbar: for dataset in datasets: for row in dataset: - if hasattr(row, f'{self.name}_traces'): - traces = json.loads(getattr(row, f'{self.name}_traces')) - if traces: - self.prompt.add_example(traces['input'],traces['output']) + inputs = {var: getattr(row, var) for var in input_vars if hasattr(row, var)} + output = {var: getattr(row, var) for var in output_vars if hasattr(row, var)} + if output: + self.prompt.add_example(inputs,output) pbar.update(1)