# Dataset Preparation for RAG Fine-tuning

Here we will download a subset of Wikipedia data and use a powerful LLM like ChatGPT to create a labeled dataset which can be used for:

- Fine-tuning Embedding Models for RAG Systems
- Fine-tuning LLMs for RAG Systems

## Get Wikipedia Data

In [1]:
!gdown 1oWBnoxBZ1Mpeond8XDUSO6J9oAjcRDyW

Downloading...
From (original): https://drive.google.com/uc?id=1oWBnoxBZ1Mpeond8XDUSO6J9oAjcRDyW
From (redirected): https://drive.google.com/uc?id=1oWBnoxBZ1Mpeond8XDUSO6J9oAjcRDyW&confirm=t&uuid=8cec7abc-bc59-472f-ad14-27b700f14cca
To: /workspace/training-fine-tuning-large-language-models-workshop-dhs2024/Module-04-Instruction-Fine-tuning-LLMs-with-Supervised-Fine-tuning/simplewiki-2020-11-01.jsonl.gz
100%|██████████████████████████████████████| 50.2M/50.2M [00:02<00:00, 22.8MB/s]


In [2]:
import gzip
import json

wikipedia_filepath = 'simplewiki-2020-11-01.jsonl.gz'

docs = []
with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        data = json.loads(line.strip())
        #Only add the first paragraph
        docs.append({
                        'metadata': {
                                        'title': data.get('title'),
                                        'article_id': data.get('id')
                        },
                        'data': data.get('paragraphs')[0] # restrict data to first 3 paragraphs to run later modules faster
        })

docs = [doc for doc in docs for x in ['india']
              if x in doc['data'].lower().split()]

In [3]:
len(docs)

767

In [4]:
docs[0]

{'metadata': {'title': 'Basil', 'article_id': '73985'},
 'data': 'Basil ("Ocimum basilicum") ( or ) is a plant of the Family Lamiaceae. It is also known as Sweet Basil or Tulsi. It is a tender low-growing herb that is grown as a perennial in warm, tropical climates. Basil is originally native to India and other tropical regions of Asia. It has been cultivated there for more than 5,000 years. It is prominently featured in many cuisines throughout the world. Some of them are Italian, Thai, Vietnamese and Laotian cuisines. It grows to between 30–60\xa0cm tall. It has light green, silky leaves 3–5\xa0cm long and 1–3\xa0cm broad. The leaves are opposite each other. The flowers are quite big. They are white in color and arranged as a spike.'}

## Enter Open AI Token

In [5]:
from getpass import getpass

OPENAI_KEY = getpass('Enter Open AI API Key: ')

Enter Open AI API Key:  ········


In [6]:
import os

os.environ['OPENAI_API_KEY'] = OPENAI_KEY

## Load Connection to GPT-4o Mini

In [7]:
from langchain_openai import ChatOpenAI

chatgpt = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)

## Use LLMs to Filter out Irrelevant Documents

Here we want to focus documents related to the country of India

In [8]:
from tqdm import tqdm 

results = []
for doc in tqdm(docs[:5]):
    prompt = f"""Given the following document follow these rules:
                  - Return 1 if the content is  focused on the country of the country of India
                  - Return 0 if the content is not focused on the country of India
                Just return the number and nothing else

                Document: {doc}
             """
    response = chatgpt.invoke(prompt)
    results.append(response.content)

100%|██████████| 5/5 [00:01<00:00,  2.67it/s]


In [9]:
docs[:5]

[{'metadata': {'title': 'Basil', 'article_id': '73985'},
  'data': 'Basil ("Ocimum basilicum") ( or ) is a plant of the Family Lamiaceae. It is also known as Sweet Basil or Tulsi. It is a tender low-growing herb that is grown as a perennial in warm, tropical climates. Basil is originally native to India and other tropical regions of Asia. It has been cultivated there for more than 5,000 years. It is prominently featured in many cuisines throughout the world. Some of them are Italian, Thai, Vietnamese and Laotian cuisines. It grows to between 30–60\xa0cm tall. It has light green, silky leaves 3–5\xa0cm long and 1–3\xa0cm broad. The leaves are opposite each other. The flowers are quite big. They are white in color and arranged as a spike.'},
 {'metadata': {'title': 'Roerich’s Pact', 'article_id': '259745'},
  'data': 'The Roerich Pact is a treaty on Protection of Artistic and Scientific Institutions and Historic Monuments, signed by the representatives of 21 states in the Oval Office of 

In [10]:
results

['1', '1', '1', '0', '0']

In [11]:
sample_docs = [doc['data'] for doc in docs[:3]]
sample_ids = [doc['metadata']['article_id'] for doc in docs[:3]]

In [12]:
sample_docs

['Basil ("Ocimum basilicum") ( or ) is a plant of the Family Lamiaceae. It is also known as Sweet Basil or Tulsi. It is a tender low-growing herb that is grown as a perennial in warm, tropical climates. Basil is originally native to India and other tropical regions of Asia. It has been cultivated there for more than 5,000 years. It is prominently featured in many cuisines throughout the world. Some of them are Italian, Thai, Vietnamese and Laotian cuisines. It grows to between 30–60\xa0cm tall. It has light green, silky leaves 3–5\xa0cm long and 1–3\xa0cm broad. The leaves are opposite each other. The flowers are quite big. They are white in color and arranged as a spike.',
 'The Roerich Pact is a treaty on Protection of Artistic and Scientific Institutions and Historic Monuments, signed by the representatives of 21 states in the Oval Office of the White House on 15 April 1935. As of January 1, 1990, the Roerich Pact had been ratified by ten nations: Brazil, Chile, Colombia, Cuba, the 

In [13]:
sample_ids

['73985', '259745', '207506']

## Use LLMs to Generate Data for Fine-tuning RAG Systems

We will format structured data from the unstructured documents with the fields of:

- Context
- Question
- Answer

To run faster we just create two such documents per wikipedia article

In [14]:
results = []
from tqdm import tqdm
import json

for doc, article_id in tqdm(zip(sample_docs, sample_ids), total=len(sample_docs)):
    prompt = f"""Given the following document and article_id, follow these rules:
                  - Generate two questions based on the content of the document
                  - Questions should be brief like humans ask questions, not super long
                  - Questions should not focus on things like what is the population, count etc.
                  - Questions should be more generic like Tell me about the capital of India etc.
                  - Generate answers to each of these questions only from the content of the document
                  - Start each answer by saying, Dear Friend,
                Do not make up answers, only use the content of the document.
                Each answer should be at least two lines if possible.

                Return a list of article_id, question, answer as a valid JSON
                Do not return a markdown JSON but just the JSON string
                JSON response should always be a list of the two records

                Document: {doc}
                article_id: {article_id}
            """
    response = chatgpt.invoke(prompt)
    result = json.loads(response.content)
    results.extend(result)

100%|██████████| 3/3 [00:04<00:00,  1.43s/it]


In [15]:
len(results)

6

In [16]:
results

[{'article_id': 73985,
  'question': 'What is Basil commonly known as?',
  'answer': 'Dear Friend, Basil is commonly known as Sweet Basil or Tulsi. This tender herb is well-regarded in various cuisines around the world.'},
 {'article_id': 73985,
  'question': 'Where is Basil originally native to?',
  'answer': 'Dear Friend, Basil is originally native to India and other tropical regions of Asia. It has a long history of cultivation in these areas, spanning over 5,000 years.'},
 {'article_id': 259745,
  'question': 'What is the significance of the Roerich Pact?',
  'answer': 'Dear Friend, the Roerich Pact is significant because it emphasizes that the protection of culture is more important than military needs. It represents a commitment to safeguarding artistic and scientific institutions as well as historic monuments.'},
 {'article_id': 259745,
  'question': 'When was the Roerich Pact signed and by whom?',
  'answer': 'Dear Friend, the Roerich Pact was signed on 15 April 1935 by represe

## Actual Runs to Create The Full Dataset

You don't need to run the commented code below and it takes quite some time to create the full dataset using LLMs but you can use this as needed on your own data also in the future

In [None]:
# results = []
# from tqdm import tqdm
# # remove the indexing to run on all docs
# for doc in tqdm(docs[:10):
#     prompt = f"""Given the following document follow these rules:
#                 - Return 1 if the content is  focused on the country of the country of India
#                 - Return 0 if the content is not focused on the country of India
#               Just return the number and nothing else

#               Document: {doc}
#            """
#     results.append(chatgpt.invoke(prompt))

100%|██████████| 767/767 [06:58<00:00,  1.83it/s]


In [None]:
len(results)

767

In [None]:
results = [r.content for r in results]

In [None]:
results[:5]

['1', '1', '1', '0', '0']

In [None]:
import pandas as pd

df = pd.DataFrame(docs)
df['relevance'] = results
df.head()

Unnamed: 0,metadata,data,relevance
0,"{'title': 'Basil', 'article_id': '73985'}","Basil (""Ocimum basilicum"") ( or ) is a plant o...",1
1,"{'title': 'Roerich’s Pact', 'article_id': '259...",The Roerich Pact is a treaty on Protection of ...,1
2,"{'title': 'Indian Air Force', 'article_id': '2...",The Indian Air Force is the air arm of the Ind...,1
3,"{'title': 'Makran (princely state)', 'article_...",Makran was an autonomous princely state of bot...,0
4,"{'title': 'Kharan (princely state)', 'article_...",The State of Kharan was an autonomous princely...,0


In [None]:
df['relevance'].value_counts()

Unnamed: 0_level_0,count
relevance,Unnamed: 1_level_1
1,670
0,97


In [None]:
filtered_df = df[df['relevance'] == '1']

In [None]:
filtered_df.head()

Unnamed: 0,metadata,data,relevance
0,"{'title': 'Basil', 'article_id': '73985'}","Basil (""Ocimum basilicum"") ( or ) is a plant o...",1
1,"{'title': 'Roerich’s Pact', 'article_id': '259...",The Roerich Pact is a treaty on Protection of ...,1
2,"{'title': 'Indian Air Force', 'article_id': '2...",The Indian Air Force is the air arm of the Ind...,1
8,"{'title': 'Prem Rawat', 'article_id': '4954'}",Prem Pal Singh Rawat (in India called Maharaji...,1
9,"{'title': 'Chennai', 'article_id': '5113'}",Chennai (formerly known as Madras) is the capi...,1


In [None]:
filtered_df.iloc[4]

Unnamed: 0,9
metadata,"{'title': 'Chennai', 'article_id': '5113'}"
data,Chennai (formerly known as Madras) is the capital city of the Indian state of Tamil Nadu. It has a population of about 7 million people. Almost 10% of all of the people in the state live in Chennai. The city is the fourth largest city of India. It was founded in 1661 by the British East India Company. The city is on the Coromandel Coast of the Bay of Bengal.
relevance,1


In [None]:
docs = filtered_df['data'].tolist()
article_ids = [doc['article_id'] for doc in filtered_df['metadata'].tolist()]

In [None]:
len(set(article_ids))

670

In [None]:
len(result)

2

In [None]:
result

[{'article_id': 5113,
  'question': 'What is the historical significance of Chennai?',
  'answer': 'Dear Friend, Chennai, formerly known as Madras, was founded in 1661 by the British East India Company. This historical foundation marks its importance in the context of colonial India.'},
 {'article_id': 5113,
  'question': 'Where is Chennai located?',
  'answer': 'Dear Friend, Chennai is located on the Coromandel Coast of the Bay of Bengal. This coastal position contributes to its significance as a major city in India.'}]

In [None]:
# results = []
# from tqdm import tqdm

# for doc, article_id in tqdm(zip(docs, article_ids)):
#     prompt = f"""Given the following document and article_id, follow these rules:
#                 - Generate two questions based on the content of the document
#                 - Questions should be brief like humans ask questions, not super long
#                 - Questions should not focus on things like what is the population, count etc.
#                 - Questions should be more generic like Tell me about the capital of India etc.
#                 - Generate answers to each of these questions only from the content of the document
#                 - Start each answer by saying, Dear Friend,
#                  Do not make up answers, only use the content of the document.
#                  Each answer should be at least two lines if possible.

#                  Return a list of article_id, question, answer as a valid JSON
#                  Do not return a markdown JSON but just the JSON string
#                  JSON response should always be a list of the two records

#               Document: {doc}
#               article_id: {article_id}
#           """
#     response = chatgpt.invoke(prompt)
#     result = json.loads(response.content)
#     results.extend(result)

670it [29:30,  2.64s/it]


In [None]:
len(results)

1340

In [None]:
results[:2]

[{'article_id': 73985,
  'question': 'What is Basil commonly known as?',
  'answer': 'Dear Friend, Basil is commonly known as Sweet Basil or Tulsi. This plant is a member of the Lamiaceae family and is cherished for its culinary uses.'},
 {'article_id': 73985,
  'question': 'Where is Basil originally native to?',
  'answer': 'Dear Friend, Basil is originally native to India and other tropical regions of Asia. It has a rich history of cultivation in these areas for over 5,000 years.'}]

In [None]:
qa_df = pd.DataFrame(results)
qa_df.head()

Unnamed: 0,article_id,question,answer
0,73985,What is Basil commonly known as?,"Dear Friend, Basil is commonly known as Sweet ..."
1,73985,Where is Basil originally native to?,"Dear Friend, Basil is originally native to Ind..."
2,259745,What is the significance of the Roerich Pact?,"Dear Friend, the Roerich Pact is significant b..."
3,259745,When was the Roerich Pact signed?,"Dear Friend, the Roerich Pact was signed on 15..."
4,207506,What is the history of the Indian Air Force?,"Dear Friend, the Indian Air Force was founded ..."


In [None]:
qa_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1340 entries, 0 to 1339
Data columns (total 3 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   article_id  1340 non-null   object
 1   question    1340 non-null   object
 2   answer      1340 non-null   object
dtypes: object(3)
memory usage: 31.5+ KB


In [None]:
filtered_df['article_id'] = article_ids
filtered_df.head()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_df['article_id'] = article_ids


Unnamed: 0,metadata,data,relevance,article_id
0,"{'title': 'Basil', 'article_id': '73985'}","Basil (""Ocimum basilicum"") ( or ) is a plant o...",1,73985
1,"{'title': 'Roerich’s Pact', 'article_id': '259...",The Roerich Pact is a treaty on Protection of ...,1,259745
2,"{'title': 'Indian Air Force', 'article_id': '2...",The Indian Air Force is the air arm of the Ind...,1,207506
8,"{'title': 'Prem Rawat', 'article_id': '4954'}",Prem Pal Singh Rawat (in India called Maharaji...,1,4954
9,"{'title': 'Chennai', 'article_id': '5113'}",Chennai (formerly known as Madras) is the capi...,1,5113


In [None]:
filtered_df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 670 entries, 0 to 765
Data columns (total 4 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   metadata    670 non-null    object
 1   data        670 non-null    object
 2   relevance   670 non-null    object
 3   article_id  670 non-null    object
dtypes: object(4)
memory usage: 26.2+ KB


In [None]:
qa_df['article_id'] = pd.to_numeric(qa_df['article_id'], errors='coerce')
filtered_df['article_id'] = pd.to_numeric(filtered_df['article_id'], errors='coerce')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_df['article_id'] = pd.to_numeric(filtered_df['article_id'], errors='coerce')


In [None]:
# merge both the dfs
final_df = pd.merge(qa_df, filtered_df, on='article_id')
final_df = final_df[['article_id', 'data', 'question', 'answer']]
final_df.columns = ['article_id', 'context', 'question', 'answer']
final_df.head(10)

Unnamed: 0,article_id,context,question,answer
0,73985,"Basil (""Ocimum basilicum"") ( or ) is a plant o...",What is Basil commonly known as?,"Dear Friend, Basil is commonly known as Sweet ..."
1,73985,"Basil (""Ocimum basilicum"") ( or ) is a plant o...",Where is Basil originally native to?,"Dear Friend, Basil is originally native to Ind..."
2,259745,The Roerich Pact is a treaty on Protection of ...,What is the significance of the Roerich Pact?,"Dear Friend, the Roerich Pact is significant b..."
3,259745,The Roerich Pact is a treaty on Protection of ...,When was the Roerich Pact signed?,"Dear Friend, the Roerich Pact was signed on 15..."
4,207506,The Indian Air Force is the air arm of the Ind...,What is the history of the Indian Air Force?,"Dear Friend, the Indian Air Force was founded ..."
5,207506,The Indian Air Force is the air arm of the Ind...,What role does the Indian Air Force play?,"Dear Friend, the Indian Air Force serves as th..."
6,4954,Prem Pal Singh Rawat (in India called Maharaji...,Who is Prem Pal Singh Rawat?,"Dear Friend, Prem Pal Singh Rawat, also known ..."
7,4954,Prem Pal Singh Rawat (in India called Maharaji...,What organizations has Prem Pal Singh Rawat be...,"Dear Friend, Prem Pal Singh Rawat has been ass..."
8,5113,Chennai (formerly known as Madras) is the capi...,What is significant about Chennai's history?,"Dear Friend, Chennai, formerly known as Madras..."
9,5113,Chennai (formerly known as Madras) is the capi...,Where is Chennai located?,"Dear Friend, Chennai is located on the Coroman..."


In [None]:
final_df.to_csv('india_rag_wikidata.csv', index=False)

In [None]:
filtered_df.shape

(670, 4)

In [None]:
from datasets import Dataset

ds = Dataset.from_pandas(final_df)
ds

Dataset({
    features: ['article_id', 'context', 'question', 'answer'],
    num_rows: 1340
})

In [17]:
# from huggingface_hub import login

# login() 

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
# ds.push_to_hub("dipanjanS/india_rag_wikidata")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/dipanjanS/india_rag_wikidata/commit/0e545ea5d9b285cec8d354bf90165115c488e05a', commit_message='Upload dataset', commit_description='', oid='0e545ea5d9b285cec8d354bf90165115c488e05a', pr_url=None, pr_revision=None, pr_num=None)

In [18]:
from datasets import load_dataset

dataset = load_dataset("dipanjanS/india_rag_wikidata", split="train")

Downloading readme:   0%|          | 0.00/399 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/335k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1340 [00:00<?, ? examples/s]

In [19]:
dataset

Dataset({
    features: ['article_id', 'context', 'question', 'answer'],
    num_rows: 1340
})

In [20]:
df = dataset.to_pandas()
df.head(10)

Unnamed: 0,article_id,context,question,answer
0,73985,"Basil (""Ocimum basilicum"") ( or ) is a plant o...",What is Basil commonly known as?,"Dear Friend, Basil is commonly known as Sweet ..."
1,73985,"Basil (""Ocimum basilicum"") ( or ) is a plant o...",Where is Basil originally native to?,"Dear Friend, Basil is originally native to Ind..."
2,259745,The Roerich Pact is a treaty on Protection of ...,What is the significance of the Roerich Pact?,"Dear Friend, the Roerich Pact is significant b..."
3,259745,The Roerich Pact is a treaty on Protection of ...,When was the Roerich Pact signed?,"Dear Friend, the Roerich Pact was signed on 15..."
4,207506,The Indian Air Force is the air arm of the Ind...,What is the history of the Indian Air Force?,"Dear Friend, the Indian Air Force was founded ..."
5,207506,The Indian Air Force is the air arm of the Ind...,What role does the Indian Air Force play?,"Dear Friend, the Indian Air Force serves as th..."
6,4954,Prem Pal Singh Rawat (in India called Maharaji...,Who is Prem Pal Singh Rawat?,"Dear Friend, Prem Pal Singh Rawat, also known ..."
7,4954,Prem Pal Singh Rawat (in India called Maharaji...,What organizations has Prem Pal Singh Rawat be...,"Dear Friend, Prem Pal Singh Rawat has been ass..."
8,5113,Chennai (formerly known as Madras) is the capi...,What is significant about Chennai's history?,"Dear Friend, Chennai, formerly known as Madras..."
9,5113,Chennai (formerly known as Madras) is the capi...,Where is Chennai located?,"Dear Friend, Chennai is located on the Coroman..."
