In [None]:
import chainlit as cl
from chainlit.playground.providers import ChatOpenAI
import json
from typing import Dict, List, Any

# Import your travel agent components and classes
from langchain.graphs.state import StateGraph, END
from langchain.schema import AgentState
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from langchain_openai import ChatOpenAI
from sentence_transformers import SentenceTransformer
from langchain_core.embeddings import Embeddings
from langchain.vectorstores import FAISS
from langchain.chains import create_stuff_documents_chain, create_retrieval_chain

# Define your embedding class
class FineTunedTravelEmbeddings(Embeddings):
    """LangChain-compatible wrapper for the fine-tuned travel embeddings"""
    
    def __init__(self, model_path="travel_assistant_embeddings", device=None):
        """Initialize with the path to the fine-tuned model"""
        self.model = SentenceTransformer(model_path, device=device)
        print(f"Loaded fine-tuned model from {model_path}")
    
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Generate embeddings for a list of documents"""
        embeddings = self.model.encode(
            texts,
            batch_size=32,
            show_progress_bar=len(texts) > 50,
            convert_to_numpy=True
        )
        return embeddings.tolist()
    
    def embed_query(self, text: str) -> List[float]:
        """Generate embedding for a query"""
        embedding = self.model.encode(text, convert_to_numpy=True)
        return embedding.tolist()

# Initialize LLM
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0.7)

# Define RAG setup
@cl.cache
def setup_rag_chain_fine_tuned():
    """Set up the RAG chain with fine-tuned embeddings for information retrieval."""
    # Use the fine-tuned embeddings if available, otherwise use OpenAI embeddings
    try:
        embeddings = FineTunedTravelEmbeddings(model_path="travel_assistant_embeddings")
        vector_db_path = "travel_db_fine_tuned"
    except Exception as e:
        print(f"Using default embeddings due to error: {e}")
        from langchain_openai import OpenAIEmbeddings
        embeddings = OpenAIEmbeddings()
        vector_db_path = "travel_db_faiss"
    
    # Load the vector store
    try:
        travel_db = FAISS.load_local(
            vector_db_path, 
            embeddings,
            allow_dangerous_deserialization=True
        )
        
        # Load the docstore if it exists
        import pickle
        import os
        docstore_path = f"{vector_db_path}_docstore.pkl"
        if os.path.exists(docstore_path):
            with open(docstore_path, 'rb') as f:
                travel_db.docstore = pickle.load(f)
        
        print(f"Loaded knowledge base with {len(travel_db.index_to_docstore_id)} documents")
    except Exception as e:
        print(f"Error loading vector store: {e}")
        raise
    
    # Create a retriever
    retriever = travel_db.as_retriever(
        search_type="similarity",
        search_kwargs={"k": 7}
    )
        
    rag_prompt = ChatPromptTemplate.from_template("""You are a knowledgeable travel assistant with expertise in destinations worldwide.
        Use the following travel information to provide detailed, accurate responses to the user's query.
        If the retrieved information doesn't fully answer the question, use your knowledge to provide
        the best possible response, but prioritize the retrieved information.
        
        Retrieved information: {context}        
        Question: {input}
    """)

    # Create the document processing chain
    document_chain = create_stuff_documents_chain(llm, rag_prompt)
    
    return create_retrieval_chain(retriever, document_chain)

# Define agent functions
def router_agent(state: AgentState) -> dict:
    """Router agent that determines which specialized agent should handle the query."""
    router_prompt = ChatPromptTemplate.from_messages([
        ("system", """You are a travel assistant router. Your job is to determine which specialized agent
        should handle the user's travel-related query. Choose the most appropriate agent from:
        
        - itinerary_agent: For requests to create travel itineraries, vacation plans, or multi-day travel schedules
        - flight_agent: For questions about flights, airfares, airlines, or flight bookings
        - accommodation_agent: For questions about hotels, resorts, accommodations, or places to stay
        - information_agent: For general travel information, destination facts, or travel advice
        
        Respond ONLY with the name of the appropriate agent. Do not include any explanations or additional text.
        """),
        ("human", "{query}")
    ])
    
    chain = router_prompt | llm | StrOutputParser()
    agent_executor = chain.invoke({"query": state.query}).strip()
    
    valid_agents = ["itinerary_agent", "flight_agent", "accommodation_agent", "information_agent"]
    if agent_executor not in valid_agents:
        agent_executor = "information_agent"
            
    return {"agent_executor": agent_executor, 
            "query": state.query,
            "chat_history": state.chat_history,
            "agent_response": state.agent_response,
            "final_response": state.final_response,
            "context": state.context,
            "error": state.error}

def itinerary_agent(state: AgentState) -> dict:
    """Creates personalized travel itineraries."""
    try:
        # Extract context from state or initialize empty dict
        context = state.context or {}
        
        # Extract entities from the user's query
        extraction_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a travel analysis assistant. Extract key information from the user's query.
            Return a JSON object with these fields (leave empty strings if not mentioned):
            {
                "destinations": ["list of destination cities or countries"],
                "duration": "duration of the trip mentioned (e.g., '5 days', '1 week')",
                "budget": "any budget information (e.g., 'luxury', 'budget-friendly')",
                "interests": ["list of mentioned activities or interests"],
                "travel_dates": "when they plan to travel",
                "travelers": "information about who is traveling (e.g., 'family with kids', 'solo')"
            }
            """),
            ("human", "{query}")
        ])
        
        extraction_chain = extraction_prompt | llm | StrOutputParser()
        
        try:
            extracted_info = json.loads(extraction_chain.invoke({"query": state.query}))
            context.update(extracted_info)
        except json.JSONDecodeError:
            pass
        
        # Generate the itinerary
        itinerary_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a travel itinerary expert who creates detailed, personalized travel plans.
            Create a comprehensive day-by-day itinerary for the user's request, including:
            
            - Appropriate pacing with realistic timing for activities
            - A variety of activities tailored to the mentioned interests
            - Specific recommendations for attractions, restaurants, and transportation
            - Practical travel tips related to the destination(s)
            - Brief contextual information about key attractions
            
            If certain details (like budget, dates, or interests) aren't specified, make balanced recommendations
            that would appeal to most travelers. Format the response with clear headings and organize by days.
            """),
            ("human", """{query}
            
            Travel Details:
            - Destinations: {destinations}
            - Duration: {duration}
            - When: {travel_dates}
            - Budget: {budget}
            - Interests: {interests}
            - Travelers: {travelers}
            """)
        ])
        
        itinerary_chain = itinerary_prompt | llm | StrOutputParser()
        
        destinations_str = ", ".join(context.get("destinations", []))
        interests_str = ", ".join(context.get("interests", []))
        
        response = itinerary_chain.invoke({
            "query": state.query,
            "destinations": destinations_str,
            "duration": context.get("duration", ""),
            "budget": context.get("budget", ""),
            "interests": interests_str,
            "travel_dates": context.get("travel_dates", ""),
            "travelers": context.get("travelers", "")
        })
        
        return {"agent_response": response, "context": context}
        
    except Exception as e:
        return {"error": str(e)}

def flight_agent(state: AgentState) -> dict:
    """Provides flight information and recommendations."""
    try:
        # Extract flight-related parameters
        extraction_prompt = ChatPromptTemplate.from_messages([
            ("system", """Extract flight search parameters from the user's query.
            Return a JSON object with these fields (leave empty if not mentioned):
            {
                "origin": "departure city or airport",
                "destination": "arrival city or airport",
                "departure_date": "in YYYY-MM-DD format if possible",
                "return_date": "in YYYY-MM-DD format if possible",
                "passengers": "number of passengers",
                "class": "economy, business, or first class",
                "preferences": ["non-stop", "specific airline", etc.]
            }
            """),
            ("human", "{query}")
        ])
        
        extraction_chain = extraction_prompt | llm | StrOutputParser()
        try:
            flight_params = json.loads(extraction_chain.invoke({"query": state.query}))
            state.context = {**state.context, "flight_params": flight_params} if state.context else {"flight_params": flight_params}
        except json.JSONDecodeError:
            state.context = state.context or {}
        
        # Get flight information using RAG
        rag_chain = setup_rag_chain_fine_tuned()
        retrieval_result = rag_chain.invoke({"input": state.query})
        
        # Generate response
        flight_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a flight expert. Provide detailed flight information and recommendations
            based on the user's query and the retrieved data. Include:
            
            - Flight options including airlines and typical schedules
            - Estimated price ranges
            - Booking recommendations and timing advice
            - Airport tips and information
            - Seasonal considerations for the route
            
            If specific flight data isn't available, provide general advice about flights for the
            requested route, typical costs, best booking times, and airlines that service the route.
            
            Retrieved flight information:
            {context}
            
            Extracted flight parameters:
            {flight_params}
            """),
            ("human", "{query}")
        ])
        
        flight_chain = flight_prompt | llm | StrOutputParser()
        response = flight_chain.invoke({
            "query": state.query,
            "context": retrieval_result.get("answer", ""),
            "flight_params": json.dumps(state.context.get("flight_params", {}), indent=2)
        })
        
        return {"agent_response": response}
        
    except Exception as e:
        return {"error": str(e)}

def accommodation_agent(state: AgentState) -> dict:
    """Provides hotel and accommodation recommendations."""
    try:
        # Extract accommodation preferences
        extraction_prompt = ChatPromptTemplate.from_messages([
            ("system", """Extract accommodation preferences from the user's query.
            Return a JSON object with these fields (leave empty if not mentioned):
            {
                "location": "city or specific area",
                "check_in_date": "in YYYY-MM-DD format",
                "check_out_date": "in YYYY-MM-DD format",
                "guests": "number of guests",
                "rooms": "number of rooms",
                "budget_range": "price range per night"
            }
            """),
            ("human", "{query}")
        ])
        
        extraction_chain = extraction_prompt | llm | StrOutputParser()
        try:
            accommodation_params = json.loads(extraction_chain.invoke({"query": state.query}))
            state.context = {**state.context, "accommodation_params": accommodation_params} if state.context else {"accommodation_params": accommodation_params}
        except json.JSONDecodeError:
            state.context = state.context or {}
        
        # Get accommodation information using RAG
        rag_chain = setup_rag_chain_fine_tuned()
        retrieval_result = rag_chain.invoke({"input": f"hotels in {state.context.get('accommodation_params', {}).get('location', '')}"})
        
        # Generate response
        accommodation_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a hotel and accommodation expert. Provide detailed recommendations
            based on the user's preferences and the retrieved accommodation data. Include:
            
            - Suitable hotel/accommodation options
            - Price ranges and value considerations
            - Location benefits and proximity to attractions
            - Amenities and facilities
            - Guest ratings and reviews summary
            - Booking tips and optimal timing
            
            If specific accommodation data isn't available, provide general advice about
            accommodations in the requested location, typical options at different price points,
            and best areas to stay.
            
            Retrieved accommodation information:
            {context}
            
            Extracted accommodation parameters:
            {accommodation_params}
            """),
            ("human", "{query}")
        ])
        
        accommodation_chain = accommodation_prompt | llm | StrOutputParser()
        response = accommodation_chain.invoke({
            "query": state.query,
            "context": retrieval_result.get("answer", ""),
            "accommodation_params": json.dumps(state.context.get("accommodation_params", {}), indent=2)
        })
        
        return {"agent_response": response}
        
    except Exception as e:
        return {"error": str(e)}

def information_agent(state: AgentState) -> dict:
    """Answers general travel questions using RAG."""
    try:
        # Use RAG chain for travel information
        rag_chain = setup_rag_chain_fine_tuned()
        result = rag_chain.invoke({"input": state.query})
        
        # Enhance RAG response
        enhancement_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a knowledgeable travel information specialist. Review and enhance
            the retrieved information to provide a comprehensive, accurate response to the user's query.
            
            If the retrieved information is incomplete, add relevant details from your knowledge while
            clearly distinguishing between retrieved facts and general knowledge.
            
            Focus on providing practical, useful information that directly addresses the user's needs.
            Include cultural insights, traveler tips, and seasonal considerations when relevant.
            
            Retrieved information:
            {rag_response}
            """),
            ("human", "{input}")
        ])
        
        enhancement_chain = enhancement_prompt | llm | StrOutputParser()
        enhanced_response = enhancement_chain.invoke({
            "input": state.query,
            "rag_response": result.get("answer", "")
        })
        
        return {"agent_response": enhanced_response}
        
    except Exception as e:
        return {"error": str(e)}

def generate_final_response(state: AgentState) -> dict:
    """Generates the final, polished response to the user."""
    formatting_prompt = ChatPromptTemplate.from_messages([
        ("system", """You are a friendly, helpful travel assistant. Format the specialized agent's response
        into a clear, well-structured, and engaging reply. Maintain all the factual information and advice
        while improving readability with:
        
        - A warm, conversational tone
        - Logical organization with headings where appropriate
        - Bullet points for lists
        - Bold text for important information
        - Emojis where appropriate (but not excessive)
        
        Make sure the response completely addresses the user's query. Add a brief, friendly closing
        that invites further questions.
        
        Original agent response:
        {agent_response}
        """),
        ("human", "{query}")
    ])
    
    formatting_chain = formatting_prompt | llm | StrOutputParser()
    final_response = formatting_chain.invoke({
        "query": state.query,
        "agent_response": state.agent_response
    })
    
    return {"agent_response": final_response}

def handle_error(state: AgentState) -> dict:
    """Handles errors and provides a graceful fallback response."""
    error_prompt = ChatPromptTemplate.from_messages([
        ("system", """You are a helpful travel assistant. The system encountered an error while
        processing the user's query. Provide a helpful response that:
        
        1. Acknowledges the issue
        2. Offers general travel advice related to their query
        3. Suggests how they might rephrase their question for better results
        
        Error message: {error}
        """),
        ("human", "{query}")
    ])
    
    error_chain = error_prompt | llm | StrOutputParser()
    fallback_response = error_chain.invoke({
        "query": state.query,
        "error": state.error or "Unknown error occurred"
    })
    
    return {"agent_response": fallback_response}

def create_travel_assistant_graph():
    """Creates the travel assistant graph using LangGraph."""
    workflow = StateGraph(AgentState)
    
    # Add nodes
    workflow.add_node("router", router_agent)
    workflow.add_node("itinerary_agent", itinerary_agent)
    workflow.add_node("flight_agent", flight_agent)
    workflow.add_node("accommodation_agent", accommodation_agent)
    workflow.add_node("information_agent", information_agent)
    workflow.add_node("response_generator", generate_final_response)
    workflow.add_node("error_handler", handle_error)
    
    # Define conditional edge routing
    def router_edges(state):
        if state.agent_executor == "itinerary_agent":
            return "itinerary_agent"
        elif state.agent_executor == "flight_agent":
            return "flight_agent"
        elif state.agent_executor == "accommodation_agent":
            return "accommodation_agent"
        else:
            return "information_agent"
    
    def agent_edges(state):
        if state.error is not None:
            return "error_handler"
        else:
            return "response_generator"
    
    # Set entry point
    workflow.set_entry_point("router")
    
    # Connect router to agents
    workflow.add_conditional_edges("router", router_edges)
    
    # Connect agents to next nodes
    for agent in ["itinerary_agent", "flight_agent", "accommodation_agent", "information_agent"]:
        workflow.add_conditional_edges(agent, agent_edges)
    
    # Connect to end
    workflow.add_edge("response_generator", END)
    workflow.add_edge("error_handler", END)
    
    # Compile the graph
    return workflow.compile()

# Initialize the travel assistant
travel_assistant = create_travel_assistant_graph()

# Chainlit code
@cl.on_chat_start
async def on_chat_start():
    await cl.Message(content="👋 Hi there! I'm your AI Travel Assistant. Whether you need help planning an itinerary, finding flights, booking accommodations, or just want travel information, I'm here to help. What can I assist you with today?").send()
    cl.user_session.set("travel_assistant", travel_assistant)

@cl.on_message
async def on_message(message: cl.Message):
    travel_assistant = cl.user_session.get("travel_assistant")
    
    # Show thinking message
    thinking_msg = cl.Message(content="Thinking...", author="Travel Assistant")
    await thinking_msg.send()
    
    # Process the query
    try:
        # Extract previous messages for chat history
        chat_history = []
        for msg in cl.user_session.get("message_history", []):
            if msg["role"] == "user":
                chat_history.append({"role": "user", "content": msg["content"]})
            else:
                chat_history.append({"role": "assistant", "content": msg["content"]})
        
        # Call the travel assistant
        result = travel_assistant.invoke({
            "query": message.content,
            "chat_history": chat_history
        })
        
        # Store this exchange in chat history
        if "message_history" not in cl.user_session:
            cl.user_session.set("message_history", [])
        
        cl.user_session.get("message_history").append({"role": "user", "content": message.content})
        cl.user_session.get("message_history").append({"role": "assistant", "content": result["agent_response"]})
        
        # Update thinking message with the response
        await thinking_msg.update(content=result["agent_response"])
        
    except Exception as e:
        import traceback
        error_msg = f"Error processing your request: {str(e)}\n\n{traceback.format_exc()}"
        await thinking_msg.update(content=f"I encountered an error while processing your request. Please try again with a more specific travel question.")
        print(error_msg)