<a href="https://colab.research.google.com/github/kstyle2198/NLP_TIPS/blob/main/RAG_CHATBOT_LLAMA2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import re
import time
import pandas as pd
import streamlit as st
from langchain.document_loaders import PyPDFLoader
from langchain.document_loaders import Docx2txtLoader
from langchain.document_loaders import UnstructuredPowerPointLoader
from loguru import logger

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers import ContextualCompressionRetriever
from langchain.llms import CTransformers
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS

from langchain.callbacks import StreamlitCallbackHandler
from llama_cpp import Llama


st.set_page_config(layout="wide",page_title="llama2")
os.chdir("/home/shared/")
os.environ["TOKENIZERS_PARALLELISM"] = "false"




####### Session State Variables ####################################
if "file_uploader_key" not in st.session_state:
    st.session_state["file_uploader_key"] = 0

if 'questions' not in st.session_state:
    st.session_state['questions'] = list()

if 'answer' not in st.session_state:
    st.session_state['answer'] = list()

if 'trans_answer' not in st.session_state:
    st.session_state['trans_answer'] = list()

if 'src_docu1' not in st.session_state:
    st.session_state['src_docu1'] = list()

if 'src_meta1' not in st.session_state:
    st.session_state['src_meta1'] = list()

if 'src_docu2' not in st.session_state:
    st.session_state['src_docu2'] = list()

if 'src_meta2' not in st.session_state:
    st.session_state['src_meta2'] = list()

if 'response' not in st.session_state:
    st.session_state['response'] = list()

if 'text_splitter' not in st.session_state:
    st.session_state['text_splitter'] = ""

if 'embeddings' not in st.session_state:
    st.session_state['embeddings'] = ""

if 'llm' not in st.session_state:
    st.session_state['llm'] = ""

if 'dbqa' not in st.session_state:
    st.session_state['dbqa'] = ""

if 'replied' not in st.session_state:
    st.session_state['replied'] = ""


#### functions ###########################################################

@st.cache_data
def get_text(docs):
    doc_list = []
    for doc in docs:
        file_name = doc.name  # doc 객체의 이름을 파일 이름으로 사용
        with open(file_name, "wb") as file:  # 파일을 doc.name으로 저장
            file.write(doc.getvalue())
            logger.info(f"Uploaded {file_name}")
        if '.pdf' in doc.name:
            loader = PyPDFLoader(file_name)
            documents = loader.load_and_split()
        elif '.docx' in doc.name:
            loader = Docx2txtLoader(file_name)
            documents = loader.load_and_split()
        elif '.pptx' in doc.name:
            loader = UnstructuredPowerPointLoader(file_name)
            documents = loader.load_and_split()
        doc_list.extend(documents)
    return doc_list


from datetime import datetime, timedelta

def calculate_time_delta(start_time, end_time):
    # Calculate the time difference (time delta) in seconds
    time_difference = end_time - start_time
    seconds = time_difference.seconds
    return seconds


from googletrans import Translator

class Google_Translator:
    def __init__(self):
        self.translator = Translator()
        self.result = {'src_text': '', 'src_lang': '', 'tgt_text': '', 'tgt_lang': ''}

    def translate(self, text, lang='ko'):
        translated = self.translator.translate(text, dest=lang)
        self.result['src_text'] = translated.origin
        self.result['src_lang'] = translated.src
        self.result['tgt_text'] = translated.text
        self.result['tgt_lang'] = translated.dest

        return self.result

    def translate_file(self, file_path, lang='ko'):
        with open(file_path, 'r') as f:
            text = f.read()
        return self.translate(text, lang)

def trans(en):
    translator = Google_Translator()
    result = translator.translate(str(en))
    if "tgt_text" in result.keys():
        return result["tgt_text"]
    else:
        return result["src_text"]


if __name__ == "__main__":

    st.title("📑 :red[문서분석 ChatBot] with :blue[LLAMA2] & :green[MISTRAL]")
    st.markdown("---")

    with st.expander("🎈 **검토 개요**", expanded=True):
        st.markdown('''
        - **첨부한 문서에 대해 Q&A 채팅 가능(첨부는 :red[영문 문서]여야 함 / :green[첨부문서 추가, 모델 Temperature, 모델 종류 등 변경시] 모델로딩 버튼 다시 클릭)**
        - :red[**🦙Llama2**] 모델 및 :blue[**🪁Mistral**] 모델 선택 가능(8-bit Quantized :blue[CPUs using Llama2-7B, Mistral-7B], C Transformers, and LangChain)
        - 한글로 질문해도 응답 하나, 영어로 질문할 때 더 좋은 응답을 제공함 / 질문, 응답, 번역응답, 응답근거, 근거파일를 **csv 파일 형태로 다운로드** 가능
        - **[Contact]** 김종배 책임(jongbae.kim@hd.com)
        ''')

    ##### read uploaded files ####################
    uploaded_files = st.file_uploader("📚 **검토대상 문서첨부(pdf, docx, pptx)** / 복수 파일 업로드 가능",type=['pdf','docx', 'pptx'],
                                           accept_multiple_files=True, key=st.session_state["file_uploader_key"])
    col1, col2, col3, col4, col5 = st.columns(5)

    models = {"🦙LLAMA2_7B": "M_model/llama-2-7b-chat.ggmlv3.q8_0.bin",
             "🪁MISTRAL_7B":"M_model/mistral-7b-instruct-v0.1.Q8_0.gguf",
             "🦙TinyLlama_1B":"M_model/tinyllama-1.1b-chat-v0.3.Q8_0.gguf"}

    model_types = {"🦙LLAMA2_7B": "llama",
                   "🪁MISTRAL_7B":"mistral",
                   "🦙TinyLlama_1B":"llama"}

    col701, col702, col703 = st.columns([3,4,4])
    with col701:
        sel01 = st.selectbox("🚩 **:red[Select LLM(언어모델 선택)]**", ("🦙LLAMA2_7B", "🪁MISTRAL_7B", "🦙TinyLlama_1B"))
        LLM_model = models[sel01]
        model_type = model_types[sel01]

    with col702:
        temp_value = st.slider("🌡️ **모델 Temperature :red[(0에 가까우면 보수적, 2에 가까우면 창의적)]**", min_value=0.0, max_value=2.0, value=0.0, step=0.1)

    with col703:
        max_token = st.slider("🌡️ **Max_New_Token(최대생성가능토큰수)**", min_value=256, max_value=2000, value=256, step=100)

    btn111 = st.button("⚙️ 모델로딩", type='primary')
    with st.spinner("로딩중..."):
        try:
            if btn111==True and uploaded_files != None :

                texts = get_text(uploaded_files)

                st.session_state['text_splitter'] = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
                docs = st.session_state['text_splitter'].split_documents(texts)
                with col1:
                    st.info("🍉text_splitter 완료")

                st.session_state['embeddings'] = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2',
                                                   model_kwargs={'device':'cpu'},)
                with col2:
                    st.info("🍊embeddings 완료")

                vectorstore = FAISS.from_documents(docs, st.session_state['embeddings'])
                with col3:
                    st.info("🍎vectorstore 완료")

                st.session_state['llm'] = CTransformers(model=LLM_model, # Location of downloaded GGML model
                                                        model_type=model_type,
                                                        stream=True,
                                                        config={'max_new_tokens': max_token,
                                                                'temperature': temp_value})
                with col4:
                    st.info("🍓LLM 로딩 완료")

                qa_template = """Use the following pieces of information to answer the user's question.
                If you don't know the answer, just say that you don't know, don't try to make up an answer.

                Context: {context}
                Question: {question}

                Only return the helpful answer below and nothing else.
                Helpful answer:
                """

                prompt = PromptTemplate(template=qa_template, input_variables=['context', 'question'])
                st_callback = StreamlitCallbackHandler(st.container())
                dbqa = RetrievalQA.from_chain_type(llm=st.session_state['llm'],
                                                   chain_type='stuff',
                                                   callbacks=[st_callback],
                                                   retriever= vectorstore.as_retriever(search_type="mmr", search_kwargs={'k':2}),
                                                   return_source_documents=True,
                                                   chain_type_kwargs={'prompt': prompt})
                st.session_state['dbqa'] = dbqa
                with col5:
                    st.info("🍇RetrievalQA Chain 로딩 완료")

            else:
                st.empty()
        except:
            st.error("🚨 검토대상 파일을 첨부해주세요.")

        st.session_state['llm']

    tab901, tab902 = st.tabs(["👔 **개별질의/응답**", "🍜 **Bulk질의/응답(:blue[점심시간에 AI 일시키기])**"])

    with tab901:

        input100 = st.text_area("🖊️ **(파일첨부 및 모델 로딩후) 질문을 입력하세요.**")

        with st.expander("**질문 예시**", expanded=False):
            col71, col72, col73 =st.columns(3)
            with col71:
                st.code("what is the main topic of the attached article?")
                st.code("what is the main purpose of this document?")
            with col72:
                st.code("summarize the attached article within 5 lines")
                st.code("Are there any risky conditions in the document?")
            with col73:
                st.code("what is expected in the near future?")
                st.code("what is the best suggestion of the document?")

        chk1 = st.checkbox("응답 한글 번역 포함", value=True)
        st.session_state['replied'] = st.button("⚙️ 질문 제출", type='primary')


        with st.spinner("🤗 추론중..."):
            if st.session_state['replied']:
                st.session_state['questions'].append(input100)
                start_time = datetime.now()

                response = st.session_state['dbqa']({'query': input100})
                st.session_state['answer'].append(response["result"])
                st.session_state['src_docu1'].append(response["source_documents"][0].page_content)
                st.session_state['src_meta1'].append(response["source_documents"][0].metadata["source"])
                try:
                    st.session_state['src_docu2'].append(response["source_documents"][1].page_content)
                    st.session_state['src_meta2'].append(response["source_documents"][1].metadata["source"])
                except:
                    st.session_state['src_docu2'].append("")
                    st.session_state['src_meta2'].append("")

                st.markdown(f"😆 :blue[{st.session_state['answer'][-1]}]")
                end_time = datetime.now()
                delta = calculate_time_delta(start_time, end_time)
                st.warning(f"⏱️ 추론후 응답소요시간(초) : {delta}")

                try:
                    if chk1:
                        st.session_state['trans_answer'].append(trans(st.session_state['answer'][-1]))
                        st.markdown(f"😃[번역] :blue[{st.session_state['trans_answer'][-1]}]")
                    else:
                        st.session_state['trans_answer'].append("")
                except:
                    st.session_state['trans_answer'].append("")

                end_time = datetime.now()
                delta = calculate_time_delta(start_time, end_time)
                st.warning(f"⏱️ 번역후 응답소요시간(초) : {delta}")

            with st.expander("✔️ **전체 질의/응답결과 모음 (다운로드)**", expanded=True):
                df = pd.DataFrame({
                    "질문":st.session_state['questions'],
                    "응답":st.session_state['answer'],
                    "번역응답":st.session_state['trans_answer'],
                    "근거문장1":st.session_state['src_docu1'],
                    "근거파일1":st.session_state['src_meta1'],
                    "근거문장2":st.session_state['src_docu2'],
                    "근거파일2":st.session_state['src_meta2']})
                st.dataframe(df, use_container_width=True)


                @st.cache_data
                def convert_df(df):
                    # IMPORTANT: Cache the conversion to prevent computation on every rerun
                    return df.to_csv().encode('utf-8-sig')

                csv = convert_df(df)

                st.download_button(
                    label="🗄️ Download data as CSV",
                    data=csv,
                    file_name='answers.csv',
                    mime='text/csv',
                )
    with tab902:
        with st.expander("**벌크 질의/응답 방법 안내**", expanded=True):
            st.markdown('''
            - 질문 첨부 엑셀 파일은 :blue[**.xlsx**] 형식으로 하고, 1행에 :red[**"질문" 칼럼명**] 유지 (1열 데이터)
            - 아래 응답 추론 버튼 누르면, 질문을 순차적으로 불러와서 응답후 결과를 모아줌 (CSV 파일로 다운로드 가능)
            - 한글 번역은 시간당 API 요청 초과시 미실시 될 수 있음 (번역은 LLM이 아니라, 별도 Googletrans API 사용)
            - 검토 대상 문서의 사이즈, CPU-base Serving 등으로 추론에 다소간 시간이 소요되는 점 고려, 점심시간 등에 이용 권장(추론시간 단축 방법은 지속 검토중)
            ''')
            st.image("09_개발/김종배_llama2/images/첨부예시이미지.png", caption="엑셀 질문파일 예시 이미지")

        uploaded_file = st.file_uploader("질문 목록 엑셀 파일 첨부")
        if uploaded_file is not None:
            t_df = pd.read_excel(uploaded_file)
            질문목록 = t_df["질문"]
            질문개수 = len(질문목록)
            st.markdown(f"🎈 **입력 질문 목록 / :green[총 질문개수 :{질문개수}개]**")
            st.dataframe(질문목록, use_container_width=True)

        btn999= st.button("🔍 Bulk 응답 추론", type="primary")
        with st.spinner("벌크 응답 추론중..."):
            if btn999:

                for idx, 질문 in enumerate(질문목록):
                    time.sleep(1)
                    st.markdown(f"[질문{idx+1}] {질문}")
                    st.session_state['questions'].append(질문)
                    start_time = datetime.now()
                    start_time

                    response = st.session_state['dbqa']({'query': 질문})

                    st.session_state['answer'].append(response["result"])
                    st.session_state['src_docu1'].append(response["source_documents"][0].page_content)
                    st.session_state['src_meta1'].append(response["source_documents"][0].metadata["source"])
                    st.session_state['src_docu2'].append(response["source_documents"][1].page_content)
                    st.session_state['src_meta2'].append(response["source_documents"][1].metadata["source"])

                    st.markdown(f"😆 :blue[{st.session_state['answer'][-1]}]")

                    try:
                        if chk1:
                            st.session_state['trans_answer'].append(trans3(st.session_state['answer'][-1]))
                            st.markdown(f"😃[번역] :blue[{st.session_state['trans_answer'][-1]}]")
                        else:
                            st.session_state['trans_answer'].append("")
                    except:
                        st.session_state['trans_answer'].append("")

                    end_time = datetime.now()
                    delta = calculate_time_delta(start_time, end_time)
                    st.warning(f"⏱️ 응답소요시간(초) : {delta}")

                with st.expander("✔️ **전체 질의/응답결과 모음 (다운로드)**", expanded=True):
                    df = pd.DataFrame({
                        "질문":st.session_state['questions'],
                        "응답":st.session_state['answer'],
                        "번역응답":st.session_state['trans_answer'],
                        "근거문장1":st.session_state['src_docu1'],
                        "근거파일1":st.session_state['src_meta1'],
                        "근거문장2":st.session_state['src_docu2'],
                        "근거파일2":st.session_state['src_meta2']})
                    st.dataframe(df, use_container_width=True)


                    @st.cache_data
                    def convert_df(df):
                        # IMPORTANT: Cache the conversion to prevent computation on every rerun
                        return df.to_csv().encode('utf-8-sig')

                    csv = convert_df(df)

                    st.download_button(
                        label="🗄️ Download data as CSV",
                        data=csv,
                        file_name='bulk_answers.csv',
                        mime='text/csv',
                    )










