In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from modules import utils

In [3]:
ENVS = utils.load_env_to_dict("./../secrets/env")

GENERATE_MODEL_NAME = "phatjk/vietcuna-7b-v3-AWQ"
EMBEDDINGS_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
HUGGINGFACE_API_KEY = ENVS["HUGGINGFACE_API_KEY"]

In [4]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.bettertransformer import BetterTransformer
import torch

In [5]:
# Detect using MPS or not
device = torch.device("mps")

In [6]:
device

device(type='mps')

In [7]:
model_rerank = AutoModelForSequenceClassification.from_pretrained(
    "amberoad/bert-multilingual-passage-reranking-msmarco"
).to(device)
tokenizer_rerank = AutoTokenizer.from_pretrained(
    "amberoad/bert-multilingual-passage-reranking-msmarco"
)

In [8]:
from langchain.schema.document import Document
from langchain_core.vectorstores import VectorStoreRetriever
from langchain.retrievers import WikipediaRetriever

from typing import List

In [9]:
class RerankRetriever(VectorStoreRetriever):
    vectorstore: VectorStoreRetriever

    def _get_relevant_documents(self, query: str) -> List[Document]:
        docs = self.vectorstore.get_relevant_documents(query=query)
        candidates = [doc.page_content for doc in docs]
        queries = [query] * len(candidates)
        features = tokenizer_rerank(
            queries, candidates, padding=True, truncation=True, return_tensors="pt"
        ).to(device)
        with torch.no_grad():
            scores = model_rerank(**features).logits
            values, indices = torch.sum(scores, dim=1).sort()
            # relevant_docs = docs[indices[0]]
        return [docs[indices[0]], docs[indices[1]]]

In [10]:
class RerankWikiRetriever(VectorStoreRetriever):
    vectorstore: WikipediaRetriever

    def _get_relevant_documents(self, query: str) -> List[Document]:
        docs = self.vectorstore.get_relevant_documents(query=query)
        candidates = [doc.page_content for doc in docs]
        queries = [query] * len(candidates)
        features = tokenizer_rerank(
            queries, candidates, padding=True, truncation=True, return_tensors="pt"
        ).to(device)
        with torch.no_grad():
            scores = model_rerank(**features).logits
            values, indices = torch.sum(scores, dim=1).sort()
            # relevant_docs = docs[indices[0]]
        return [docs[indices[0]], docs[indices[1]]]

In [11]:
from langchain.retrievers import WikipediaRetriever
from langchain.vectorstores import Qdrant
from langchain.llms import HuggingFacePipeline
from qdrant_client import QdrantClient
from langchain.prompts import PromptTemplate
from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain.chains import RetrievalQA, MultiRetrievalQAChain
from langchain.llms import VLLM
from langchain.llms import HuggingFaceHub

In [12]:
class LLMServe:
    def __init__(self) -> None:
        self.embeddings = self.load_embeddings()
        self.current_source = "wiki"
        self.retriever = self.load_retriever(
            retriever_name=self.current_source, embeddings=self.embeddings
        )
        self.pipe = self.load_model_pipeline(max_new_tokens=300)
        self.prompt = self.load_prompt_template()
        self.rag_pipeline = self.load_rag_pipeline(
            llm=self.pipe, retriever=self.retriever, prompt=self.prompt
        )

    def load_embeddings(self):
        embeddings = HuggingFaceInferenceAPIEmbeddings(
            model_name=EMBEDDINGS_MODEL_NAME,
            api_key=HUGGINGFACE_API_KEY,
        )
        return embeddings

    def load_retriever(self, retriever_name, embeddings):
        retriever = None
        if retriever_name == "wiki":
            retriever = RerankWikiRetriever(
                vectorstore=WikipediaRetriever(
                    lang="vi",
                    doc_content_chars_max=800,
                    top_k_results=15,
                )
            )
        else:
            client = QdrantClient(
                url=QDRANT_URL, api_key=QDRANT_API_KEY, prefer_grpc=False
            )
            db = Qdrant(
                client=client,
                embeddings=embeddings,
                collection_name=QDRANT_COLLECTION_NAME,
            )

            retriever = RerankRetriever(
                vectorstore=db.as_retriever(search_kwargs={"k": 15})
            )

        return retriever

    def load_model_pipeline(self, max_new_tokens=100):
        llm = VLLM(
            model=GENERATE_MODEL_NAME,
            trust_remote_code=True,  # mandatory for hf models
            max_new_tokens=max_new_tokens,
            # temperature=1.0,
            # top_k=50,
            # top_p=0.9,
            top_k=10,
            top_p=0.95,
            temperature=0.4,
            dtype="half",
            vllm_kwargs={"quantization": "awq"},
        )
        return llm

    def load_prompt_template(self):
        # query_template = "Bạn là một trợ lý của trường Đại học Nguyễn Tất Thành. Hãy trả lời câu hỏi sau dựa trên ngữ cảnh, nếu ngữ cảnh không cung cấp câu trả lời hoặc không chắc chắn hãy trả lời 'Tôi không biết thông tin này, tuy nhiên đoạn thông tin dưới phần tham khảo có thể có câu trả lời cho bạn!' đừng cố tạo ra câu trả lời không có trong ngữ cảnh.\nNgữ cảnh: {context} \nCâu hỏi: {question}\nTrả lời: "
        # query_template = "Tham khảo ngữ cảnh:{context}\n\n### Câu hỏi:{question}\n\n### Trả lời:"
        query_template = "Bạn là một chatbot thông minh trả lời câu hỏi dựa trên ngữ cảnh (context).\n\n### Context:{context} \n\n### Human: {question}\n\n### Assistant:"
        prompt = PromptTemplate(
            template=query_template, input_variables=["context", "question"]
        )
        return prompt

    def load_rag_pipeline(self, llm, retriever, prompt):
        rag_pipeline = RetrievalQA.from_chain_type(
            llm=llm,
            chain_type="stuff",
            retriever=retriever,
            chain_type_kwargs={"prompt": prompt},
            return_source_documents=True,
        )
        return rag_pipeline

    def rag(self, source):
        if source == self.current_source:
            return self.rag_pipeline
        else:
            self.retriever = self.load_retriever(
                retriever_name=source, embeddings=self.embeddings
            )
            self.rag_pipeline = self.load_rag_pipeline(
                llm=self.pipe, retriever=self.retriever, prompt=self.prompt
            )
            self.current_source = source
            return self.rag_pipeline

In [13]:
app = LLMServe()

RuntimeError: Failed to infer device type

In [14]:
llm = VLLM(
    model=GENERATE_MODEL_NAME,
    trust_remote_code=True,  # mandatory for hf models
    max_new_tokens=100,
    # temperature=1.0,
    # top_k=50,
    # top_p=0.9,
    top_k=10,
    top_p=0.95,
    temperature=0.4,
    dtype="half",
    vllm_kwargs={"quantization": "awq"},
)

RuntimeError: Failed to infer device type