In [None]:
from pypdf import PdfReader
import pickle
from dotenv import load_dotenv
import os
# For download embeddings model
from huggingface_hub import snapshot_download

# For embeddings and vector stores
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter

# For langchain_groq
from langchain_groq import ChatGroq

# For Rag chain
from IPython.display import Markdown, display
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

# Settings
- api_key : api_key
- model : put model name here  

In [None]:
api_key = os.getenv('GROQ_API_KEY')
model = "llama-3.1-70b-versatile"
pdf_path = "./pdf"
embeddings_path = "./path_to_model/intfloat/multilingual-e5-base"

# Extract txts from PDF

In [5]:
def read_text_from_pdf(pdf_path):
    reader = PdfReader(pdf_path)
    text = ""
    for page_num in range(len(reader.pages)):
        page = reader.pages[page_num]
        text += page.extract_text()
    return text


# PDFファイルのパスを指定してテキストを取得
pdfs = os.listdir(pdf_path)
pdf_text = ""
for pdf in pdfs:
    pdf_text += read_text_from_pdf(pdf_path + "/" + pdf)

with open("pdf_text.pkl", "wb") as f:
    pickle.dump(pdf_text, f)

In [6]:
with open("pdf_text.pkl", "rb") as f:
    pdf_text = pickle.load(f)

# DB establishment

## Download the embeddigns model

In [None]:
model_name = "intfloat/multilingual-e5-base"
if os.path.isfile(f"path_to_model/{model_name}/config.json"):
    print("Model already exists.")
else:
    download_path = snapshot_download(
        repo_id=model_name,
        local_dir = f"path_to_model/{model_name}",
        local_dir_use_symlinks=False # ※1
        )


## Make retriever

In [None]:
# チャンク間でoverlappingさせながらテキストを分割
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=512,
    chunk_overlap=128,
)
# テキストを分割
splited_text = text_splitter.split_text(pdf_text)

embeddings = HuggingFaceEmbeddings(
    model_name=embeddings_path
)

# テキストを埋め込みベクトルに変換
index = FAISS.from_texts(splited_text, embedding=embeddings)
# FaissのRetrieverを取得
retriever = index.as_retriever(search_kwargs={"k": 4})
with open("retriever.pkl", "wb") as f:
    pickle.dump(retriever, f)

### From the second run

In [None]:
with open("retriever.pkl", "rb") as f:
    retriever = pickle.load(f)

# RAG chain

In [None]:
class LLM:
    def __init__(self, api_key):
        self.api_key = api_key
        self.client = ChatGroq(api_key=api_key,)
        self.model = "llama-3.1-70b-versatile"

        self.prompt_template = """
            <|system|>
            Use the following pieces of context to answer the question at the end.
            If you don't know the answer, just say that you don't know, don't try to make up an answer.
            keep the answer as concise as possible.
            Use markdown formatting when displaying code.
            Emphasis should be used to terminologies.
            You need to provide the sources of the information you provide.

            Answer in very easy terms t.
            Answer in English.
            Then translate the answer into Japanese.

            {context}

            </s>
            <|user|>
            {question}
            </s>
            <|assistant|>

        """
        #  Create the PromptTemplate Instance
        self.prompt = PromptTemplate(
            input_variables=[
                "context",
                "question",
                "sources"
                ],
            template=self.prompt_template,
        )

        self.llm_chain = self.prompt | self.client | StrOutputParser()

    def format_docs(self,docs):
        return "\n\n".join(doc.page_content for doc in docs)

    def format_sources(self,docs):
        """
        Extracts and formats the sources (metadata) from the retrieved documents.
        This function ensures that the source information is passed to the LLM as part of the context.
        """
        # Extract sources (e.g., URLs or document titles) from metadata
        sources = [doc.metadata.get('source', 'Unknown Source') for doc in docs]
        # Join the sources into a single string to pass to the LLM
        return "Sources:\n" + "\n".join(sources)
    
    def show_markdown(self, markdown_text,title=None):
        display(Markdown(title + markdown_text))
    
    def chat(self, user_input,retriever, show=True):
        rag_chain = (
            {
                "context": retriever | self.format_docs,    # Retrieval Step for context
                "sources": retriever | self.format_sources, # Retrieve and format sources
                "question": RunnablePassthrough()      # Prompt Generation
            }
            | self.llm_chain                                # Generation Step
        )

        response = rag_chain.invoke(user_input)

        if show : self.show_markdown(response, title="# RAG Answer\n")
        return response
    
    def not_fine_tuned_chat(self, user_input, show=True):
        response = self.llm_chain.invoke({"context": "", "question": user_input})

        if show : self.show_markdown(response, title="# RAG Answer\n")
        return response

In [None]:
llm = LLM(api_key)
query = "半導体検出器について教えてください。"
response = llm.chat(query,retriever)
response = llm.not_fine_tuned_chat(query)