# 기본환경 설정

In [19]:
# !pip install faiss-cpu

In [None]:
from google.colab import userdata
HF_KEY = userdata.get("HF_KEY")

In [None]:
import huggingface_hub
huggingface_hub.login(HF_KEY)

# 모델 로딩

In [1]:
from unsloth import FastLanguageModel
from langchain.embeddings import HuggingFaceEmbeddings
import torch

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 08-06 15:50:37 [__init__.py:235] Automatically detected platform cuda.


In [2]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/gemma-3-4b-it",
    load_in_4bit=True
)

==((====))==  Unsloth 2025.7.8: Fast Gemma3 patching. Transformers: 4.54.0. vLLM: 0.10.0.
   \\   /|    NVIDIA GeForce RTX 4090. Num GPUs = 2. Max memory: 23.494 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [3]:
model = FastLanguageModel.for_inference(model)

In [4]:
# 임베딩 생성기 (한국어 포함 모델)
# embedding = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-base")
embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")

  embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")


# Custom ChatModel 함수

In [5]:
from typing import List, Any, ClassVar
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.outputs import ChatResult, ChatGeneration
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

In [26]:
class GemmaChatModel(BaseChatModel):
    def __init__(self, model, tokenizer, max_tokens: int = 512, do_sample: bool = True, temperature: float = 0.7, top_p: float = 0.9):
        super().__init__()
        object.__setattr__(self, "model", model)
        object.__setattr__(self, "tokenizer", tokenizer)
        object.__setattr__(self, "max_tokens", max_tokens)
        object.__setattr__(self, "do_sample", do_sample)
        object.__setattr__(self, "temperature", temperature)
        object.__setattr__(self, "top_p", top_p)

    @property
    def _llm_type(self) -> str:
        return "gemma-chat"

    def _format_messages(self, messages: List[Any]) -> str:
        prompt = ""
        for message in messages:
            if isinstance(message, SystemMessage):
                prompt += f"<|system|>\n{message.content}</s>\n"
            elif isinstance(message, HumanMessage):
                prompt += f"<|user|>\n{message.content}</s>\n"
            elif isinstance(message, AIMessage):
                prompt += f"<|assistant|>\n{message.content}</s>\n"
        prompt += "<|assistant|>\n"
        return prompt

    def _generate(self, messages: List[Any], **kwargs) -> ChatResult:
        prompt = self._format_messages(messages)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=self.max_tokens,
                do_sample=kwargs.get("do_sample", self.do_sample),
                temperature=kwargs.get("temperature", self.temperature),
                top_p=kwargs.get("top_p", self.top_p),
                eos_token_id=self.tokenizer.eos_token_id,
            )

        decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = decoded.split("<|assistant|>\n")[-1].strip()

        return ChatResult(generations=[ChatGeneration(message=AIMessage(content=response))])

In [27]:
chat_model = GemmaChatModel(model=model, tokenizer=tokenizer, max_tokens=512)

# Documents 준비

In [28]:
from langchain.vectorstores import FAISS

In [29]:
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA

In [30]:
# 문서 준비
texts = [
    "서울은 대한민국의 수도입니다.",
    "부산은 한국에서 두 번째로 큰 도시입니다.",
    "제주는 아름다운 섬으로 유명합니다.",
]

In [31]:
docs = [Document(page_content=t) for t in texts]

In [32]:
# 텍스트 분할
splitter = RecursiveCharacterTextSplitter(chunk_size=256, chunk_overlap=32)
split_docs = splitter.split_documents(docs)

In [33]:
# 벡터 스토어 생성
vectordb = FAISS.from_documents(split_docs, embedding=embedding)

  return forward_call(*args, **kwargs)


In [34]:
# RetrievalQA 체인 생성
retrieval_chain = RetrievalQA.from_chain_type(
    llm=chat_model,                    # 로컬 Gemma 모델
    chain_type="stuff",                # 간단한 체인 유형
    retriever=vectordb.as_retriever(), # 벡터 검색기
    return_source_documents=True       # 출처 문서 포함 여부
)

In [35]:
query = "한국의 수도는 어디야?"
result = retrieval_chain.invoke(query)

print("💬 답변:", result["result"])
print("📄 사용된 문서:", [doc.page_content for doc in result["source_documents"]])

  return forward_call(*args, **kwargs)


💬 답변: 서울은 대한민국의 수도입니다.
📄 사용된 문서: ['서울은 대한민국의 수도입니다.', '부산은 한국에서 두 번째로 큰 도시입니다.', '제주는 아름다운 섬으로 유명합니다.']
