# 기본환경 설정

In [None]:
from google.colab import userdata
HF_KEY = userdata.get("HF_KEY")

In [None]:
import huggingface_hub
huggingface_hub.login(HF_KEY)

# 모델 로딩

In [None]:
!pip install unsloth

In [None]:
from unsloth import FastLanguageModel
import torch

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/gemma-3-4b-it",
    load_in_4bit=True
)

In [None]:
model = FastLanguageModel.for_inference(model)

# Custom ChatModel 함수

In [None]:
from typing import List, Any, ClassVar
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.outputs import ChatResult, ChatGeneration
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

In [None]:
class GemmaChatModel(BaseChatModel):
    def __init__(self, model, tokenizer, max_tokens: int = 512, do_sample: bool = True, temperature: float = 0.7, top_p: float = 0.9):
        super().__init__()
        object.__setattr__(self, "model", model)
        object.__setattr__(self, "tokenizer", tokenizer)
        object.__setattr__(self, "max_tokens", max_tokens)
        object.__setattr__(self, "do_sample", do_sample)
        object.__setattr__(self, "temperature", temperature)
        object.__setattr__(self, "top_p", top_p)

    @property
    def _llm_type(self) -> str:
        return "gemma-chat"

    def _format_messages(self, messages: List[Any]) -> str:
        prompt = ""
        for message in messages:
            if isinstance(message, SystemMessage):
                prompt += f"<|system|>\n{message.content}</s>\n"
            elif isinstance(message, HumanMessage):
                prompt += f"<|user|>\n{message.content}</s>\n"
            elif isinstance(message, AIMessage):
                prompt += f"<|assistant|>\n{message.content}</s>\n"
        prompt += "<|assistant|>\n"
        return prompt

    def _generate(self, messages: List[Any], **kwargs) -> ChatResult:
        prompt = self._format_messages(messages)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=self.max_tokens,
                do_sample=kwargs.get("do_sample", self.do_sample),
                temperature=kwargs.get("temperature", self.temperature),
                top_p=kwargs.get("top_p", self.top_p),
                eos_token_id=self.tokenizer.eos_token_id,
            )

        decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = decoded.split("<|assistant|>\n")[-1].strip()

        return ChatResult(generations=[ChatGeneration(message=AIMessage(content=response))])

In [None]:
chat_model = GemmaChatModel(model=model, tokenizer=tokenizer, max_tokens=512)

# Prompt 수행

In [None]:
result = chat_model.invoke([
    SystemMessage(content="너는 친절하고 머신러닝 기술을 잘아는 전문가 AI야."),
    HumanMessage(content="LangChain은 무엇인가요?"),
])
print(result.content)

# Chat History In Memory 구성

In [None]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory

In [None]:
# Prompt 템플릿 정의 : GemmaChatModel이 원하는 형태로 변경하기 위해.
prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a helpful assistant."),
    MessagesPlaceholder(variable_name="history"),
    ("human", "{input}"),
])

In [None]:
# 세션별 history 저장 함수
store = {}
def get_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = InMemoryChatMessageHistory()
    return store[session_id]

In [None]:
# RunnableWithMessageHistory 구성
runnable = RunnableWithMessageHistory(
    prompt | chat_model,  # chat_model은 GemmaChatModel 인스턴스
    get_history,
    input_messages_key="input",
    history_messages_key="history"
)

# session별 Chat History In Memory 수행

# user1 session 테스트

In [None]:
# 세션 아이디 설정
session_id = "user1"

In [None]:
# 4. 대화 실행
response1 = runnable.invoke(
    {"input": "나는 민수라고 해."},
    config={"configurable": {"session_id": session_id}}
)
print("Response 1:", response1.content)

In [None]:
response2 = runnable.invoke(
    {"input": "내 이름 기억해?"},
    config={"configurable": {"session_id": session_id}}
)
print("Response 2:", response2.content)

In [None]:
# 5. History 출력 (검증)
history = get_history(session_id)
print("\n=== Conversation History ===")
for msg in history.messages:
    print(f"{msg.type.upper()}: {msg.content}")

# user2 session 테스트

In [None]:
# 세션 아이디 설정
session_id = "user2"

In [None]:
response1 = runnable.invoke(
    {"input": "나는 철수라고 해."},
    config={"configurable": {"session_id": session_id}}
)
response2 = runnable.invoke(
    {"input": "내 이름 기억해?"},
    config={"configurable": {"session_id": session_id}}
)
print("Response 2:", response2.content)

In [None]:
# History 출력 (검증)
history = get_history(session_id)
print("\n=== Conversation History ===")
for msg in history.messages:
    print(f"{msg.type.upper()}: {msg.content}")

# user1 session에 추가 질문

In [None]:
# 세션 아이디 설정
session_id = "user1"

In [None]:
response2 = runnable.invoke(
    {"input": "내 이름 기억해?"},
    config={"configurable": {"session_id": session_id}}
)
print("Response 2:", response2.content)

In [None]:
# History 출력 (검증)
history = get_history(session_id)
print("\n=== Conversation History ===")
for msg in history.messages:
    print(f"{msg.type.upper()}: {msg.content}")