In [None]:
!pip install jsonlines
!pip install accelerate

In [None]:
import pandas as pd
import jsonlines
from tqdm import tqdm

import transformers
import torch

In [None]:
torch.set_default_device("cuda")

## Load Document

In [None]:
docid_to_doc = dict()

with jsonlines.open('./data/llm4eval_document_2024.jsonl', 'r') as document_file:
  for obj in document_file:
    docid_to_doc[obj['docid']] = obj['doc']

## Load Query

In [None]:
query_data = pd.read_csv("./data/llm4eval_query_2024.txt", sep="\t", header=None, names=['qid', 'qtext'])
qid_to_query = dict(zip(query_data.qid, query_data.qtext))

In [None]:
system_message = """You are a search quality rater evaluating the relevance of passages. Given a query and passage, you must provide a score on an integer scale of 0 to 3 with the following meanings:

    3 = Perfectly relevant: The passage is dedicated to the query and contains the exact answer.
    2 = Highly relevant: The passage has some answer for the query, but the answer may be a bit unclear, or hidden amongst extraneous information.
    1 = Related: The passage seems related to the query but does not answer it.
    0 = Irrelevant: The passage has nothing to do with the query

    Assume that you are writing an answer to the query. If the passage seems to be related to the query but does not include any answer to the query, mark it 1. If you would use any of the information contained in the passage in such an asnwer, mark it 2. If the passage is primarily about the query, or contains vital information about the topic, mark it 3. Otherwise, mark it 0."""

In [None]:
def get_prompt(query, passage):
    return f"""Please rate how the given passage is relevant to the query. The output must be only a score that indicate how relevant they are.

    Query: {query}
    Passage: {passage}

    Score:"""

In [None]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)

In [None]:
def get_relevance_score(prompt):
  messages = [
      {"role": "system", "content": system_message},
      {"role": "user", "content": prompt},
  ]

  prompt = pipeline.tokenizer.apply_chat_template(
          messages,
          tokenize=False,
          add_generation_prompt=True
  )

  terminators = [
      pipeline.tokenizer.eos_token_id,
      pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
  ]

  outputs = pipeline(
      prompt,
      max_new_tokens=256,
      eos_token_id=terminators,
      pad_token_id=128009,
      do_sample=True,
      temperature=0.6,
      top_p=0.9,
  )

  return outputs[0]["generated_text"][len(prompt):]

In [None]:
test_qrel = pd.read_csv("./data/llm4eval_test_qrel_2024.txt", sep=" ", header=None, names=['qid', 'Q0', 'docid'])
test_qrel.head(5)

In [None]:
with open('llm4eval_test_qrel_results.txt', 'w') as result_file:
  for eachline in tqdm(test_qrel.itertuples(index=True)):
    qidx = eachline.qid
    docidx = eachline.docid
    prompt = get_prompt(query=qid_to_query[qidx], passage=docid_to_doc[docidx])
    pred_score = get_relevance_score(prompt)
    result_file.write(f"{qidx} 0 {docidx} {pred_score}\n")