# 기본환경 설정

In [None]:
# !pip install faiss-cpu

In [None]:
from google.colab import userdata
HF_KEY = userdata.get("HF_KEY")

In [None]:
import huggingface_hub
huggingface_hub.login(HF_KEY)

# 모델 로딩

In [None]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo langchain-community pypdf langchain_huggingface faiss-cpu
!pip install --no-deps unsloth

In [None]:
from unsloth import FastModel
from langchain.embeddings import HuggingFaceEmbeddings
import torch

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3-4b-it",
    max_seq_length = 1024*5, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    # device_map = {"": device}
)

# Custom ChatModel

In [None]:
from typing import List, Any
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.outputs import ChatResult, ChatGeneration
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableLambda
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

In [None]:
class GemmaChatModel(BaseChatModel):
    def __init__(self, model, tokenizer, max_tokens: int = 512, do_sample: bool = True, temperature: float = 0.7, top_p: float = 0.9):
        super().__init__()
        object.__setattr__(self, "model", model)
        object.__setattr__(self, "tokenizer", tokenizer)
        object.__setattr__(self, "max_tokens", max_tokens)
        object.__setattr__(self, "do_sample", do_sample)
        object.__setattr__(self, "temperature", temperature)
        object.__setattr__(self, "top_p", top_p)

    @property
    def _llm_type(self) -> str:
        return "gemma-chat"

    def _format_messages(self, messages: List[Any]) -> str:
        prompt = ""
        for message in messages:
            if isinstance(message, SystemMessage):
                prompt += f"<|system|>\n{message.content}</s>\n"
            elif isinstance(message, HumanMessage):
                prompt += f"<|user|>\n{message.content}</s>\n"
            elif isinstance(message, AIMessage):
                prompt += f"<|assistant|>\n{message.content}</s>\n"
        prompt += "<|assistant|>\n"
        return prompt

    def _generate(self, messages: List[Any], **kwargs) -> ChatResult:
        prompt = self._format_messages(messages)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=self.max_tokens,
                do_sample=kwargs.get("do_sample", self.do_sample),
                temperature=kwargs.get("temperature", self.temperature),
                top_p=kwargs.get("top_p", self.top_p),
                eos_token_id=self.tokenizer.eos_token_id,
            )

        decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = decoded.split("<|assistant|>\n")[-1].strip()

        return ChatResult(generations=[ChatGeneration(message=AIMessage(content=response))])

In [None]:
llm = GemmaChatModel(model=model, tokenizer=tokenizer, max_tokens=512)

# 툴 정의

In [None]:
# 예시 툴(실제로는 API 호출 등으로 교체)
def get_weather(city: str) -> str:
    return f"{city}: 맑음, 25℃ (데모)"

In [None]:
TOOLS = {
    "get_weather": {
        "description": "도시의 현재 날씨를 조회",
        "parameters": {"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}
    }
}
TOOL_FUNCS = {"get_weather": get_weather}

In [None]:
parser = JsonOutputParser()  # {"tool": "...", "args": {...}} 로 파싱

In [None]:
select_instruct = """\
너는 도구를 선택해 답을 찾는 어시스턴트다.
반드시 아래 JSON 스키마만 출력하라. 설명/문장/코드블록 금지.

JSON 스키마:
{{
  "tool": "<{tool_names} 중 하나>|none",
  "args": <object>  # 선택한 툴의 parameters와 일치
}}

사용 가능 도구 정의:
{tool_schema}
""".strip()

In [None]:
select_prompt = ChatPromptTemplate.from_messages([
    ("system", select_instruct),
    ("human", "사용자 질문: {input}")
])

In [None]:
# 툴 선택 체인
select_chain = select_prompt | llm | parser

# 파이썬 라우팅: 선택 결과를 받아 툴 실행 → observation 생성

In [None]:
def run_tool(selection: dict):
    tool = selection.get("tool", "none")

    # 찾지 못했는지 검사
    if tool == "none" or tool not in TOOL_FUNCS:
        return {"observation": None, "tool": "none"}

    # 툴에 사용할 argument 확인
    args = selection.get("args", {}) or {}

    # 툴 실행
    try:
        result = TOOL_FUNCS[tool](**args)
    except Exception as e:
        result = f"TOOL_ERROR: {e}"

    return {"observation": result, "tool": tool, "args": args}

In [None]:
route = RunnableLambda(run_tool)

# LLM Prompt

In [None]:
final_instruct = """\
너는 도구 결과를 참고하여 간결하고 정확하게 한국어로 답한다.
도구결과: {observation}

지시:
- 도구결과가 있으면 그것을 근거로 한 문단의 최종 답을 써라.
- 도구결과가 없으면 도구 없이 바로 답하되, 모르면 솔직히 모른다고 말해라.
""".strip()

In [None]:
final_prompt = ChatPromptTemplate.from_messages([
    ("system", final_instruct),
    ("human", "사용자 질문: {input}")
])

In [None]:
final_chain = final_prompt | llm

# 엔드-투-엔드 실행

In [None]:
def answer(user_input: str):
    selection = select_chain.invoke({
        "input": user_input,
        "tool_schema": TOOLS,
        "tool_names": ", ".join(TOOLS.keys())
    })
    routed = route.invoke(selection)
    return final_chain.invoke({
        "observation": routed["observation"],
        "input": user_input
    })

In [None]:
print(answer("서울 날씨 알려줘"))