<a target="_blank" href="https://colab.research.google.com/github/gox6/colab-demos/blob/main/rags/evaluate-rags-rigorously-or-perish.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# The code accompanying the Medium article [Evaluate RAGs Rigorously or Perish](https://medium.com/@jgrygolec/evaluate-rags-rigorously-or-perish-54f790557357)

# Project Setup

In [1]:
# Installing Python packages & hiding

!pip install --quiet \
  chromadb \
  datasets \
  langchain \
  langchain_chroma \
  optuna \
  plotly \
  polars \
  ragas \
  1> /dev/null

In [73]:
# Importing the packages
from functools import reduce
import json
import os
import requests
import warnings

import chromadb
from chromadb.api.models.Collection import Collection as ChromaCollection
from datasets import load_dataset, Dataset
from getpass import getpass
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.runnables.base import RunnableSequence
from langchain_community.document_loaders import WebBaseLoader, PolarsDataFrameLoader
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import CharacterTextSplitter
from operator import itemgetter
import optuna
import pandas as pd
import plotly.express as px
import polars as pl
from ragas import evaluate
from ragas.metrics import (
    answer_relevancy,
    faithfulness,
    context_recall,
    context_precision,
    answer_correctness
)
from ragas.testset.generator import TestsetGenerator
from ragas.testset.evolutions import simple, reasoning, multi_context, conditional


In [3]:
# Managing secrets
# - If using Colab please use Colab Secrets
# - If running outside Colab please provide secrets as environmental variables
COLAB = os.getenv("COLAB_RELEASE_TAG") is not None

if COLAB:
  from google.colab import userdata, data_table
  # Secrets
  OPENAI_API_KEY = userdata.get('OPENAI_API_KEY')
  os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
  runtime_info = "Colab runtime"

  # Enabling Colab's data formatter for pandas
  data_table.enable_dataframe_formatter()
elif OPENAI_API_KEY := os.environ.get('OPENAI_API_KEY'):
  # Secrets
  runtime_info = "Non Colab runtime"
else:
  OPENAI_API_KEY = getpass("OPENAI_API_KEY")
  os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
  runtime_info = "Non Colab runtime"

print(runtime_info)

Colab runtime


#Exploring Different Types of Question Evolution in RAGAs




In [4]:
# Getting the batch article
the_batch_newsletter_loader = WebBaseLoader("https://www.deeplearning.ai/the-batch/data-points-issue-245/")
the_batch_newsletter = the_batch_newsletter_loader.load()


the_batch_newsletter

[Document(page_content='Big Updates for GPT-4 Turbo, Gemini 1.5, Mixtral, and More\uf8ffüåü New Course! Enroll in Getting Started with MistralCoursesShort CoursesSpecializationsAI NewsletterThe BatchAndrew\'s LetterData PointsML ResearchBlogCommunityForumEventsAmbassadorsAmbassador SpotlightResourcesCompanyAboutCareersContactStart LearningWeekly IssuesAndrew\'s LettersData PointsML ResearchBusinessScienceAI & SocietyCultureHardwareAI CareersAboutSubscribeThe BatchData PointsArticleBig Updates for GPT-4 Turbo, Gemini 1.5, Mixtral, and More Plus, AI Helps Rebuild Lost MemoriesData PointsPublishedApr 17, 2024Reading time6 min readShareThis week\'s top AI news and research stories\xa0featured\xa0Google\'s Vertex AI Agent Builder, security holes in generated code, a series of policy violations in the GPT Store, and RA-DIT, a fine-tuning procedure that trains an LLM and retrieval model together to improve the LLM‚Äôs ability to capitalize on retrieved content. But first: U.S. and Japan gover

In [5]:
# Examining question evolution types evailable in ragas library
llm = ChatOpenAI(model="gpt-3.5-turbo")
generator_llm = llm
critic_llm = llm
embeddings = OpenAIEmbeddings()

example_generator = TestsetGenerator.from_langchain(
    generator_llm,
    critic_llm,
    embeddings
)

# Change resulting question type distribution
distributions = [{simple: 1}, {reasoning: 1}, {multi_context: 1}, {conditional: 1}]

In [30]:
# This step COSTS $$$ ...
# Generating the example evolutions
avoid_costs = True

if not avoid_costs:
  # Running ragas to get examples of question evolutions
  question_evolution_types = list(map(lambda x: example_generator.generate_with_langchain_docs(the_batch_newsletter, 1, x), distributions))
  question_evolution_types_pd = reduce(lambda x, y: pd.concat([x, y], axis=0), [x.to_pandas() for x in question_evolution_types])
  question_evolution_types_pd = question_evolution_types_pd.loc[:, ["evolution_type", "question", "ground_truth"]]
else:
  # Downloading examples for question evolutions discussed in the article:
  question_evolution_types_pd  = pl.read_csv(
    "https://gist.github.com/gox6/31f66ff936be445a9d16836a79f640a9/raw/example-question-evolution-types-in-ragas.csv",
    separator=",",
).drop("index").to_pandas()


In [31]:
# Displaying examples
# There is randomness in generating evaluation sets in ragas, which stems for the ragas sampling as well as from LLMs indeterminacy.
# As the result the above generated examples may be different than the ones described in the Medium blog post, which are displayed below.

display(question_evolution_types_pd)

Unnamed: 0,evolution_type,question,ground_truth
0,simple,How has generative AI adoption impacted artist...,Generative AI adoption has positively impacted...
1,reasoning,How does Spotify's AI playlist generator work ...,Spotify's AI Playlist feature allows users to ...
2,multi_context,What sets GPT-4 Turbo apart from Mixtral 8x22B?,GPT-4 Turbo stands out from Mixtral 8x22B due ...
3,conditional,How does generative AI impact artists' product...,Generative AI adoption boosts artists' product...


#Getting data: CNN and Daily Mail news articles


In [32]:
# Loading small sample of article from CNN and Daily Mail news dateset on HF: https://huggingface.co/datasets/cnn_dailymail
# To save time leveraging the gist with tiny extract from the dataset on HF
# - Not directly via LangChain with HuggingFaceDatasetLoader class because, it doesn't have split argument
save_time = True

if not save_time:
  news_hf = load_dataset(path="cnn_dailymail", name='1.0.0', split='train[:100]')
  news_pl = (pl.from_arrow(news_hf.data.table)
            .with_columns([pl.col("article").str.split(' ').list.len().alias("word_count")]))
else:

  news_pl = pl.read_csv(
      "https://gist.github.com/gox6/ef0aabc16dab6811e9b3da1e6694a84e/raw/cnn_daily_mail_tiny_extract.csv",
      separator=",",
  )
  news_hf = Dataset(news_pl.to_arrow())


news_pd = news_pl.to_pandas()

loader = PolarsDataFrameLoader(news_pl, page_content_column="article")
news = loader.load()


In [33]:
# Distribution of artciles by word count
fig = px.histogram(news_pl, x="word_count", marginal="rug")
fig.update_layout(
    title_text="Distribution of articles by word count", # title of plot
    xaxis_title_text='Word Count', # xaxis label
    yaxis_title_text='# Articles', # yaxis label
)
fig.show()

In [34]:
# Seeing news data
if COLAB:
  display(data_table.DataTable(news_pd, include_index=False, num_rows_per_page=5))
else:
  display(news_pd.head(5))

Unnamed: 0,article,highlights,id,word_count
0,"LONDON, England (Reuters) -- Harry Potter star...",Harry Potter star Daniel Radcliffe gets £20M f...,42c027e4ff9730fbb3de84c1af0d2c506e41c3e4,456
1,Editor's note: In our Behind the Scenes series...,Mentally ill inmates in Miami are housed on th...,ee8871b15c50d0db17b0179a6d2beab35065f1e9,700
2,"MINNEAPOLIS, Minnesota (CNN) -- Drivers who we...","NEW: ""I thought I was going to die,"" driver sa...",06352019a19ae31e527f37f7571c6dd7f0c5da37,746
3,WASHINGTON (CNN) -- Doctors removed five small...,"Five small polyps found during procedure; ""non...",24521a2abb2e1f5e34e6824e0f9e56904a2b0e88,415
4,(CNN) -- The National Football League has ind...,"NEW: NFL chief, Atlanta Falcons owner critical...",7fe70cc8b12fab2d0a258fababf7d9c6b5e1262a,977
...,...,...,...,...
95,"DENVER, Colorado -- A Colorado man terrorized ...",Some witnesses say Colorado does nothing to pr...,f70a7abb6c5b0ef383ea12a4d9ca046a5bd854e5,844
96,"LONDON, England (CNN) -- Previously unseen foo...",NEW: Jury shown new footage of Diana taken hou...,a3dd38ec7bc9d7e8423b96d8fd0641a2a5d5c984,659
97,WASHINGTON (CNN) -- Republicans reacted with s...,"Republican Sen. Lindsey Graham: ""I am astounde...",654c6b29b96d2a5a818d91400c20f838b0e8b6df,721
98,"ST. PETERSBURG, Florida (CNN) -- The acrimony ...","YouTube questions address taxes, the Bible, ab...",764d9ce99a1e3f79d95fbc4b68adbce14e7f8bcd,1161


#Generating Synthetic Evaluation Set

In [37]:
# Examining question evolution types evailable in ragas library
llm = ChatOpenAI(model="gpt-3.5-turbo")
generator_llm = llm
critic_llm = llm
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
generator = TestsetGenerator.from_langchain(
    generator_llm,
    critic_llm,
    embeddings
)

# Set question type distribution
distributions = {simple: 0.25, reasoning: 0.25, multi_context: 0.25,conditional: 0.25}

In [64]:
# This costs some real $$$
avoid_costs = True
save_time = True

if not (avoid_costs or save_time):
  # Generate evaluation set
  synthetic_evaluation_set_hf = generator.generate_with_langchain_docs(documents=news, test_size=20, distributions=distributions).to_dataset()
  synthetic_evaluation_set_pl = pl.from_arrow(synthetic_evaluation_set_hf.data.table)
  synthetic_evaluation_set_pd = synthetic_evaluation_set_hf.to_pandas()
else:
  # Download the pre-computed evaluation set
  synthetic_evaluation_set_url = "https://gist.github.com/gox6/0858a1ae2d6e3642aa132674650f9c76/raw/synthetic-evaluation-set-cnn-daily-mail.csv"
  synthetic_evaluation_set_pl = pl.read_csv(synthetic_evaluation_set_url, separator=",").drop("index")
  synthetic_evaluation_set_pd = synthetic_evaluation_set_pl.to_pandas()
  synthetic_evaluation_set_hf = Dataset(synthetic_evaluation_set_pl.to_arrow())



In [65]:
# Seeing news data
if COLAB:
  display(data_table.DataTable(synthetic_evaluation_set_pd, include_index=False, num_rows_per_page=3))
else:
  display(synthetic_evaluation_set_pd.head(5))

Unnamed: 0,question,contexts,ground_truth,evolution_type,metadata,episode_done,cnt
0,What actions are law enforcement taking in res...,(CNN) -- A girl who was shown on a videotape b...,Law enforcement is actively seeking Chester Ar...,conditional,"{'highlights': 'Stiles described as ""survivali...",True,0
1,What controversial practices could arise if Wa...,(CNN) -- Polygamist sect leader Warren Jeffs t...,If Warren Jeffs stepped down as the FLDS proph...,conditional,{'highlights': 'Documents say after suicide at...,True,1
2,How many John Lewis and Waitrose stores are ow...,(CNN) -- The partnership started as a single s...,The partnership owns 26 John Lewis department ...,conditional,"{'highlights': ""John Lewis Partnership began a...",True,2
3,How does Katherine Heigl describe her Invisali...,(InStyle.com) -- A hit TV show. An Emmy. A sum...,Katherine Heigl describes her Invisalign exper...,conditional,"{'highlights': '""Grey\'s Anatomy"" actress Kath...",True,3
4,"Who hosts ""Wine Library TV"" and has a unique a...","(LiveWire) -- Voluble Gary Vaynerchuk, 31, the...",,conditional,"{'highlights': '""Wine Library TV"" Internet sho...",True,4
5,How has the policy of Iraqi children enrolling...,"AMMAN, Jordan (CNN) -- In the sunbathed school...","The policy of allowing all Iraqi children, reg...",conditional,{'highlights': 'Jordan opens school doors to a...,True,5
6,How would Barcelona's win over Atletico Madrid...,"MADRID, Spain -- Lionel Messi scored for the s...","If Messi didn't score, Barcelona's win over At...",conditional,"{'highlights': ""Lionel Messi scores for the si...",True,6
7,What caused the gun battles in Mogadishu betwe...,"MOGADISHU, Somalia (CNN) -- An enraged crowd d...",The gun battles in Mogadishu between Ethiopian...,conditional,{'highlights': 'Ethiopian soldier dragged afte...,True,7
8,How did Interpol use software to analyze image...,"PARIS, France (CNN) -- Interpol on Monday took...",Interpol used software to unscramble digitally...,conditional,{'highlights': 'Man posted photos on the Inter...,True,8
9,What did Uru see while escaping the school wit...,"uru, one of the school's teachers, on radio st...",Uru saw a woman's body as he fled the school w...,conditional,"{'highlights': ""NEW: Teen gunman is dead, Finn...",True,9


In [14]:
def concat_pl_dfs(paths: list[str]):
  dfs = pl.DataFrame()
  for path in paths:
    df = pl.read_csv(path, separator=",")
    dfs = pl.concat([dfs, df], how="vertical")

  return dfs

urls = ["https://gist.github.com/gox6/20f8332dc4f0071e81b1e6e0ed15f14e/raw/synthetic2.csv",
        "https://gist.github.com/gox6/16a8578440b5fdc5d2304192a64ca721/raw/synthetic3.csv",
        "https://gist.github.com/gox6/f3d0f4b6cd9d8bd7d4c0d481f3d94a22/raw/synthetic4.csv",
        "https://gist.github.com/gox6/3de8ea92a91783e28e6f63e712c1b77d/raw/synthetic5.csv",
        "https://gist.github.com/gox6/13f70ebb8fffdce13a8a4d7b1a3af194/raw/synthetic6.csv",]

eval_set = concat_pl_dfs(urls).drop("index")



In [15]:
eval_set = eval_set.sort("contexts").with_columns([pl.col("contexts").cum_count().over("contexts").alias("cnt")])
eval_set.filter(pl.col("cnt") == pl.lit(0))
eval_set = eval_set.sort("evolution_type").with_columns([pl.col("evolution_type").cum_count().over("evolution_type").alias("cnt")])
eval_set.filter(pl.col("cnt") < pl.lit(7))


question,contexts,ground_truth,evolution_type,metadata,episode_done,cnt
str,str,str,str,str,bool,u32
"""What actions a…","""(CNN) -- A gir…","""Law enforcemen…","""conditional""","""{'highlights':…",true,0
"""What controver…","""(CNN) -- Polyg…","""If Warren Jeff…","""conditional""","""{'highlights':…",true,1
"""How many John …","""(CNN) -- The p…","""The partnershi…","""conditional""","""{'highlights':…",true,2
"""How does Kathe…","""(InStyle.com) …","""Katherine Heig…","""conditional""","""{'highlights':…",true,3
"""Who hosts ""Win…","""(LiveWire) -- …","""nan""","""conditional""","""{'highlights':…",true,4
"""How has the po…","""AMMAN, Jordan …","""The policy of …","""conditional""","""{'highlights':…",true,5
"""How would Barc…","""MADRID, Spain …","""If Messi didn'…","""conditional""","""{'highlights':…",true,6
"""What are the m…",""" identified on…","""The main compo…","""multi_context""","""{'highlights':…",true,0
"""What was the d…","""(CNN) -- Filmm…","""Michael Moore …","""multi_context""","""{'highlights':…",true,1
"""How did Beckha…","""(CNN) -- Footb…","""Beckham's move…","""multi_context""","""{'highlights':…",true,2


# Setting up a vector database: ChromaDB

In [61]:
# Setting up a ChromaDB client
chroma_client = chromadb.EphemeralClient()

# Listing exististing document collections in Chroma DB
chroma_client.list_collections()


[Collection(name=text-embedding-ada-002_chunk1000_overlap200)]

In [66]:
# Defining a function to get document collection from vector db with given hyperparemeters
# The function embeds the documents only if collection is missing
# This development version as for production one would rather implement document level check


def get_vectordb_collection(chroma_client,
                            documents,
                            embedding_model="text-embedding-ada-002",
                            chunk_size=None, overlap_size=0) -> ChromaCollection:

    if chunk_size is None:
      collection_name = "full_text"
      docs_pp = documents
    else:
      collection_name = f"{embedding_model}_chunk{chunk_size}_overlap{overlap_size}"

      text_splitter = CharacterTextSplitter(
        separator=".",
        chunk_size=chunk_size,
        chunk_overlap=overlap_size,
        length_function=len,
        is_separator_regex=False,
      )

      docs_pp = text_splitter.transform_documents(documents)


    embedding = OpenAIEmbeddings(model=embedding_model)

    langchain_chroma = Chroma(client=chroma_client,
                              collection_name=collection_name,
                              embedding_function=embedding,
                              )

    existing_collections = [collection.name for collection in chroma_client.list_collections()]

    if chroma_client.get_collection(collection_name).count() == 0:
      langchain_chroma.from_documents(collection_name=collection_name,
                                        documents=docs_pp,
                                        embedding=embedding)
    return langchain_chroma

# Simple RAG in LangChain

In [63]:
# Defininig a function to get a simple RAG as Langchain chain with given hyperparemeters
# RAG returns also the context documents retrieved for evaluation purposes in RAGAs

def get_chain(chroma_client,
              documents,
              embedding_model="text-embedding-ada-002",
              llm_model="gpt-3.5-turbo",
              chunk_size=None,
              overlap_size=0,
              top_k=4,
              lambda_mult=0.25) -> RunnableSequence:

    vectordb_collection = get_vectordb_collection(chroma_client=chroma_client,
                                                  documents=documents,
                                                  embedding_model=embedding_model,
                                                  chunk_size=chunk_size,
                                                  overlap_size=overlap_size)

    retriever = vectordb_collection.as_retriever(top_k=top_k, lambda_mult=lambda_mult)

    template = """Answer the question based only on the following context.
    If the context doesn't contain entities present in the question say you don't know.

    {context}

    Question: {question}
    """
    prompt = ChatPromptTemplate.from_template(template)
    llm = ChatOpenAI(model=llm_model)

    def format_docs(docs):
        return "\n\n".join([doc.page_content for doc in docs])

    chain_from_docs = (
      RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
      | prompt
      | llm
      | StrOutputParser()
    )

    chain_with_context_and_ground_truth = RunnableParallel(
      context=itemgetter("question") | retriever,
      question=itemgetter("question"),
      ground_truth=itemgetter("ground_truth"),
    ).assign(answer=chain_from_docs)

    return chain_with_context_and_ground_truth

In [118]:
# Testing full text rag

with warnings.catch_warnings():
  rag_prototype = get_chain(chroma_client=chroma_client, documents=news, chunk_size=1000, overlap_size=200)

rag_prototype.invoke( {'question': 'What happened in Minneapolis to the bridge?',
                       "ground_truth": "x"})["answer"]

'The bridge in Minneapolis collapsed.'

# Evaluation of RAG

In [20]:
# We create the helper function to generate the RAG ansers together with Ground Truth based on synthetic evaluation set
# The dataset for RAGAS evaluation should contain the columns: question, answer, ground_truth, contexts
# RAGAs expects the data in Huggingface Dataset format

def generate_rag_answers_for_synthetic_questions(chain,
                                                 synthetic_evaluation_set) -> pl.DataFrame:

  df = pl.DataFrame()

  for row in synthetic_evaluation_set.iter_rows(named=True):
    rag_output = chain.invoke({"question": row["question"], "ground_truth": row["ground_truth"]})
    rag_output["contexts"] = [doc.page_content for doc in rag_output["context"]]
    del rag_output["context"]
    rag_output_pp = {k: [v] for k, v in rag_output.items()}
    df = pl.concat([df, pl.DataFrame(rag_output_pp)], how="vertical")

  return df

avoid_costs = True
save_time = True

if not (avoid_costs or save_time):

  rag_prototype_answers = generate_rag_answers_for_synthetic_questions(rag_prototype, synthetic_evaluation_set_pl)

else:
  url = "https://gist.github.com/gox6/73927c9e273dc0ed48525d89bf9f36dd/raw/rag_prototype_answers_with_ground_truth.json"
  response = requests.get(url)
  rag_prototype_answers = pl.from_dicts(json.loads(response.text))


In [21]:
rag_prototype_answers_pd = rag_prototype_answers.to_pandas()
rag_prototype_answers_pl = pl.from_pandas(rag_prototype_answers_pd)
rag_prototype_answers_hf = Dataset.from_pandas(rag_prototype_answers_pd )

if COLAB:
  display(data_table.DataTable(rag_prototype_answers_pd, include_index=False, num_rows_per_page=3))
else:
  display(rag_prototype_answers_pd.head(5))

Unnamed: 0,question,ground_truth,answer,contexts
0,What actions are law enforcement taking in res...,Law enforcement is actively seeking Chester Ar...,Law enforcement is actively seeking Chester Ar...,[(CNN) -- With his hands and feet shackled an...
1,What controversial practices could arise if Wa...,If Warren Jeffs stepped down as the FLDS proph...,If Warren Jeffs stepped down as the FLDS proph...,[The FLDS -- which is not affiliated with the ...
2,How many John Lewis and Waitrose stores are ow...,The partnership owns 26 John Lewis department ...,The partnership owns 26 John Lewis department ...,[(CNN) -- The partnership started as a single ...
3,How does Katherine Heigl describe her Invisali...,Katherine Heigl describes her Invisalign exper...,Katherine Heigl describes her Invisalign exper...,[Then there's her deal with Coty to be the fac...
4,"Who hosts ""Wine Library TV"" and has a unique a...",,No entities present in the given context.,"[About 40,000 Internet viewers -- many of the..."
5,How has the policy of Iraqi children enrolling...,"The policy of allowing all Iraqi children, reg...",The policy of allowing all Iraqi children to e...,"[According to the charity Save the Children, 2..."
6,How would Barcelona's win over Atletico Madrid...,"If Messi didn't score, Barcelona's win over At...",If Messi didn't score in Barcelona's win over ...,"[MADRID, Spain -- Lionel Messi scored for the ..."
7,What caused the gun battles in Mogadishu betwe...,The gun battles in Mogadishu between Ethiopian...,The gun battles in Mogadishu between Ethiopian...,"[MOGADISHU, Somalia (CNN) -- An enraged crowd ..."
8,How did Interpol use software to analyze image...,Interpol used software to unscramble digitally...,Interpol used software to unscramble and resto...,"[PARIS, France (CNN) -- Interpol on Monday too..."
9,What did Uru see while escaping the school wit...,Uru saw a woman's body as he fled the school w...,Uru saw a woman's body as he fled the school.,[The agency reported Kiuru as saying that he s...


In [22]:
prototype_result = evaluate(rag_prototype_answers_hf,
                            metrics=[answer_correctness],
                            )

print(prototype_result)

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

# Optimising RAG using RAGAs and Optuna

In [119]:
# Train test split
# We need at least 2 sets: train and test for RAG optimization.

shuffled = synthetic_evaluation_set_pl.sample(fraction=1,
                                              shuffle=True,
                                              seed=6)
test_fraction = 0.5

test_n = round(len(synthetic_evaluation_set_pl) * test_fraction)
train, test = (shuffled.head(-test_n),
               shuffled.head( test_n))


In [90]:
def objective(trial):

  embedding_model = trial.suggest_categorical(name="embedding_model",
                                              choices=["text-embedding-ada-002", 'text-embedding-3-small'])

  chunk_size = trial.suggest_int(name="chunk_size",
                                 low=500,
                                 high=1000,
                                 step=100)

  overlap_size = trial.suggest_int(name="overlap_size",
                                   low=100,
                                   high=400,
                                   step=50)

  top_k = trial.suggest_int(name="top_k",
                            low=1,
                            high=10,
                            step=1)


  challenger_chain = get_chain(chroma_client,
                            news,
                            embedding_model=embedding_model,
                            llm_model="gpt-3.5-turbo",
                            chunk_size=chunk_size,
                            overlap_size= overlap_size ,
                            top_k=top_k,
                            lambda_mult=0.25)


  challenger_answers_pl = generate_rag_answers_for_synthetic_questions(challenger_chain , train)
  challenger_answers_hf = Dataset.from_pandas(challenger_answers_pl.to_pandas())

  challenger_result = evaluate(challenger_answers_hf,
                               metrics=[answer_correctness],
                              )

  return challenger_result['answer_correctness']



In [91]:
sampler = optuna.samplers.TPESampler(seed=6)
study = optuna.create_study(study_name="RAG Optimisation",
                            direction="maximize",
                            sampler=sampler)
study.set_metric_names(['answer_correctness'])

educated_guess = {"embedding_model": "text-embedding-3-small",
                  "chunk_size": 1000,
                  "overlap_size": 200,
                  "top_k": 3}


study.enqueue_trial(educated_guess)

print(f"Sampler is {study.sampler.__class__.__name__}")
study.optimize(objective, timeout=180)

[I 2024-04-26 08:50:53,743] A new study created in memory with name: RAG Optimisation


Sampler is TPESampler


Evaluating:   0%|          | 0/20 [00:00<?, ?it/s]

[I 2024-04-26 08:51:34,460] Trial 0 finished with value: 0.637602656181653 and parameters: {'embedding_model': 'text-embedding-3-small', 'chunk_size': 1000, 'overlap_size': 200, 'top_k': 3}. Best is trial 0 with value: 0.637602656181653.


Evaluating:   0%|          | 0/20 [00:00<?, ?it/s]

[I 2024-04-26 08:52:18,095] Trial 1 finished with value: 0.6475852431804713 and parameters: {'embedding_model': 'text-embedding-ada-002', 'chunk_size': 900, 'overlap_size': 100, 'top_k': 2}. Best is trial 1 with value: 0.6475852431804713.


Evaluating:   0%|          | 0/20 [00:00<?, ?it/s]

[I 2024-04-26 08:53:03,895] Trial 2 finished with value: 0.6719732084471782 and parameters: {'embedding_model': 'text-embedding-ada-002', 'chunk_size': 700, 'overlap_size': 200, 'top_k': 7}. Best is trial 2 with value: 0.6719732084471782.


Evaluating:   0%|          | 0/20 [00:00<?, ?it/s]

[I 2024-04-26 08:53:51,473] Trial 3 finished with value: 0.6501407376785545 and parameters: {'embedding_model': 'text-embedding-3-small', 'chunk_size': 800, 'overlap_size': 300, 'top_k': 7}. Best is trial 2 with value: 0.6719732084471782.


Evaluating:   0%|          | 0/20 [00:00<?, ?it/s]

[I 2024-04-26 08:54:42,448] Trial 4 finished with value: 0.700130617593832 and parameters: {'embedding_model': 'text-embedding-ada-002', 'chunk_size': 700, 'overlap_size': 400, 'top_k': 9}. Best is trial 4 with value: 0.700130617593832.


In [113]:
print("Best trial with answer_correctness:", study.best_trial.value)
print("Hyper-parameters for the best trial:", study.best_trial.params)

Best trial with answer_correctness: 0.700130617593832
Hyper-parameters for the best trial: {'embedding_model': 'text-embedding-ada-002', 'chunk_size': 700, 'overlap_size': 400, 'top_k': 9}


In [115]:
# Evaluation of the best trial parameters on the test set
challenger_chain = get_chain(chroma_client, news, **study.best_trial.params)
challenger_answers_pl = generate_rag_answers_for_synthetic_questions(challenger_chain , test)
challenger_answers_hf = Dataset.from_pandas(challenger_answers_pl.to_pandas())

challenger_result = evaluate(challenger_answers_hf, metrics=[answer_correctness])
challenger_result

Evaluating:   0%|          | 0/20 [00:00<?, ?it/s]

{'answer_correctness': 0.6788}