# 기본환경 설정

In [None]:
# !pip install faiss-cpu

In [None]:
from google.colab import userdata
HF_KEY = userdata.get("HF_KEY")

In [None]:
import huggingface_hub
huggingface_hub.login(HF_KEY)

# 모델 로딩

In [None]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo langchain-community pypdf langchain_huggingface faiss-cpu
!pip install --no-deps unsloth

In [None]:
from unsloth import FastModel
from langchain.embeddings import HuggingFaceEmbeddings
import torch

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3-4b-it",
    max_seq_length = 1024*5, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    device_map = {"": device}
)

In [None]:
model = FastModel.for_inference(model)

# Custom ChatModel 함수

In [None]:
from typing import Any, Dict, List, Optional
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult

In [None]:
class GemmaChatModel(BaseChatModel):
    def __init__(self, model, tokenizer, max_tokens: int = 512, do_sample: bool = True, temperature: float = 0.7, top_p: float = 0.9):
        super().__init__()
        object.__setattr__(self, "model", model)
        object.__setattr__(self, "tokenizer", tokenizer)
        object.__setattr__(self, "max_tokens", max_tokens)
        object.__setattr__(self, "do_sample", do_sample)
        object.__setattr__(self, "temperature", temperature)
        object.__setattr__(self, "top_p", top_p)

    @property
    def _llm_type(self) -> str:
        return "gemma-chat"

    def _format_messages(self, messages: List[BaseMessage]) -> str:
        prompt = ""
        for m in messages:
            if isinstance(m, SystemMessage):
                prompt += f"<|system|>\n{m.content}</s>\n"
            elif isinstance(m, HumanMessage):
                prompt += f"<|user|>\n{m.content}</s>\n"
            elif isinstance(m, AIMessage):
                prompt += f"<|assistant|>\n{m.content}</s>\n"
        prompt += "<|assistant|>\n"
        return prompt

    def _apply_stop(self, text: str, stop: Optional[List[str]]) -> str:
        if not stop:
            return text
        cut = len(text)
        for s in stop:
            idx = text.find(s)
            if idx != -1:
                cut = min(cut, idx)
        return text[:cut]

    def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[Any] = None, **kwargs: Any) -> ChatResult:
        prompt = self._format_messages(messages)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        gen_kwargs = {
            "max_new_tokens": kwargs.get("max_tokens", self.max_tokens),
            "do_sample": kwargs.get("do_sample", self.do_sample),
            "temperature": kwargs.get("temperature", self.temperature),
            "top_p": kwargs.get("top_p", self.top_p),
            "eos_token_id": self.tokenizer.eos_token_id,
            "pad_token_id": self.tokenizer.pad_token_id,
        }

        with torch.no_grad():
            outputs = self.model.generate(**inputs, **gen_kwargs)

        decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        # 마지막 assistant 턴 이후만 추출
        if "<|assistant|>\n" in decoded:
            response = decoded.split("<|assistant|>\n")[-1]
        else:
            response = decoded
        response = response.strip()
        response = self._apply_stop(response, stop)

        return ChatResult(generations=[ChatGeneration(message=AIMessage(content=response))])

In [None]:
chat_model = GemmaChatModel(model=model, tokenizer=tokenizer, max_tokens=1024*5)

# Tool 정의

In [None]:
from langchain_core.tools import tool, Tool

In [None]:
@tool
def calculator(expression: str) -> str:
    """문자열 수식을 계산합니다. 예: '12 * (3 + 4) / 5'"""
    import math
    safe = {k: v for k, v in math.__dict__.items() if not k.startswith("__")}
    try:
        return str(eval(expression, {"__builtins__": {}}, safe))
    except Exception as e:
        return f"계산 실패: {e}"

In [None]:
@tool
def get_time(_: str = "") -> str:
    """현재 시간을 ISO8601 형식으로 반환합니다."""
    from datetime import datetime, timezone
    return datetime.now(timezone.utc).astimezone().isoformat()

In [None]:
calculator_tool = Tool(
    name="CalculatorTool",
    func=calculator,
    description="문자열 수식을 계산합니다."
)

In [None]:
time_tool = Tool(
    name="TimeTool",
    func=get_time,
    description="현재 시간을 ISO8601 형식으로 반환합니다."
)

In [None]:
tools = [calculator_tool, time_tool]

# 에이전트 프롬프트

In [None]:
from langchain.agents import AgentExecutor, Tool, initialize_agent, AgentType

In [None]:
agent = initialize_agent(
    tools=tools,
    llm=chat_model,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=True
)

In [None]:
# 실행 예시
res = agent.invoke({"input": "계산기로 (12 * (3 + 4)) / 5 값을 구해줘."})
print(res["output"])

In [None]:
(12 * (3 + 4)) / 5