<a href="https://colab.research.google.com/github/dimitrod/ehu_nlp_dimathina/blob/develop/RAG_QA_Embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install langchain langchain-openai langchain-community faiss-cpu tiktoken datasets bitsandbytes transformers sentence_transformers



In [2]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
#from langchain_community.vectorstores import FAISS
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_core.vectorstores import VectorStoreRetriever
from langchain.chains import RetrievalQA
from datasets import load_dataset
from tqdm import tqdm
import os
import json

In [3]:
def download_dataset(split):
  evaluation_data = load_dataset('trivia_qa', name='rc.wikipedia', split=split)
  return evaluation_data

In [4]:
def format_to_json(data):
  file = "{\"Data\":["
  i = 0
  for i in tqdm(range(len(data)), desc="Loading dataset"):
    file += json.dumps(data[i])
    if i < len(data) - 1:
      file += ",\n"
    i += 1
  file += "]}\n"
  return json.loads(file)

In [35]:
def extract_qa_pairs(data):
  qa_pairs = []
  for i in tqdm(range(len(data['Data'])), desc="Extracting QA Pairs"):
    entry = data['Data'][i]
    qa_pairs.append((entry['question'], entry['answer']['aliases']))
  return qa_pairs

In [43]:
def write_file(file):
  print("Saving file...")
  with open(f"qa_database.json", "w", encoding="utf-8") as f:
    f.write(file)

In [52]:
split = "validation"
data = download_dataset(split)

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

In [54]:
data = format_to_json(data)

Loading dataset: 100%|██████████| 7993/7993 [00:09<00:00, 855.84it/s]


In [60]:
data = extract_qa_pairs(data)

Extracting QA Pairs: 100%|██████████| 7993/7993 [00:00<00:00, 348609.43it/s]


In [63]:
data =json.dumps(data)
write_file(data)

Saving file...


In [65]:
with open("qa_database.json", "r") as f:
  data = json.load(f)

In [67]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=0,
    length_function=len,
)

In [68]:
model = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")

In [69]:
import faiss

index = faiss.IndexFlatL2(len(model.embed_query("hello world")))

In [70]:
from langchain_community.docstore.in_memory import InMemoryDocstore

vector_store = FAISS(
    embedding_function=model,
    index=index,
    docstore= InMemoryDocstore(),
    index_to_docstore_id={}
)

In [74]:
print(data[0].__class__)

<class 'list'>


In [78]:
from langchain_core.documents import Document

train_documents = []
for i in tqdm(range(len(data)), desc="Embedding data"):
  qa_pair = json.dumps(data[i])
  doc = Document(page_content=qa_pair, metadata={"source": "Wikipedia"})
  train_documents.append(doc)
  vector_store.add_documents([doc])

Embedding data: 100%|██████████| 7993/7993 [1:09:10<00:00,  1.93it/s]


In [79]:
vector_store.save_local("qa_database")

In [84]:
db = FAISS.load_local("qa_database", model, allow_dangerous_deserialization=True)

In [86]:
contexts = db.similarity_search("What is the capital of France?")

In [87]:
for context in contexts:
  print(context)

page_content='["What is the capital of the French region of Burgundy?", ["Dijon", "Dijonnaise", "DIJON", "Dijon, France"]]' metadata={'source': 'Wikipedia'}
page_content='["What French region's capital city is Ajaccio?", ["La Corse", "Corsica, France", "Corsica (Journal)", "Cyrnos", "Cyrnus", "Corsic", "Corse", "Korsika", "Corsica"]]' metadata={'source': 'Wikipedia'}
page_content='["The Dakar Rally (previously known as the Paris Dakar rally) is an annual event, but of which country is Dakar the capital city?", ["Republic of S\u00e9n\u00e9gal", "Sengal", "Sport in Senegal", "ISO 3166-1:SN", "Indigenous cultures, kingdoms and ethnic groups of Senegal", "Culture of Senegal", "Republic of Senegal", "Senegal", "Senegalese", "R\u00e9publique du S\u00e9n\u00e9gal", "S\u00e9n\u00e9gal", "Indigenous Cultures, Kingdoms and Ethnic Groups of Senegal", "Etymology of Senegal"]]' metadata={'source': 'Wikipedia'}
page_content='["Metz is the capital of which region of France?", ["Lorraine (France)", "L