/
chat_eval_chain.py
45 lines (38 loc) 路 1.5 KB
/
chat_eval_chain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""LLM Chain specifically for evaluating question answering."""
from __future__ import annotations
from typing import Any, List
from langchain.chains.llm import ChatModelChain
from langchain.chat_models.base import BaseChatModel
from langchain.evaluation.qa.eval_prompt import CHAT_PROMPT
from langchain.prompts.base import BasePromptTemplate
class QAEvalChain(ChatModelChain):
"""LLM Chain specifically for evaluating question answering."""
@classmethod
def from_llm(
cls, llm: BaseChatModel, prompt: BasePromptTemplate = CHAT_PROMPT, **kwargs: Any
) -> QAEvalChain:
expected_input_vars = {"query", "answer", "result"}
if expected_input_vars != set(prompt.input_variables):
raise ValueError(
f"Input variables should be {expected_input_vars}, "
f"but got {prompt.input_variables}"
)
return cls(llm=llm, prompt=prompt, **kwargs)
def evaluate(
self,
examples: List[dict],
predictions: List[dict],
question_key: str = "query",
answer_key: str = "answer",
prediction_key: str = "result",
) -> List[dict]:
"""Evaluate question answering examples and predictions."""
inputs = [
{
"query": example[question_key],
"answer": example[answer_key],
"result": predictions[i][prediction_key],
}
for i, example in enumerate(examples)
]
return self.apply(inputs)