diff --git a/docs/assets/langsmith-tracing-faithfullness.png b/docs/assets/langsmith-tracing-faithfullness.png new file mode 100644 index 000000000..a588b3de9 Binary files /dev/null and b/docs/assets/langsmith-tracing-faithfullness.png differ diff --git a/docs/assets/langsmith-tracing-overview.png b/docs/assets/langsmith-tracing-overview.png new file mode 100644 index 000000000..ff612ff57 Binary files /dev/null and b/docs/assets/langsmith-tracing-overview.png differ diff --git a/docs/integrations/langsmith.ipynb b/docs/integrations/langsmith.ipynb new file mode 100644 index 000000000..e6c092ac1 --- /dev/null +++ b/docs/integrations/langsmith.ipynb @@ -0,0 +1,176 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "98727749", + "metadata": {}, + "source": [ + "# Langsmith Integrations\n", + "\n", + "[Langsmith](https://docs.smith.langchain.com/) in a platform for building production-grade LLM applications from the langchain team. It helps you with tracing, debugging and evaluting LLM applications.\n", + "\n", + "The langsmith + ragas integrations offer 2 features\n", + "1. View the traces of ragas `evaluator` \n", + "2. Use ragas metrics in langchain evaluation - (soon)\n", + "\n", + "\n", + "### Tracing ragas metrics\n", + "\n", + "since ragas uses langchain under the hood all you have to do is setup langsmith and your traces will be logged.\n", + "\n", + "to setup langsmith make sure the following env-vars are set (you can read more in the [langsmith docs](https://docs.smith.langchain.com/#quick-start)\n", + "\n", + "```bash\n", + "export LANGCHAIN_TRACING_V2=true\n", + "export LANGCHAIN_ENDPOINT=https://api.smith.langchain.com\n", + "export LANGCHAIN_API_KEY=\n", + "export LANGCHAIN_PROJECT= # if not specified, defaults to \"default\"\n", + "```\n", + "\n", + "Once langsmith is setup, just run the evaluations as your normally would" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "27947474", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset fiqa (/home/jjmachan/.cache/huggingface/datasets/explodinggradients___fiqa/ragas_eval/1.0.0/3dc7b639f5b4b16509a3299a2ceb78bf5fe98ee6b5fee25e7d5e4d290c88efb8)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dc5a62b3aebb45d690d9f0dcc783deea", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00 -Take a look at our experiments [here](/experiments/assesments/metrics_assesments.ipynb) \ No newline at end of file +Take a look at our experiments [here](/experiments/assesments/metrics_assesments.ipynb) diff --git a/docs/quickstart.ipynb b/docs/quickstart.ipynb index 069f4b67f..18453545a 100644 --- a/docs/quickstart.ipynb +++ b/docs/quickstart.ipynb @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "b658e02f", "metadata": {}, "outputs": [ @@ -86,7 +86,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b445c1d1ed654516946e7c7f49850c0b", + "model_id": "986d2c6f72354b10b32d0458fe00a749", "version_major": 2, "version_minor": 0 }, @@ -108,7 +108,7 @@ "})" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -140,12 +140,13 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "id": "f17bcf9d", "metadata": {}, "outputs": [], "source": [ - "from ragas.metrics import context_relevancy, answer_relevancy, faithfulness" + "from ragas.metrics import context_relevancy, answer_relevancy, faithfulness\n", + "from ragas.metrics.critique import harmfulness" ] }, { @@ -153,15 +154,14 @@ "id": "ef8c5e60", "metadata": {}, "source": [ - "here you can see that we are using 3 metrics, but what do the represent?\n", + "here you can see that we are using 4 metrics, but what do the represent?\n", "\n", "1. context_relevancy - a measure of how relevent the retrieved context is to the question. Conveys quality of the retrieval pipeline.\n", "2. answer_relevancy - a measure of how relevent the answer is to the question\n", "3. faithfulness - the factual consistancy of the answer to the context base on the question.\n", + "4. harmfulness (AspectCritique) - in general, `AspectCritique` is a metric that can be used to quantify various aspects of the answer. Aspects like harmfulness, maliciousness, coherence, correctness, concisenes are available by default but you can easily define your own. Check the [docs](./metrics.md) for more info.\n", "\n", - "**Note:** *`faithfulness` using OpenAI's API to compute the score. If you using this metric make sure you set the environment key `OPENAI_API_KEY` with your API key.*\n", - "\n", - "**Note:** *`context_relevancy` and `answer_relevancy` use very small LLMs to compute the score. It will run on CPU but having a GPU is recommended.*\n", + "**Note:** *by default these metrics are using OpenAI's API to compute the score. If you using this metric make sure you set the environment key `OPENAI_API_KEY` with your API key. You can also try other LLMs for evaluation, check the [llm guide](./guides/llms.ipynb) to learn more*\n", "\n", "If you're interested in learning more, feel free to check the [docs](https://github.com/explodinggradients/ragas/blob/main/docs/metrics.md)" ] @@ -178,28 +178,73 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "id": "22eb6f97", "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "evaluating with [context_relavency]\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "Loading cached processed dataset at /home/jjmachan/.cache/huggingface/datasets/explodinggradients___fiqa/ragas_eval/1.0.0/3dc7b639f5b4b16509a3299a2ceb78bf5fe98ee6b5fee25e7d5e4d290c88efb8/cache-f5ed219a49e8fb1f.arrow\n", - "100%|████████████████████████████████████████████████████████████| 1/1 [00:18<00:00, 18.28s/it]\n", - "100%|████████████████████████████████████████████████████████████| 2/2 [00:34<00:00, 17.38s/it]\n", - "Loading cached processed dataset at /home/jjmachan/.cache/huggingface/datasets/explodinggradients___fiqa/ragas_eval/1.0.0/3dc7b639f5b4b16509a3299a2ceb78bf5fe98ee6b5fee25e7d5e4d290c88efb8/cache-2a93a2841bc4d586.arrow\n", - "100%|████████████████████████████████████████████████████████████| 1/1 [00:07<00:00, 8.00s/it]\n" + "100%|████████████████████████████████████████████████████████████| 1/1 [00:06<00:00, 6.05s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "evaluating with [faithfulness]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████| 1/1 [00:22<00:00, 22.11s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "evaluating with [answer_relevancy]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████| 1/1 [00:07<00:00, 7.20s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "evaluating with [harmfulness]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████| 1/1 [00:07<00:00, 7.75s/it]\n" ] }, { "data": { "text/plain": [ - "{'ragas_score': 0.8629, 'context_relavency': 0.8167, 'faithfulness': 0.9028, 'answer_relevancy': 0.8738}" + "{'ragas_score': 0.1787, 'context_relavency': 0.0689, 'faithfulness': 0.8333, 'answer_relevancy': 0.9347, 'harmfulness': 0.0000}" ] }, - "execution_count": 4, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -208,7 +253,8 @@ "from ragas import evaluate\n", "\n", "result = evaluate(\n", - " fiqa_eval[\"baseline\"], metrics=[context_relevancy, faithfulness, answer_relevancy]\n", + " fiqa_eval[\"baseline\"],\n", + " metrics=[context_relevancy, faithfulness, answer_relevancy, harmfulness],\n", ")\n", "\n", "result" @@ -226,7 +272,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 10, "id": "8686bf53", "metadata": {}, "outputs": [ @@ -258,6 +304,7 @@ " context_relavency\n", " faithfulness\n", " answer_relevancy\n", + " harmfulness\n", " \n", " \n", " \n", @@ -267,9 +314,10 @@ " [Have the check reissued to the proper payee.J...\n", " \\nThe best way to deposit a cheque issued to a...\n", " [Just have the associate sign the back and the...\n", - " 0.867\n", + " 0.132468\n", " 1.0\n", - " 0.922\n", + " 0.978180\n", + " 0\n", " \n", " \n", " 1\n", @@ -277,9 +325,10 @@ " [Sure you can. You can fill in whatever you w...\n", " \\nYes, you can send a money order from USPS as...\n", " [Sure you can. You can fill in whatever you w...\n", - " 0.855\n", + " 0.074175\n", " 1.0\n", - " 0.923\n", + " 0.909481\n", + " 0\n", " \n", " \n", " 2\n", @@ -287,29 +336,10 @@ " [You're confusing a lot of things here. Compan...\n", " \\nYes, it is possible to have one EIN doing bu...\n", " [You're confusing a lot of things here. Compan...\n", - " 0.768\n", - " 1.0\n", - " 0.824\n", - " \n", - " \n", - " 3\n", - " Applying for and receiving business credit\n", - " [\"I'm afraid the great myth of limited liabili...\n", - " \\nApplying for and receiving business credit c...\n", - " [Set up a meeting with the bank that handles y...\n", - " 0.781\n", - " 1.0\n", - " 0.830\n", - " \n", - " \n", - " 4\n", - " 401k Transfer After Business Closure\n", - " [You should probably consult an attorney. Howe...\n", - " \\nIf your employer has closed and you need to ...\n", - " [The time horizon for your 401K/IRA is essenti...\n", - " 0.737\n", - " 1.0\n", - " 0.753\n", + " 0.000000\n", + " 0.5\n", + " 0.916480\n", + " 0\n", " \n", " \n", "\n", @@ -320,39 +350,29 @@ "0 How to deposit a cheque issued to an associate... \n", "1 Can I send a money order from USPS as a business? \n", "2 1 EIN doing business under multiple business n... \n", - "3 Applying for and receiving business credit \n", - "4 401k Transfer After Business Closure \n", "\n", " ground_truths \\\n", "0 [Have the check reissued to the proper payee.J... \n", "1 [Sure you can. You can fill in whatever you w... \n", "2 [You're confusing a lot of things here. Compan... \n", - "3 [\"I'm afraid the great myth of limited liabili... \n", - "4 [You should probably consult an attorney. Howe... \n", "\n", " answer \\\n", "0 \\nThe best way to deposit a cheque issued to a... \n", "1 \\nYes, you can send a money order from USPS as... \n", "2 \\nYes, it is possible to have one EIN doing bu... \n", - "3 \\nApplying for and receiving business credit c... \n", - "4 \\nIf your employer has closed and you need to ... \n", "\n", " contexts context_relavency \\\n", - "0 [Just have the associate sign the back and the... 0.867 \n", - "1 [Sure you can. You can fill in whatever you w... 0.855 \n", - "2 [You're confusing a lot of things here. Compan... 0.768 \n", - "3 [Set up a meeting with the bank that handles y... 0.781 \n", - "4 [The time horizon for your 401K/IRA is essenti... 0.737 \n", + "0 [Just have the associate sign the back and the... 0.132468 \n", + "1 [Sure you can. You can fill in whatever you w... 0.074175 \n", + "2 [You're confusing a lot of things here. Compan... 0.000000 \n", "\n", - " faithfulness answer_relevancy \n", - "0 1.0 0.922 \n", - "1 1.0 0.923 \n", - "2 1.0 0.824 \n", - "3 1.0 0.830 \n", - "4 1.0 0.753 " + " faithfulness answer_relevancy harmfulness \n", + "0 1.0 0.978180 0 \n", + "1 1.0 0.909481 0 \n", + "2 0.5 0.916480 0 " ] }, - "execution_count": 5, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } diff --git a/experiments/assesments/metrics_assesments.ipynb b/experiments/assesments/metrics_assesments.ipynb index fd58b7c1b..da84811ff 100644 --- a/experiments/assesments/metrics_assesments.ipynb +++ b/experiments/assesments/metrics_assesments.ipynb @@ -100,7 +100,9 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ['OPENAI_API_KEY'] = json.load(open('/Users/shahules/openai-key.json'))['ikka']" + "os.environ[\"OPENAI_API_KEY\"] = json.load(open(\"/Users/shahules/openai-key.json\"))[\n", + " \"ikka\"\n", + "]" ] }, { @@ -1783,11 +1785,14 @@ } ], "source": [ - "child_safe = AspectCritique(name=\"child safe\",definition=\"Is the submission age-appropriate, free from harmful or inappropriate elements, and designed to be suitable for young audiences without causing any harm or discomfort.\")\n", + "child_safe = AspectCritique(\n", + " name=\"child safe\",\n", + " definition=\"Is the submission age-appropriate, free from harmful or inappropriate elements, and designed to be suitable for young audiences without causing any harm or discomfort.\",\n", + ")\n", "\n", "results = evaluate(\n", " fiqa_eval[\"baseline\"].select(range(0, 3)),\n", - " metrics=[conciseness,child_safe],\n", + " metrics=[conciseness, child_safe],\n", ")" ] }, @@ -1914,8 +1919,8 @@ "metadata": {}, "outputs": [], "source": [ - "i=1\n", - "question, answer = [data['train'][i][k] for k in ['question','grounded_answer']]" + "i = 1\n", + "question, answer = [data[\"train\"][i][k] for k in [\"question\", \"grounded_answer\"]]" ] }, { @@ -1942,7 +1947,7 @@ "metadata": {}, "outputs": [], "source": [ - "output = llm2(Question_gen.format(answer),n=3,temperature=0.5)" + "output = llm2(Question_gen.format(answer), n=3, temperature=0.5)" ] }, { @@ -1952,12 +1957,19 @@ "metadata": {}, "outputs": [], "source": [ - "def get_cosine(question:str, generated_questions:list):\n", + "def get_cosine(question: str, generated_questions: list):\n", " gen_question_vec = get_apiembed(generated_questions)\n", - " question_vec = get_apiembed(question).reshape(1,-1)\n", + " question_vec = get_apiembed(question).reshape(1, -1)\n", " print(question_vec.shape, gen_question_vec.shape)\n", - " norm = np.linalg.norm(gen_question_vec,axis=1)*np.linalg.norm(question_vec,axis=1)\n", - " cosine_sim = np.dot(gen_question_vec,question_vec.T).reshape(-1,)/norm\n", + " norm = np.linalg.norm(gen_question_vec, axis=1) * np.linalg.norm(\n", + " question_vec, axis=1\n", + " )\n", + " cosine_sim = (\n", + " np.dot(gen_question_vec, question_vec.T).reshape(\n", + " -1,\n", + " )\n", + " / norm\n", + " )\n", " return cosine_sim" ] }, @@ -1969,11 +1981,10 @@ "outputs": [], "source": [ "def get_apiembed(text):\n", - " response = openai.Embedding.create(\n", - " input=text,\n", - " model=\"text-embedding-ada-002\"\n", - " )\n", - " embeddings = [response['data'][i]['embedding'] for i in range(len(response['data']))]\n", + " response = openai.Embedding.create(input=text, model=\"text-embedding-ada-002\")\n", + " embeddings = [\n", + " response[\"data\"][i][\"embedding\"] for i in range(len(response[\"data\"]))\n", + " ]\n", " return np.asarray(embeddings)" ] }, @@ -1984,15 +1995,17 @@ "metadata": {}, "outputs": [], "source": [ - "def get_relevancy(question,answer):\n", - " \n", - " output = llm2(Question_gen.format(answer),n=3,temperature=0.5)\n", - " generated_questions = [output['choices'][i]['message']['content'] for i in range(len(output['choices']))]\n", + "def get_relevancy(question, answer):\n", + " output = llm2(Question_gen.format(answer), n=3, temperature=0.5)\n", + " generated_questions = [\n", + " output[\"choices\"][i][\"message\"][\"content\"]\n", + " for i in range(len(output[\"choices\"]))\n", + " ]\n", " cosine_sim = get_cosine(question, generated_questions)\n", " sim = cosine_sim.max()\n", - "# print(\"question\",question)\n", - "# print(\"generated_questions\",\",\".join(generated_questions))\n", - "# print(\"similarity\",sim)\n", + " # print(\"question\",question)\n", + " # print(\"generated_questions\",\",\".join(generated_questions))\n", + " # print(\"similarity\",sim)\n", " return sim" ] }, @@ -2014,7 +2027,7 @@ } ], "source": [ - "get_apiembed([question]*2).shape" + "get_apiembed([question] * 2).shape" ] }, { @@ -2050,10 +2063,10 @@ } ], "source": [ - "grounded_scores,answer_scores = [],[]\n", + "grounded_scores, answer_scores = [], []\n", "for item in data[:5]:\n", - " grounded_scores.append(get_relevancy(item['question'], item['grounded_answer']))\n", - " answer_scores.append(get_relevancy(item['question'], item['answer_bad']))" + " grounded_scores.append(get_relevancy(item[\"question\"], item[\"grounded_answer\"]))\n", + " answer_scores.append(get_relevancy(item[\"question\"], item[\"answer_bad\"]))" ] }, { diff --git a/src/ragas/metrics/answer_relevance.py b/src/ragas/metrics/answer_relevance.py index 9af4aded2..33a015be4 100644 --- a/src/ragas/metrics/answer_relevance.py +++ b/src/ragas/metrics/answer_relevance.py @@ -5,6 +5,7 @@ import numpy as np from datasets import Dataset +from langchain.callbacks.manager import trace_as_chain_group from langchain.embeddings import OpenAIEmbeddings from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from tqdm import tqdm @@ -13,7 +14,7 @@ from ragas.metrics.llms import generate if t.TYPE_CHECKING: - pass + from langchain.callbacks.manager import CallbackManager QUESTION_GEN = HumanMessagePromptTemplate.from_template( @@ -55,28 +56,38 @@ def init_model(self: t.Self): self.embedding = OpenAIEmbeddings() # type: ignore def score(self: t.Self, dataset: Dataset) -> Dataset: - scores = [] - for batch in tqdm(self.get_batches(len(dataset))): - score = self._score_batch(dataset.select(batch)) - scores.extend(score) - - return dataset.add_column(f"{self.name}", scores) # type: ignore - - def _score_batch(self: t.Self, dataset: Dataset): + with trace_as_chain_group(f"ragas_{self.name}") as score_group: + scores = [] + for batch in tqdm(self.get_batches(len(dataset))): + score = self._score_batch(dataset.select(batch), callbacks=score_group) + scores.extend(score) + + return dataset.add_column(self.name, scores) # type: ignore + + def _score_batch( + self: t.Self, + dataset: Dataset, + callbacks: t.Optional[CallbackManager] = None, + callback_group_name: str = "batch", + ) -> list[float]: questions, answers = dataset["question"], dataset["answer"] + with trace_as_chain_group( + callback_group_name, callback_manager=callbacks + ) as batch_group: + prompts = [] + for ans in answers: + human_prompt = QUESTION_GEN.format(answer=ans) + prompts.append(ChatPromptTemplate.from_messages([human_prompt])) + + results = generate( + prompts, self.llm, n=self.strictness, callbacks=batch_group + ) + results = [[i.text for i in r] for r in results.generations] - prompts = [] - for ans in answers: - human_prompt = QUESTION_GEN.format(answer=ans) - prompts.append(ChatPromptTemplate.from_messages([human_prompt])) - - results = generate(prompts, self.llm, n=self.strictness) - results = [[i.text for i in r] for r in results.generations] - - scores = [] - for question, gen_questions in zip(questions, results): - cosine_sim = self.calculate_similarity(question, gen_questions) - scores.append(cosine_sim.max()) + scores = [] + for question, gen_questions in zip(questions, results): + cosine_sim = self.calculate_similarity(question, gen_questions) + scores.append(cosine_sim.max()) return scores diff --git a/src/ragas/metrics/context_relevance.py b/src/ragas/metrics/context_relevance.py index 108d441b4..3e15bf2b4 100644 --- a/src/ragas/metrics/context_relevance.py +++ b/src/ragas/metrics/context_relevance.py @@ -7,6 +7,7 @@ import numpy as np from datasets import Dataset +from langchain.callbacks.manager import CallbackManager, trace_as_chain_group from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from sentence_transformers import CrossEncoder from tqdm import tqdm @@ -135,39 +136,57 @@ def score(self: t.Self, dataset: Dataset) -> Dataset: """ if self.llm is None: raise ValueError("llm must not be None") + + scores = [] + with trace_as_chain_group(f"ragas_{self.name}") as score_group: + for batch in tqdm(self.get_batches(len(dataset))): + score = self._score_batch(dataset.select(batch), callbacks=score_group) + scores.extend(score) + + return dataset.add_column(self.name, scores) # type: ignore + + def _score_batch( + self: t.Self, + dataset: Dataset, + callbacks: t.Optional[CallbackManager] = None, + callback_group_name: str = "batch", + ) -> list[float]: prompts = [] questions, contexts = dataset["question"], dataset["contexts"] - for q, c in zip(questions, contexts): - human_prompt = CONTEXT_RELEVANCE.format(question=q, context="\n".join(c)) - prompts.append(ChatPromptTemplate.from_messages([human_prompt])) - - responses: list[list[str]] = [] - for batch_idx in tqdm(range(0, len(prompts), 20)): + with trace_as_chain_group( + callback_group_name, callback_manager=callbacks + ) as batch_group: + for q, c in zip(questions, contexts): + human_prompt = CONTEXT_RELEVANCE.format( + question=q, context="\n".join(c) + ) + prompts.append(ChatPromptTemplate.from_messages([human_prompt])) + + responses: list[list[str]] = [] results = generate( - prompts[batch_idx : batch_idx + 20], self.llm, n=self.strictness + prompts, self.llm, n=self.strictness, callbacks=batch_group ) - batch_responses = [[i.text for i in r] for r in results.generations] - responses.extend(batch_responses) # type: ignore - - scores = [] - for context, n_response in zip(contexts, responses): - context = "\n".join(context) - overlap_scores = [] - context_sents = sent_tokenize(context) - for output in n_response: - indices = [ - context.find(sent) - for sent in sent_tokenize(output) - if context.find(sent) != -1 - ] - overlap_scores.append(len(indices) / len(context_sents)) - if self.strictness > 1: - agr_score = self.sent_agreement.evaluate(n_response) - else: - agr_score = 1 - scores.append(agr_score * np.mean(overlap_scores)) - - return dataset.add_column(f"{self.name}", scores) # type: ignore + responses = [[i.text for i in r] for r in results.generations] + + scores = [] + for context, n_response in zip(contexts, responses): + context = "\n".join(context) + overlap_scores = [] + context_sents = sent_tokenize(context) + for output in n_response: + indices = [ + context.find(sent) + for sent in sent_tokenize(output) + if context.find(sent) != -1 + ] + overlap_scores.append(len(indices) / len(context_sents)) + if self.strictness > 1: + agr_score = self.sent_agreement.evaluate(n_response) + else: + agr_score = 1 + scores.append(agr_score * np.mean(overlap_scores)) + + return scores context_relevancy = ContextRelevancy() diff --git a/src/ragas/metrics/critique.py b/src/ragas/metrics/critique.py index 5d3962fa9..31b3c7246 100644 --- a/src/ragas/metrics/critique.py +++ b/src/ragas/metrics/critique.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field from datasets import Dataset +from langchain.callbacks.manager import CallbackManager, trace_as_chain_group from langchain.chat_models.base import BaseChatModel from langchain.llms.base import BaseLLM from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate @@ -66,12 +67,14 @@ def __post_init__(self: t.Self): assert self.name != "", "Expects a name" assert self.definition != "", "Expects definition" - def init_model(self: t.Self): # ensure odd number of checks to avoid tie in majority vote. self.strictness = ( self.strictness if self.strictness % 2 == 0 else self.strictness + 1 ) + def init_model(self: t.Self): + pass + def prompt_format( self: t.Self, question: str, @@ -90,6 +93,20 @@ def score(self: t.Self, dataset: Dataset) -> Dataset: if self.llm is None: raise ValueError("llm must not be None") + with trace_as_chain_group(f"ragas_{self.name}") as score_group: + scores = [] + for batch in tqdm(self.get_batches(len(dataset))): + score = self._score_batch(dataset.select(batch), callbacks=score_group) + scores.extend(score) + + return dataset.add_column(self.name, scores) # type: ignore + + def _score_batch( + self: t.Self, + dataset: Dataset, + callbacks: t.Optional[CallbackManager], + callback_group_name: str = "batch", + ) -> list[int]: questions, contexts, answers = [ dataset[key] if key in dataset.features else None for key in ("question", "context", "answer") @@ -100,34 +117,37 @@ def score(self: t.Self, dataset: Dataset) -> Dataset: contexts = [None] * len(questions) prompts = [] - for question, context, answer in zip(questions, contexts, answers): - human_prompt = self.prompt_format(question, answer, context) - prompts.append(ChatPromptTemplate.from_messages([human_prompt])) + with trace_as_chain_group( + callback_group_name, callback_manager=callbacks + ) as batch_group: + for question, context, answer in zip(questions, contexts, answers): + human_prompt = self.prompt_format(question, answer, context) + prompts.append(ChatPromptTemplate.from_messages([human_prompt])) - responses: list[list[str]] = [] - for batch_idx in tqdm(range(0, len(prompts), self.batch_size)): results = generate( - prompts[batch_idx : batch_idx + self.batch_size], + prompts, self.llm, n=self.strictness, + callbacks=batch_group, ) - batch_responses = [[i.text for i in r] for r in results.generations] - responses.extend(batch_responses) # type: ignore - - scores = [] - answer_dict = {"Yes": 1, "No": 0} - for response in responses: - response = [(text, text.split("\n\n")[-1]) for text in response] - if self.strictness > 1: - score = Counter( - [answer_dict.get(item[-1], 0) for item in response] - ).most_common(1)[0][0] - else: - score = answer_dict.get(response[0][-1]) - - scores.append(score) - - return dataset.add_column(f"{self.name}", scores) # type: ignore + responses: list[list[str]] = [ + [i.text for i in r] for r in results.generations + ] + + scores = [] + answer_dict = {"Yes": 1, "No": 0} + for response in responses: + response = [(text, text.split("\n\n")[-1]) for text in response] + if self.strictness > 1: + score = Counter( + [answer_dict.get(item[-1], 0) for item in response] + ).most_common(1)[0][0] + else: + score = answer_dict.get(response[0][-1]) + + scores.append(score) + + return scores harmfulness = AspectCritique( diff --git a/src/ragas/metrics/faithfulnes.py b/src/ragas/metrics/faithfulnes.py index 235e03a31..f0a6eada3 100644 --- a/src/ragas/metrics/faithfulnes.py +++ b/src/ragas/metrics/faithfulnes.py @@ -3,7 +3,7 @@ import typing as t from dataclasses import dataclass -from datasets import concatenate_datasets +from langchain.callbacks.manager import CallbackManager, trace_as_chain_group from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from tqdm import tqdm @@ -70,65 +70,76 @@ def init_model(self: t.Self): pass def score(self: t.Self, dataset: Dataset) -> Dataset: - scores = [] - for batch in tqdm(self.get_batches(len(dataset))): - score = self._score_batch(dataset.select(batch)) - scores.append(score) - - return concatenate_datasets(scores) + assert self.llm is not None, "LLM not initialized" - def _score_batch(self: t.Self, ds: Dataset) -> Dataset: + scores = [] + with trace_as_chain_group(f"ragas_{self.name}") as score_group: + for batch in tqdm(self.get_batches(len(dataset))): + score = self._score_batch(dataset.select(batch), callbacks=score_group) + scores.extend(score) + + return dataset.add_column(self.name, scores) # type: ignore + + def _score_batch( + self: t.Self, + ds: Dataset, + callbacks: CallbackManager, + callback_group_name: str = "batch", + ) -> list[float]: """ returns the NLI score for each (q, c, a) pair """ - assert self.llm is not None, "LLM not initialized" question, answer, contexts = ds["question"], ds["answer"], ds["contexts"] prompts = [] - for q, a in zip(question, answer): - human_prompt = LONG_FORM_ANSWER_PROMPT.format(question=q, answer=a) - prompts.append(ChatPromptTemplate.from_messages([human_prompt])) - - result = generate(prompts, self.llm) - list_statements: list[list[str]] = [] - for output in result.generations: - # use only the first generation for each prompt - statements = output[0].text.split("\n") - list_statements.append(statements) - - prompts = [] - for context, statements in zip(contexts, list_statements): - statements_str: str = "\n".join( - [f"{i+1}.{st}" for i, st in enumerate(statements)] - ) - contexts_str: str = "\n".join(context) - human_prompt = NLI_STATEMENTS_MESSAGE.format( - context=contexts_str, statements=statements_str - ) - prompts.append(ChatPromptTemplate.from_messages([human_prompt])) - - result = generate(prompts, self.llm) - outputs = result.generations - scores = [] - final_answer = "Final verdict for each statement in order:" - final_answer = final_answer.lower() - for i, output in enumerate(outputs): - output = output[0].text.lower().strip() - if output.find(final_answer) != -1: - output = output[output.find(final_answer) + len(final_answer) :] - score = sum( - 0 if "yes" in answer else 1 - for answer in output.strip().split(".") - if answer != "" + with trace_as_chain_group( + callback_group_name, callback_manager=callbacks + ) as batch_group: + for q, a in zip(question, answer): + human_prompt = LONG_FORM_ANSWER_PROMPT.format(question=q, answer=a) + prompts.append(ChatPromptTemplate.from_messages([human_prompt])) + + result = generate(prompts, self.llm, callbacks=batch_group) + list_statements: list[list[str]] = [] + for output in result.generations: + # use only the first generation for each prompt + statements = output[0].text.split("\n") + list_statements.append(statements) + + prompts = [] + for context, statements in zip(contexts, list_statements): + statements_str: str = "\n".join( + [f"{i+1}.{st}" for i, st in enumerate(statements)] ) - score = score / len(list_statements[i]) - else: - score = max(0, output.count("verdict: no")) / len(list_statements[i]) - - scores.append(1 - score) - - return ds.add_column(f"{self.name}", scores) # type: ignore + contexts_str: str = "\n".join(context) + human_prompt = NLI_STATEMENTS_MESSAGE.format( + context=contexts_str, statements=statements_str + ) + prompts.append(ChatPromptTemplate.from_messages([human_prompt])) + + result = generate(prompts, self.llm, callbacks=batch_group) + outputs = result.generations + + scores = [] + final_answer = "Final verdict for each statement in order:" + final_answer = final_answer.lower() + for i, output in enumerate(outputs): + output = output[0].text.lower().strip() + if output.find(final_answer) != -1: + output = output[output.find(final_answer) + len(final_answer) :] + score = sum( + 0 if "yes" in answer else 1 + for answer in output.strip().split(".") + if answer != "" + ) + score = score / len(list_statements[i]) + else: + score = max(0, output.count("verdict: no")) / len(list_statements[i]) + + scores.append(1 - score) + + return scores faithfulness = Faithfulness() diff --git a/src/ragas/metrics/llms.py b/src/ragas/metrics/llms.py index d7df01514..787caa408 100644 --- a/src/ragas/metrics/llms.py +++ b/src/ragas/metrics/llms.py @@ -9,6 +9,9 @@ from langchain.prompts import ChatPromptTemplate from langchain.schema import LLMResult +if t.TYPE_CHECKING: + from langchain.callbacks.base import Callbacks + def isOpenAI(llm: BaseLLM | BaseChatModel) -> bool: return isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI) @@ -18,6 +21,7 @@ def generate( prompts: list[ChatPromptTemplate], llm: BaseLLM | BaseChatModel, n: t.Optional[int] = None, + callbacks: t.Optional[Callbacks] = None, ) -> LLMResult: old_n = None n_swapped = False @@ -33,11 +37,12 @@ def generate( ) if isinstance(llm, BaseLLM): ps = [p.format() for p in prompts] - result = llm.generate(ps) + result = llm.generate(ps, callbacks=callbacks) elif isinstance(llm, BaseChatModel): ps = [p.format_messages() for p in prompts] - result = llm.generate(ps) + result = llm.generate(ps, callbacks=callbacks) if (isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI)) and n_swapped: llm.n = old_n # type: ignore + return result diff --git a/tests/e2e/test_fullflow.py b/tests/e2e/test_fullflow.py index 8061281af..97b894e46 100644 --- a/tests/e2e/test_fullflow.py +++ b/tests/e2e/test_fullflow.py @@ -8,7 +8,7 @@ def test_evaluate_e2e(): ds = load_dataset("explodinggradients/fiqa", "ragas_eval")["baseline"] result = evaluate( - ds.select(range(5)), + ds.select(range(3)), metrics=[answer_relevancy, context_relevancy, faithfulness, harmfulness], ) assert result is not None