In [None]:
# %run "wall_of_imports.ipynb"
# %run "agent_state.ipynb"

In [None]:
def flight_agent(state: AgentState) -> dict:
    """Handles flight-related questions and searches."""
    try:
        # Extract flight search parameters
        extraction_prompt = ChatPromptTemplate.from_template("""Extract flight search parameters from the user's query.
            Return a JSON object with these fields (leave empty if not mentioned):
        {{
            "origin": "origin airport or city code",
            "destination": "destination airport or city code",
            "departure_date": "departure date in YYYY-MM-DD format",
            "return_date": "return date in YYYY-MM-DD format (if round-trip)",
            "num_passengers": "number of passengers",
            "cabin_class": "economy/business/first",
            "price_range": "budget constraints",
            "airline_preferences": ["preferred airlines"]
        }}
        
        Query: {input}
        """)       
        
        extraction_chain = extraction_prompt | llm | StrOutputParser()

        try:
            flight_params = json.loads(extraction_chain.invoke({"input": state.query}))  # Changed from input.query to state.query
            state.context.update({"flight_params": flight_params})
        except json.JSONDecodeError:
            state.context.update({"flight_params": {}})
        
        # Get flight information using RAG
        rag_chain = setup_rag_chain()
        retrieval_result = rag_chain.invoke({"input": state.query})

        flight_prompt = ChatPromptTemplate.from_template("""You are a flight search specialist. Provide helpful information about flights
            based on the retrieved flight data and the user's query. Include details about available flights matching the criteria, 
            price ranges and fare comparisons, airline options, departure/arrival times, travel duration, layovers (if applicable), 
            and booking recommendations.
            
            If exact flight information isn't available in the retrieved data, provide general advice
            about the requested route, typical prices, and best booking strategies.
            
            Retrieved flight information:
            {context_str}
            
            Extracted flight parameters:
            {parameters}
        
            Query: {input}
            """)
        
        # Format the context and parameters for the prompt
        context_str = retrieval_result.get("answer", "")
        parameters_str = json.dumps(state.context.get("flight_params", {}), indent=2)
        
        flight_chain = flight_prompt | llm | StrOutputParser()
        response = flight_chain.invoke({
            "input": state.query,
            "context_str": context_str,
            "parameters": parameters_str
        })
        
        return {
            "agent_response": response,
            "context": state.context,
            "query": state.query
        }
        
    except Exception as e:
        return {
            "error": str(e),
            "context": state.context,
            "query": state.query
        }

In [None]:
# def flight_agent(state: AgentState) -> dict:
#     """Handles flight-related questions and searches."""
#     try:
#         # Extract flight search parameters
#         extraction_prompt = ChatPromptTemplate.from_template("""Extract flight search parameters from the user's query.
#             Return a JSON object with these fields (leave empty if not mentioned):
#         {{
#             "origin": "origin airport or city code",
#             "destination": "destination airport or city code",
#             "departure_date": "departure date in YYYY-MM-DD format",
#             "return_date": "return date in YYYY-MM-DD format (if round-trip)",
#             "num_passengers": "number of passengers",
#             "cabin_class": "economy/business/first",
#             "price_range": "budget constraints",
#             "airline_preferences": ["preferred airlines"]
#         }}
        
#         Query: {input}
#         """)       
        
#         extraction_chain = extraction_prompt | llm | StrOutputParser()

#         try:
#             flight_params = json.loads(extraction_chain.invoke({"input": state.query}))  # Changed from input.query to state.query
#             state.context.update({"flight_params": flight_params})
#         except json.JSONDecodeError:
#             state.context.update({"flight_params": {}})
        
#         # Get flight information using RAG
#         rag_chain = setup_rag_chain()
#         retrieval_result = rag_chain.invoke({"input": state.query})

#         flight_prompt = ChatPromptTemplate.from_template("""You are a flight search specialist. Provide helpful information about flights
#             based on the retrieved flight data and the user's query. Include details about available flights matching the criteria, 
#             price ranges and fare comparisons, airline options, departure/arrival times, travel duration, layovers (if applicable), 
#             and booking recommendations.
            
#             If exact flight information isn't available in the retrieved data, provide general advice
#             about the requested route, typical prices, and best booking strategies.
            
#             Retrieved flight information:
#             {context_str}
            
#             Extracted flight parameters:
#             {parameters}
        
#             Query: {input}
#             """)
        
#         # Format the context and parameters for the prompt
#         context_str = retrieval_result.get("answer", "")
#         parameters_str = json.dumps(state.context.get("flight_params", {}), indent=2)
        
#         flight_chain = flight_prompt | llm | StrOutputParser()
#         response = flight_chain.invoke({
#             "input": state.query,
#             "context_str": context_str,
#             "parameters": parameters_str
#         })
        
#         return {
#             "agent_response": response,
#             "context": state.context,
#             "query": state.query
#         }
        
#     except Exception as e:
#         return {
#             "error": str(e),
#             "context": state.context,
#             "query": state.query
#         }