In [5]:
from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_community.chat_models import ChatOllama
from langchain.callbacks.base import BaseCallbackHandler

#Tracking with LLm is applied
class ModelTrackingCallback(BaseCallbackHandler):
    def __init__(self):
        self.model_name = None

    def on_chat_model_start(self, serialized, messages, **kwargs):
        self.model_name = serialized.get("name") or serialized.get("model")
        
# -------------------------------------------------
# 1. Prompt (RAG / Legal safe)
# -------------------------------------------------
prompt = ChatPromptTemplate.from_messages([
    ("system", """You are a Legal Assistant.
Use ONLY the provided DOCUMENT CONTEXT.
If the answer is not in the context, say:
"I cannot find this in our database."
Do NOT use external knowledge."""),
    ("user", "DOCUMENT CONTEXT:\n{data_context}\n\nQUESTION: {question}")
])

# -------------------------------------------------
# 2. Models (ordered by preference)
# -------------------------------------------------

# Primary (best quality)
gpt4 = ChatOpenAI(
    model="gpt-4o",
    temperature=0,
    timeout=30
)

# Fast / cheaper fallback
fast_gpt = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0,
    timeout=20
)

# Local offline fallback (Ollama)
local_llama = ChatOllama(
    model="llama3.2",
    temperature=0,
    base_url="http://ollama:11434"  # Docker service name
)

# -------------------------------------------------
# 3. Robust model with fallbacks
# -------------------------------------------------
robust_model = local_llama.with_fallbacks([
    fast_gpt,
    gpt4
])

tracker = ModelTrackingCallback()


# -------------------------------------------------
# 4. Output parser
# -------------------------------------------------
parser = StrOutputParser()

# -------------------------------------------------
# 5. LCEL Chain
# -------------------------------------------------
chain = prompt | robust_model | parser

# -------------------------------------------------
# 6. Invoke
# -------------------------------------------------
result = chain.invoke({
    "data_context": "Doc 42: Employees are entitled to 20 days of annual leave.",
    "question": "How many days of annual leave do employees get?"
}    ,config={"callbacks": [tracker]}
)

print(result)
print("Model used:", tracker.model_name)

Employees are entitled to 20 days of annual leave.
Model used: ChatOllama
