In [1]:
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
from langchain.globals import set_llm_cache
from langchain.cache import SQLiteCache
set_llm_cache(SQLiteCache(database_path="/tmp/langchain-cache.db"))

In [3]:
import pandas as pd

df = pd.read_json('/Users/bdsaglam/knowledge/bellek/data/generated/musique-kg-llm/train/dataset.jsonl', orient='records', lines=True)
df = df.iloc[:100]

In [4]:
def make_docs(example, only_supporting=False):
    ps = example["paragraphs"]
    for p in ps:
        if only_supporting and not p["is_supporting"]:
            continue
        idx = p["idx"]
        title = p["title"]
        body = p["paragraph_text"]
        is_supporting = p["is_supporting"]
        text = f"# {title}\n{body}"
        yield dict(
            text=text,
            metadata={"parent_id": example["id"], "idx": idx, "is_supporting": is_supporting},
        )

In [5]:
def present_example(example, predicted_answer):
    text = "\n\n".join([p["paragraph_text"] for p in example['paragraphs']])
    print("="*80)
    print("Question:", example["question"])
    print("Reference Answer:", example['answer'])
    print("Predicted Answer:", predicted_answer)
    print("-"*80)
    print("Paragraphs")
    print(text)

In [6]:
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.1)

SYSTEM_PROMPT = """You are a helpful assistant that answers user's questions about the given text in 2-4 words.
# Text
{text}
"""
prompt = ChatPromptTemplate.from_messages([
    ("system", SYSTEM_PROMPT),
    ("user", "{question}")
])

chain = prompt | llm 

In [7]:
def answer(example):
    documents = list(make_docs(example, only_supporting=False))
    text = "\n\n".join([doc["text"] for doc in documents])
    return chain.invoke({"text": text, "question": example['question']}).content

In [8]:
df['predicted_answer'] = df.apply(answer, axis=1)
df.to_json('/Users/bdsaglam/knowledge/bellek/data/generated/musique-kg-llm/train/baseline.jsonl', orient='records', lines=True)

In [9]:
from difflib import SequenceMatcher

def similarity(a, b):
    return SequenceMatcher(None, a, b).ratio()

def fuzzy_match(a, b, threshold=0.7):
    return similarity(a, b) >= threshold

def is_correct(example):
    return (example["answer"] in example["predicted_answer"]) or (
        fuzzy_match(example["predicted_answer"], example["answer"])
    )

In [10]:
df["is_correct"] = df.apply(is_correct, axis=1)
df["is_correct"].mean()

0.48

In [11]:
df[['id', 'question', 'answer', 'predicted_answer', 'is_correct']]

Unnamed: 0,id,question,answer,predicted_answer,is_correct
0,2hop__128801_205185,What county is the town where KNFM is licensed...,Midland County,Midland County,True
1,2hop__719559_217649,What's the record label of the artist who put ...,Warner Bros.,Columbia Records,False
2,2hop__128806_205185,What region is the town where KQRX is liscense...,Midland County,"Midland, Texas is the capital of the region wh...",False
3,2hop__837090_278127,What is the record label of the Do It Again pe...,Roc-A-Fella Records,Roc-A-Fella Records,True
4,2hop__128895_11424,How many households were there in the town WPU...,15504,"15,504 households.",True
...,...,...,...,...,...
95,2hop__651488_94210,Who was the place where Pieta is located desig...,Giorgio Vasari,Pietà (Perugino) - Pietro Perugino,False
96,2hop__362083_467995,What is the record label of the performer of M...,RCA Records,Lotus (Christina Aguilera album),False
97,2hop__525596_543261,The Roman Catholic Diocese of Jim Norton's bir...,Delaware,North Carolina,False
98,2hop__394596_8607,What metro area is JAKAZiD's birthplace a part...,South Hampshire,"Portsmouth, England is not mentioned in the gi...",False
