In [93]:
# pip install langchain langchain-google-genai faiss-cpu pandas sentence-transformers

In [95]:
from langchain_community.document_loaders.csv_loader import CSVLoader

In [97]:
import pandas as pd

In [99]:
df = pd.read_parquet('passages.parquet')

In [100]:
idx = [i for i in range(df.shape[0]) if df.iloc[i]['passage'].lower() == 'nan']

In [101]:
len(idx)

12220

In [102]:
df = df.iloc[[i for i in range(df.shape[0]) if i not in idx]]

In [103]:
df

Unnamed: 0_level_0,passage
id,Unnamed: 1_level_1
9797,New data on viruses isolated from patients wit...
11906,We describe an improved method for detecting d...
16083,We have studied the effects of curare on respo...
23188,Kinetic and electrophoretic properties of 230-...
23469,Male Wistar specific-pathogen-free rats aged 2...
...,...
34885209,LncRNAs are involved in the occurrence and pro...
34886835,BACKGROUND: COVID-19 patients with long incuba...
34888619,Spinal muscular atrophy (SMA) is an autosomal ...
34893673,Amphiregulin (AREG) is an epidermal growth fac...


In [104]:
df.to_csv('passages.csv')

In [105]:
from langchain_text_splitters import RecursiveCharacterTextSplitter

In [107]:
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_google_genai import GoogleGenerativeAI

In [108]:
from  langchain.embeddings import HuggingFaceEmbeddings

In [109]:
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
import pandas as pd

In [110]:
class LangChainRAG:
    def __init__(self, google_api_key):
        self.llm = GoogleGenerativeAI(
            model="gemini-pro",
            google_api_key=google_api_key,
            temperature=0.3
        )
        self.embeddings = HuggingFaceEmbeddings(
            model_name="sentence-transformers/all-MiniLM-L6-v2"
        )
        
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=500,
            chunk_overlap=50,
            length_function=len
        )
        self.vectorstore = None
        self.qa_chain = None
        
    def load_and_index_passages(self, csv_path, passage_column = 'passage'):
        df = pd.read_csv(csv_path)
        passages = df.iloc[:1000][passage_column].tolist()
        texts = self.text_splitter.split_text('\n'.join(passages))
        self.vectorstore = FAISS.from_texts(
            texts,
            self.embeddings
        )

        prompt_template = """Use the following pieces of context to answer the question at the end. 
        If you don't know the answer or can't find it in the context, just say that you don't know, 
        don't try to make up an answer.
        Context: {context}

        Question: {question}

        Answer:"""
        
        PROMPT = PromptTemplate(
            template=prompt_template,
            input_variables=["context", "question"]
        )
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="stuff",
            retriever=self.vectorstore.as_retriever(
                search_kwargs={"k": 3}
            ),
            return_source_documents=True,
            chain_type_kwargs={"prompt": PROMPT}
        )
        
    def get_answer(self, query):
        result = self.qa_chain({"query": query})
        return result["result"], result["source_documents"]

In [111]:
qa = pd.read_parquet('test.parquet')

In [112]:
qa

Unnamed: 0_level_0,question,answer,relevant_passage_ids
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,Is Hirschsprung disease a mendelian or a multi...,"Coding sequence mutations in RET, GDNF, EDNRB,...","[20598273, 6650562, 15829955, 15617541, 230011..."
1,List signaling molecules (ligands) that intera...,The 7 known EGFR ligands are: epidermal growt...,"[23821377, 24323361, 23382875, 22247333, 23787..."
2,Is the protein Papilin secreted?,"Yes, papilin is a secreted protein","[21784067, 19297413, 15094122, 7515725, 332004..."
3,Are long non coding RNAs spliced?,Long non coding RNAs appear to be spliced thro...,"[22955974, 21622663, 22707570, 22955988, 24285..."
4,Is RANKL secreted from the cells?,Receptor activator of nuclear factor κB ligand...,"[22867712, 23827649, 21618594, 23835909, 24265..."
...,...,...,...
4714,Is PPROM a condition that occurs in males or f...,Preterm premature rupture of fetal membranes (...,"[23599878, 23573382, 24304137, 18301713, 23179..."
4715,What is EpiMethylTag?,"EpiMethylTag is a fast, low-input, low sequenc...",[31752933]
4716,What is the target of Sutimlimab?,Sutimlimab is a novel humanized monoclonal ant...,"[30635392, 31229501, 33826820, 32176765, 31114..."
4717,Can parasite infections by Schistosoma japonic...,A peptide named as SJMHE1 from Schistosoma jap...,"[26840774, 34703270, 28614408, 31496071, 18654..."


In [113]:
google_api_key = "AIzaSyC5ihtfdzs86uWcNmuu4swqZB3tkzM37dE"
rag = LangChainRAG(google_api_key)
rag.load_and_index_passages("passages.csv")
question = qa.iloc[0,0]
answer, source = rag.get_answer(question)
print(f"Question: {question}")
print(f"\nAnswer: {answer}")
print(f"\nSource: {source[0]}")

Question: Is Hirschsprung disease a mendelian or a multifactorial disorder?

Answer: Multifactorial

Sources: page_content='characterized by parotitis, uveitis, and facial nerve paralysis. A case is \npresented and the clinical manifestations are discussed. Angiotensin converting \nenzyme assays along with tissue biopsy demonstrating noncaseating granulomas \nconfirm the diagnosis.\nHirschsprung disease, or congenital aganglionic megacolon, is commonly assumed \nto be a sex-modified multifactorial trait. To test this hypothesis, complex \nsegregation analysis was performed on data on 487 probands and their families.'
