Skip to content

Commit

Permalink
Add llm as judge in metrics (#146)
Browse files Browse the repository at this point in the history
Now allows launching an Open AI LLM as a judge simply in the metrics.
  • Loading branch information
NathanHB committed Apr 11, 2024
1 parent 23987aa commit 7f6f807
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,19 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# Inspired by the FastChat Codebase: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/README.md


import ast
import json
import re
import time
from abc import ABC
from typing import Optional

from openai import OpenAI

from lighteval.logging.hierarchical_logger import hlog_warn


# Abstract class for a judge
class Judge(ABC):
def evaluate_answer(answers, questions, references) -> tuple[str, list[dict[str, str]], str]:
pass


class JudgeOpenAI(Judge):
class JudgeOpenAI:
"""
A class representing a judge for evaluating answers using the OpenAI API.
Expand Down Expand Up @@ -70,11 +61,20 @@ class JudgeOpenAI(Judge):
__process_judge_response: Processes the judge's response and extracts the score.
"""

def __init__(self, model: str, seed: int, temperature: float, templates_path: str, openai_api_key: str):
def __init__(
self,
model: str,
seed: int,
temperature: float,
templates_path: str,
openai_api_key: str,
multi_turn: bool = False,
):
self.client = OpenAI(api_key=openai_api_key)
self.model = model
self.seed = seed
self.temperature = temperature
self.multi_turn = multi_turn

data = []
with open(templates_path, "r") as f:
Expand All @@ -95,7 +95,7 @@ def __init__(self, model: str, seed: int, temperature: float, templates_path: st
self.max_tokens = 2048

def evaluate_answer(
self, questions: list[str], answers: list[str], references: list[str], single_turn: bool
self, questions: list[str], answers: list[str], references: list[str]
) -> tuple[int, list[dict[str, str]], str]:
"""
Evaluates an answer using the OpenAI API.
Expand All @@ -112,36 +112,43 @@ def evaluate_answer(
Raises:
Exception: If an error occurs during the API call.
"""
if single_turn:
prompts = self.__get_prompts_single_turn(
questions[0], answers[0], references[0] if len(references) > 0 else None
prompts = [
self.__get_prompts_single_turn(
questions[0], answers[0], references[0] if references is not None and len(references) > 0 else None
)
else:
prompts = self.__get_prompts_multi_turn(questions, answers, references if len(references) > 1 else None)

for _ in range(self.API_MAX_RETRY):
try:
response = self.client.chat.completions.create(
model=self.model,
seed=self.seed,
temperature=self.temperature,
messages=prompts,
max_tokens=self.max_tokens,
n=1,
)
break
except Exception as e:
hlog_warn(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
response = None

if response is None:
]

if self.multi_turn:
prompts_multi_turn = self.__get_prompts_multi_turn(
questions, answers, references if len(references) > 1 else None
)
prompts.append(prompts_multi_turn)

responses = []
for prompt in prompts:
for _ in range(self.API_MAX_RETRY):
try:
response = self.client.chat.completions.create(
model=self.model,
seed=self.seed,
temperature=self.temperature,
messages=prompt,
max_tokens=self.max_tokens,
n=1,
)
responses.append(response)
break
except Exception as e:
hlog_warn(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)

if len(responses) == 0:
raise Exception("Failed to get response from the API")

judgment = response.choices[0].message.content
score = self.__process_judge_response(judgment)
judgments = [response.choices[0].message.content for response in responses]
scores = [self.__process_judge_response(judgment) for judgment in judgments]

return score, prompts, judgment
return scores, prompts, judgments

def __get_prompts_multi_turn(
self, questions: list[str], answers: list[str], references: Optional[list[str]]
Expand Down
30 changes: 30 additions & 0 deletions src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
BertScore,
ExactMatches,
F1_score,
JudgeLLM,
LoglikelihoodAcc,
Recall,
StringDistance,
Expand Down Expand Up @@ -224,6 +225,35 @@ class Metrics(Enum):
corpus_level_fn=np.mean,
higher_is_better=True,
)
llm_judge_multi_turn = SampleLevelMetricGrouping(
metric=["single_turn", "multi_turn"],
higher_is_better=True,
category=MetricCategory.GENERATIVE_MULTI_TURN,
use_case=MetricUseCase.SUMMARIZATION,
sample_level_fn=JudgeLLM(
judge_model_name="gpt-3.5-turbo",
template_path="src/lighteval/tasks/extended/mt_bench/judge_prompts.jsonl",
multi_turn=True,
).compute,
corpus_level_fn={
"single_turn": np.mean,
"multi_turn": np.mean,
},
)
llm_judge = SampleLevelMetricGrouping(
metric=["judge_score"],
higher_is_better=True,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.SUMMARIZATION,
sample_level_fn=JudgeLLM(
judge_model_name="gpt-3.5-turbo",
template_path="src/lighteval/tasks/extended/mt_bench/judge_prompts.jsonl",
multi_turn=False,
).compute,
corpus_level_fn={
"judge_score": np.mean,
},
)
loglikelihood_acc = SampleLevelMetric(
metric="acc",
sample_level_fn=LoglikelihoodAcc().compute,
Expand Down
59 changes: 59 additions & 0 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"""This module manages all the metrics occurring at the sample level. The results of said metrics are then aggregated
using simple function (min, mean, max, ...) at the corpus level. Most metrics fall under this category.
"""
import os
from typing import Union

import nltk
Expand All @@ -38,6 +39,7 @@
from lighteval.metrics.imports.bert_scorer import BERTScorer
from lighteval.metrics.imports.data_stats_metric import DataStatsMetric
from lighteval.metrics.imports.summac import SummaCZS
from lighteval.metrics.llm_as_judge import JudgeOpenAI
from lighteval.metrics.normalizations import remove_braces, remove_braces_and_strip
from lighteval.tasks.requests import Doc
from lighteval.utils import as_list
Expand Down Expand Up @@ -616,3 +618,60 @@ def edit_similarity(self, s1, s2):
"""
edist = edit_distance(s1, s2)
return 1.0 - edist / max(len(s1), len(s2)) if len(s1) > 0 and len(s2) > 0 else 0


class JudgeLLM:
available_models = ["gpt-3.5-turbo"]

def __init__(self, judge_model_name: str, template_path: str, multi_turn: bool = False):
if judge_model_name not in self.available_models:
raise ValueError(f"{judge_model_name} not in available models for llm as a judge metric")

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
self.multi_turn = multi_turn

try:
self.judge = JudgeOpenAI(
model=judge_model_name,
seed=42,
temperature=0.0,
templates_path=template_path,
openai_api_key=OPENAI_API_KEY,
multi_turn=multi_turn,
)
except Exception as e:
print(f"Could not initialize the JudgeOpenAI model:\n{e}")
self.judge = None

def compute(self, predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]:
"""
Compute the score of a generative taks using a llm as a judge.
The generative task can be multiturn with 2 turns max, in that case, we
return scores for turn 1 and 2. Also returns user_prompt and judgment
which are ignored later by the aggregator.
"""

# If we are evaluating a multiturn task, we need to have specific field in the formated doc
if self.multi_turn:
questions = formatted_doc.specific["multi_turn_queries"]
ref_answers = formatted_doc.specific.get("reference", None) if formatted_doc.specific is not None else None
else:
questions = [formatted_doc.query]
ref_answers = [formatted_doc.choices[formatted_doc.gold_index]]

scores, messages, judgements = self.judge.evaluate_answer(questions, predictions, ref_answers)

# Multi turn only has 2 turns
if self.multi_turn:
return {
"single_turn": scores[0],
"multi_turn": scores[1],
"user_prompt": [messages[0], messages[1]],
"judgement": [judgements[0], judgements[1]],
}

return {
"judge_score": scores[0],
"user_prompt": messages[0],
"judgement": judgements[0],
}
57 changes: 1 addition & 56 deletions src/lighteval/tasks/extended/mt_bench/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from aenum import extend_enum
from transformers import AutoModelForCausalLM, AutoTokenizer

from lighteval.tasks.extended.mt_bench.judges import JudgeOpenAI
from lighteval.metrics import Metrics
from lighteval.metrics.utils import MetricCategory, MetricUseCase, SampleLevelMetric, SampleLevelMetricGrouping
from lighteval.tasks.lighteval_task import LightevalTaskConfig
Expand All @@ -35,15 +34,6 @@
from colorama import Fore, Style
import os

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

if OPENAI_API_KEY is None:
# Using print here because hlog_warn is not yet available in this context
print(
Fore.YELLOW
+ "No OpenAI API key found. If you are using the OpenAI judge, please set the OPENAI_API_KEY environment variable."
+ Style.RESET_ALL
)

task = LightevalTaskConfig(
name="mt_bench",
Expand All @@ -55,7 +45,7 @@
evaluation_splits=["train"],
few_shots_split="",
few_shots_select="random",
metric=["mt_bench_metric"],
metric=["llm_judge_multi_turn"],
generation_size=1024,
stop_sequence=[],
)
Expand All @@ -81,54 +71,9 @@ def prompt_fn(line, task_name: str = None):
)


def mt_bench_metric(predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]:
"""Defines how to go from a list of predictions to a score.
Follow examples in src/lighteval/metrics/metrics.py, or get more info
about what this function should do in the README.
"""

judge = JudgeOpenAI(
model="gpt-3.5-turbo",
seed=42,
temperature=0.0,
templates_path="src/lighteval/tasks/extended/mt_bench/judge_prompts.jsonl",
openai_api_key=OPENAI_API_KEY,
)

questions = formatted_doc.specific["multi_turn_queries"]
ref_answers = formatted_doc.specific["reference"]

score, messages, judgement = judge.evaluate_answer(questions, predictions, ref_answers, single_turn=True)
score_mt, messages_mt, judgement_mt = judge.evaluate_answer(questions, predictions, ref_answers, single_turn=False)

return {
"single_turn": score,
"multi_turn": score_mt,
"user_prompt": [messages, messages_mt],
"judgement": [judgement, judgement_mt],
}


mt_bench_metric = SampleLevelMetricGrouping(
metric="mt_bench_metric",
higher_is_better=True,
category=MetricCategory.GENERATIVE_MULTI_TURN,
use_case=MetricUseCase.SUMMARIZATION,
sample_level_fn=mt_bench_metric,
corpus_level_fn={
"single_turn": np.mean,
"multi_turn": np.mean,
},
)

_TASKS = [task]

TASKS_TABLE = [task.as_dict() for task in _TASKS]
extend_enum(
Metrics,
"mt_bench_metric",
mt_bench_metric,
)

if __name__ == "__main__":
print(t["name"] for t in TASKS_TABLE)
Expand Down

0 comments on commit 7f6f807

Please sign in to comment.