# 기본환경 설정

In [None]:
# !pip install faiss-cpu

In [None]:
from google.colab import userdata
HF_KEY = userdata.get("HF_KEY")

In [None]:
import huggingface_hub
huggingface_hub.login(HF_KEY)

# 모델 로딩

In [None]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo langchain-community pypdf langchain_huggingface faiss-cpu
!pip install --no-deps unsloth

In [None]:
from unsloth import FastModel
from langchain.embeddings import HuggingFaceEmbeddings
import torch

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3-4b-it",
    max_seq_length = 1024*5, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    device_map = {"": device}
)

In [None]:
model = FastModel.for_inference(model)

# Custom ChatModel 함수

In [None]:
from typing import Any, Dict, Optional, List, Union, Type, Literal
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult

In [None]:
from langchain_core.runnables import RunnableLambda
from typing import Any, Dict, Optional, List, Union, Type, Literal
import json, warnings
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage

In [None]:
from pydantic import BaseModel as PydanticBaseModel

In [None]:
class GemmaChatModel(BaseChatModel):
    def __init__(self, model, tokenizer, max_tokens: int = 512, do_sample: bool = True, temperature: float = 0.7, top_p: float = 0.9):
        super().__init__()
        object.__setattr__(self, "model", model)
        object.__setattr__(self, "tokenizer", tokenizer)
        object.__setattr__(self, "max_tokens", max_tokens)
        object.__setattr__(self, "do_sample", do_sample)
        object.__setattr__(self, "temperature", temperature)
        object.__setattr__(self, "top_p", top_p)

    @property
    def _llm_type(self) -> str:
        return "gemma-chat"

    def _format_messages(self, messages: List[BaseMessage]) -> str:
        prompt = ""
        for m in messages:
            if isinstance(m, SystemMessage):
                prompt += f"<|system|>\n{m.content}</s>\n"
            elif isinstance(m, HumanMessage):
                prompt += f"<|user|>\n{m.content}</s>\n"
            elif isinstance(m, AIMessage):
                prompt += f"<|assistant|>\n{m.content}</s>\n"
        prompt += "<|assistant|>\n"
        return prompt

    def _apply_stop(self, text: str, stop: Optional[List[str]]) -> str:
        if not stop:
            return text
        cut = len(text)
        for s in stop:
            idx = text.find(s)
            if idx != -1:
                cut = min(cut, idx)
        return text[:cut]

    def _generate(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[Any] = None, **kwargs: Any) -> ChatResult:
        prompt = self._format_messages(messages)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        gen_kwargs = {
            "max_new_tokens": kwargs.get("max_tokens", self.max_tokens),
            "do_sample": kwargs.get("do_sample", self.do_sample),
            "temperature": kwargs.get("temperature", self.temperature),
            "top_p": kwargs.get("top_p", self.top_p),
            "eos_token_id": self.tokenizer.eos_token_id,
            "pad_token_id": self.tokenizer.pad_token_id,
        }

        with torch.no_grad():
            outputs = self.model.generate(**inputs, **gen_kwargs)

        decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        # 마지막 assistant 턴 이후만 추출
        if "<|assistant|>\n" in decoded:
            response = decoded.split("<|assistant|>\n")[-1]
        else:
            response = decoded
        response = response.strip()
        response = self._apply_stop(response, stop)

        return ChatResult(generations=[ChatGeneration(message=AIMessage(content=response))])

    ### Structured output support #######################################################################################
    def _build_json_system_prompt(self, schema_text: str) -> str:
        # 모델이 JSON만 내도록 강하게 지시 (hallucination 방지용 규칙 포함)
        return (
            "You are a strict JSON generator.\n"
            "Return ONLY a single JSON object, no prose, no backticks, no explanations.\n"
            "Do not include trailing commas. Do not include comments.\n"
            "Conform exactly to the following JSON schema (fields, types, required):\n"
            f"{schema_text}\n"
        )

    def _ensure_pydantic(self):
        if PydanticBaseModel is None:
            raise RuntimeError(
                "Pydantic is not available. Install pydantic or pass a dict schema instead of a BaseModel."
            )

    def _schema_to_text(self, schema: Union[Type["PydanticBaseModel"], Dict[str, Any]]) -> str:
        if PydanticBaseModel is not None and isinstance(schema, type) and issubclass(schema, PydanticBaseModel):
            # pydantic 스키마를 JSON 스키마로 직렬화
            try:
                # v1/v2 호환 직렬화
                json_schema = schema.model_json_schema()  # pydantic v2
            except Exception:
                json_schema = schema.schema()  # pydantic v1
            return json.dumps(json_schema, ensure_ascii=False, indent=2)
        elif isinstance(schema, dict):
            return json.dumps(schema, ensure_ascii=False, indent=2)
        else:
            raise TypeError("schema must be a Pydantic BaseModel subclass or a dict JSON schema.")

    def _parse_structured(self, text: str, schema: Union[Type["PydanticBaseModel"], Dict[str, Any]], include_raw: bool):
        # 코드블록 등 제거 시도(혹시 들어올 경우)
        t = text.strip()
        if t.startswith("```"):
            # ```json ... ``` 또는 ``` ... ```
            t = t.strip("`")
            # 첫 줄에 json 명시가 들어있을 수 있음
            t = "\n".join(line for line in t.splitlines() if not line.lower().startswith("json"))
        # JSON 파싱
        obj = json.loads(t)

        # Pydantic 검증
        if PydanticBaseModel is not None and isinstance(schema, type) and issubclass(schema, PydanticBaseModel):
            validated = schema.model_validate(obj) if hasattr(schema, "model_validate") else schema.parse_obj(obj)
            return {"parsed": validated, "raw": text} if include_raw else validated
        else:
            # dict 스키마는 별도 검증 없이 반환 (원하면 jsonschema로 검증 가능)
            return {"parsed": obj, "raw": text} if include_raw else obj

    def with_structured_output(
        self,
        schema: Union[Type["PydanticBaseModel"], Dict[str, Any]],
        *,
        method: Literal["json_mode"] = "json_mode",
        include_raw: bool = False,
        system_prefix: Optional[str] = None,
        deterministic: bool = True,
    ):
        schema_text = self._schema_to_text(schema)
        sys_prompt = self._build_json_system_prompt(schema_text)
        if system_prefix:
            sys_prompt = system_prefix.rstrip() + "\n\n" + sys_prompt

        def _invoke(messages_or_any):
            # 입력을 메시지 리스트로 정규화
            if isinstance(messages_or_any, list) and all(isinstance(m, BaseMessage) for m in messages_or_any):
                msgs = [SystemMessage(content=sys_prompt)] + messages_or_any
            else:
                # 문자열/딕셔너리 등도 처리
                msgs = [SystemMessage(content=sys_prompt), HumanMessage(content=str(messages_or_any))]

            # 결정론 옵션
            kw = {}
            if deterministic:
                kw = {"do_sample": False, "temperature": 0.0, "top_p": 1.0}

            # 1차 시도
            result = self._generate(msgs, **kw)
            text = result.generations[0].message.content
            try:
                return self._parse_structured(text, schema, include_raw)
            except Exception:
                # 재시도: 더 강한 지시
                retry_msgs = [SystemMessage(content=sys_prompt + "\nOutput must be valid JSON. Try again.")] + msgs[1:]
                result2 = self._generate(retry_msgs, **kw)
                text2 = result2.generations[0].message.content
                return self._parse_structured(text2, schema, include_raw)

        # Runnable 로 래핑해서 반환 (체인 파이프에 바로 사용 가능)
        return RunnableLambda(_invoke)

In [None]:
chat_model = GemmaChatModel(model=model, tokenizer=tokenizer, max_tokens=1024*5)

# prompt와 model 연결

In [None]:
from langchain_core.prompts import ChatPromptTemplate

In [None]:
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "사용자가 입력한 요리의 레시피를 생각해 주세요."),
        ("human", "{dish}"),
    ]
)

In [None]:
chain = prompt | chat_model

In [None]:
ai_message = chain.invoke({"dish": "카레"})
print(ai_message.content)

# Output Parser에 연결

In [None]:
from langchain_core.output_parsers import StrOutputParser

In [None]:
chain = prompt | chat_model | StrOutputParser()

In [None]:
output = chain.invoke({"dish": "카레"})
print(output)

# PydanticOutputParser 사용

In [None]:
from langchain_core.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field

In [None]:
class Recipe(BaseModel):
    ingredients: list[str] = Field(description="ingredients of the dish")
    steps: list[str] = Field(description="steps to make the dish")
    tips: list[str] = Field(description="Tips for making food delicious.")

In [None]:
output_parser = PydanticOutputParser(pydantic_object=Recipe)

In [None]:
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "사용자가 입력한 요리의 레시피를 생각해 주세요.\n\n{format_instructions}"),
        ("human", "{dish}"),
    ]
)

In [None]:
prompt_with_format_instructions = prompt.partial(
    format_instructions=output_parser.get_format_instructions()
)

In [None]:
model = GemmaChatModel(model=model, tokenizer=tokenizer, max_tokens=1024*5, temperature=0.3).bind(
    response_format={"type": "json_object"}
)

In [None]:
chain = prompt_with_format_instructions | model | output_parser

In [None]:
recipe = chain.invoke({"dish": "카레"})
print(type(recipe))
print(recipe)

# with_structured_output

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

In [None]:
class Recipe(BaseModel):
    ingredients: list[str] = Field(description="ingredients of the dish")
    steps: list[str] = Field(description="steps to make the dish")

In [None]:
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "사용자가 입력한 요리의 레시피를 생각해 주세요."),
        ("human", "{dish}"),
    ]
)

In [None]:
chain = prompt | model.with_structured_output(Recipe)

In [None]:
recipe = chain.invoke({"dish": "카레"})
print(type(recipe))
print(recipe)