<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/FP_AGENTIC_T2SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!nvidia-smi

Mon Jun 16 07:36:14 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   42C    P8             12W /   72W |       0MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
# From the provided reference:
# Assume these are already installed as per the notebook:
!pip install -U langchain-community -q
!pip install -U crewai -q
!pip install 'crewai [tools]' -q
!pip install transformers -U -q
!pip install colab-env -q
!pip install unsloth -q
!pip install torch -q

In [4]:
import os
import torch
import warnings
from typing import Any, List, Dict, Optional

# Ensure all necessary Langchain/Transformers/Unsloth imports are here
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
from langchain_core.outputs import ChatResult, ChatGeneration

# Import PromptTemplate and LLMChain for the new approach
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

# Unsloth and Transformers imports for model loading
from unsloth import FastLanguageModel
from transformers import pipeline, AutoConfig # Make sure AutoConfig is imported
from langchain.tools import BaseTool # Import BaseTool if you still want to use your tool class structure


In [None]:
# 1. Custom LLM Wrapper (UnslothCrewAILLM)
# This class makes your fine-tuned model compatible with Langchain.
# (Keep the same class definition from the last attempt as it's the most compliant)
class UnslothCrewAILLM(BaseChatModel):
    model: Any
    tokenizer: Any
    pipeline: Any = None
    max_new_tokens: int = 1024
    temperature: float = 0.1
    do_sample: bool = False
    trust_remote_code: bool = True # Changed True to bool

    def __init__(self, model, tokenizer, pipeline=None, max_new_tokens=1024, temperature=0.1, do_sample=False, trust_remote_code=True):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            pipeline=pipeline,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=do_sample,
            trust_remote_code=trust_remote_code
        )

        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Any = None,
        **kwargs: Any,
    ) -> ChatResult:
        if not messages:
            raise ValueError("No messages provided to the LLM wrapper.")

        # Langchain often sends a list of messages, take the last one as the primary prompt
        final_message_content = messages[-1].content

        if self.pipeline:
            try:
                response = self.pipeline(
                    final_message_content,
                    num_return_sequences=1,
                    return_full_text=False,
                    max_new_tokens=self.max_new_tokens,
                    temperature=self.temperature,
                    do_sample=self.do_sample,
                )
                generated_text = response[0].get('generated_text', '').strip() if response else ""
            except Exception as e:
                print(f"Error during pipeline generation in wrapper: {e}")
                generated_text = f"Error generating response: {e}"
        elif self.model and self.tokenizer:
            try:
                max_input_length = getattr(self.tokenizer, 'model_max_length', self.max_new_tokens)
                inputs = self.tokenizer(final_message_content, return_tensors="pt", truncation=True, max_length=max_input_length)

                if self.tokenizer.pad_token_id is None:
                    self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=self.max_new_tokens,
                    temperature=self.temperature,
                    do_sample=self.do_sample,
                    pad_token_id=self.tokenizer.pad_token_id,
                    stopping_criteria=stop,
                    # input_length = inputs.input_ids.shape[1] # This line was causing an error, removed for clarity.
                )
                # Ensure outputs has a shape to work with before slicing
                if outputs.shape[1] > inputs.input_ids.shape[1]:
                    generated_ids = outputs[0, inputs.input_ids.shape[1]:]
                else:
                    generated_ids = outputs[0] # If output is shorter, take the whole thing

                generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
            except Exception as e:
                print(f"Error during manual generation in wrapper: {e}")
                import traceback
                traceback.print_exc()
                generated_text = f"Error generating response: {e}"
        else:
            generated_text = "Error: Model or pipeline not loaded in wrapper."

        message = AIMessage(content=generated_text)
        generation = ChatGeneration(message=message)
        return ChatResult(generations=[generation])

    @property
    def _llm_type(self) -> str:
        return "unsloth_transformer_wrapper"

    def supports_stop_words(self) -> bool:
        """Returns whether the model supports stop words."""
        return True

    @property
    def supports_control_chars(self) -> bool:
        """Returns whether the model supports control characters."""
        return False

    # Add dummy implementations for other BaseChatModel methods for compatibility
    # Implement stream, invoke, batch methods for better Langchain compatibility
    # For this example, we can delegate _invoke to generate
    def _stream(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Any = None):
        """Implement stream method (not used in this wrapper's logic, but required by BaseChatModel)."""
        raise NotImplementedError("Streaming is not implemented for this wrapper.")

    def _invoke(self, prompt: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Any = None, **kwargs: Any):
        """Implement invoke method (required by BaseChatModel)."""
        # Delegate to generate and return the first message
        return self._generate(prompt, stop=stop, run_manager=run_manager, **kwargs).generations[0].message

    def _batch(self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None, run_manager: Any = None, **kwargs: Any):
        """Implement batch method (required by BaseChatModel)."""
        return [self._generate(msgs, stop=stop, run_manager=run_manager, **kwargs) for msgs in messages]

    async def _agenerate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Any = None,
        **kwargs: Any,
    ) -> ChatResult:
        return self._generate(messages, stop, run_manager, **kwargs)

    async def _astream(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Any = None):
        """Implement async stream method."""
        raise NotImplementedError("Async streaming is not implemented for this wrapper.")

    async def _ainvoke(self, prompt: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Any = None, **kwargs: Any):
        """Implement async invoke method."""
        return (await self._agenerate(prompt, stop=stop, run_manager=run_manager, **kwargs)).generations[0].message

    async def _abatch(self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None, run_manager: Any = None, **kwargs: Any):
        """Implement async batch method."""
        import asyncio
        return await asyncio.gather(*[self._agenerate(msgs, stop=stop, run_manager=run_manager, **kwargs) for msgs in messages])

# 2. Database Schema Definition for Flight Planning
db_schema = {
    "tables": {
        "flights": ['flight_id', 'departure_airport', 'arrival_airport', 'departure_time', 'arrival_time', 'aircraft_type', 'status', 'price'],
        "airports": ['airport_code', 'airport_name', 'city', 'country'],
        "passengers": ['passenger_id', 'first_name', 'last_name', 'email'],
        "bookings": ['booking_id', 'flight_id', 'passenger_id', 'booking_date', 'seat_number']
    }
}
db_schema_string_for_prompt = str(db_schema)

# 3. Model Loading (using the model from the reference)
fine_tuned_model_id = "frankmorales2020/deepseek_r1_text2sql_finetuned"
max_seq_length = 2048
load_in_4bit = True

print(f"\n--- Attempting Direct LLM Loading for {fine_tuned_model_id} using Unsloth ---")

# Determine optimal dtype for Unsloth
unsloth_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16

model = None
tokenizer = None
unsloth_wrapper_pipeline = None
llm_instance = None # Renamed from llm_for_agents for clarity in this new approach

try:
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=FutureWarning)

        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=fine_tuned_model_id,
            max_seq_length=max_seq_length,
            dtype=unsloth_dtype,
            load_in_4bit=load_in_4bit,
            trust_remote_code=True,
        )
        print("Model and Tokenizer loaded successfully using Unsloth.")

    try:
        # You can still create the pipeline if you prefer, or rely solely on manual generation
        unsloth_wrapper_pipeline = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            max_new_tokens=1024,
            temperature=0.1,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
            return_full_text=False,
        )
        print("Text generation pipeline created.")
    except Exception as e:
        print(f"Warning: Could not create transformers pipeline: {e}. Falling back to manual generation.")
        unsloth_wrapper_pipeline = None # Ensure pipeline is None if creation fails

    # Instantiate your custom LLM
    llm_instance = UnslothCrewAILLM(
        model=model,
        tokenizer=tokenizer,
        pipeline=unsloth_wrapper_pipeline, # Pass the pipeline or None
        max_new_tokens=1024,
        temperature=0.1,
        do_sample=False,
        trust_remote_code=True,
    )
    print("Unsloth CrewAILLL instance created.")

except ImportError as e:
    print(f"\n-- Skipping model loading: Unsloth or necessary libraries not installed, or compatible GPU not found.")
    print("Please ensure you have 'unsloth' and 'torch' installed and a compatible GPU/CUDA setup.")
except Exception as e:
    print(f"\n--- An error occurred during model loading (Unsloth): {e} ---")
    import traceback
    traceback.print_exc()

# 4. Define the SQL Query Executor Tool (as a Langchain BaseTool)
class SQLQueryExecutorTool(BaseTool):
    name: str = "SQL Query Executor"
    description: str = "Executes a given SQL query against the flight database and returns the results."

    def _run(self, query: str) -> str:
        print(f"\n--- Attempting to execute SQL query: \n{query}\n---")

        # Simple validation/simulation
        if "DROP TABLE" in query.upper() or "DELETE FROM" in query.upper():
            return "Error: Harmful SQL query detected and blocked for safety."

        # Add a check for the specific flight query pattern
        if "SELECT" in query.upper() and "FROM flights" in query.lower() and "'JFK'" in query and "'LAX'" in query:
            return "SQL executed successfully. Sample results for flight query: [{'flight_id': 101, 'departure_airport': 'JFK', 'arrival_airport': 'LAX', 'departure_time': '2025-07-01 10:00:00', 'price': 250.00}]"
        # Removed other irrelevant simulated results to focus on the flight planning example
        elif not query.strip().lower().startswith("select"):
            return "Error: Only SELECT queries are supported by this tool for safety and simplicity in this simulation."
        else:
            if "SELECT" in query.upper() and "FROM" in query.upper():
                return "SQL executed successfully. (Simulated) No specific results available for this general query."
            else:
                return "Error: Invalid or unexecutable SQL query format (simulated error)."

sql_executor_tool = SQLQueryExecutorTool()
print("\nSQL Query Executor Tool defined.")

# 5. Define the Flight Optimizer Tool (NEW)
class FlightOptimizerTool(BaseTool):
    name: str = "Flight Optimizer"
    description: str = "Optimizes a flight route between two airports (e.g., shortest, most efficient, considering real-time weather and air traffic) and returns a simulated optimal flight plan including estimated duration, fuel efficiency, and potential waypoints. Input should be two comma-separated airport codes and optionally optimization criteria (e.g., 'JFK,LAX,fuel_efficient,2025-07-01') or just 'JFK,LAX'."

    def _run(self, optimization_params: str) -> str:
        print(f"\n--- Attempting to optimize route for: {optimization_params} ---")
        try:
            parts = [p.strip().upper() for p in optimization_params.split(',')]
            origin = parts[0]
            destination = parts[1]
            optimization_criteria = "DEFAULT" # Default to a general optimization
            date_str = "TODAY"

            if len(parts) > 2:
                for part in parts[2:]:
                    if re.match(r"\d{4}-\d{2}-\d{2}", part) : # Simple YYYY-MM-DD date detection
                        date_str = part
                    elif part in ["FUEL_EFFICIENT", "SHORTEST_TIME", "LOW_TURBULENCE", "OPTIMAL"]:
                        optimization_criteria = part

            # Simulate optimization based on the article's concepts for YUL-ZSPD and other routes
            if origin == 'YUL' and destination == 'ZSPD':
                if optimization_criteria == "FUEL_EFFICIENT":
                    return (
                        f"**Optimized Flight Plan: {origin} to {destination} (Fuel Efficient)**\n"
                        f"  Date: {date_str}\n"
                        f"  Route: {origin} -> (AI Predicted Waypoint: 46.6180°N, -74.1754°W near Saint-Michel-des-Saints) -> {destination}\n"
                        f"  Estimated Flight Time: 12h 57m (aligned with AI-based Linear Regression model from study)\n"
                        f"  Estimated Fuel Consumption: 77340.77 kg (Excellent efficiency)\n"
                        f"  Considerations: Leveraging favorable wind patterns, avoiding major air traffic zones. (Source: 'AI-Driven Flight Path Optimization' Case Study 3)"
                    )
                elif optimization_criteria == "SHORTEST_TIME":
                    return (
                        f"**Optimized Flight Plan: {origin} to {destination} (Shortest Time)**\n"
                        f"  Date: {date_str}\n"
                        f"  Route: {origin} -> (Hypothetical Advanced AI Waypoint for speed) -> {destination}\n"
                        f"  Estimated Flight Time: ~12h 00m (Hypothetically even better with advanced AI and real-time data)\n"
                        f"  Estimated Fuel Consumption: Very Low\n"
                        f"  Considerations: Optimal altitudes for speed, dynamic adaptation to real-time weather and air traffic for minimal delays. (Source: 'AI-Driven Flight Path Optimization' Case Study 4)"
                    )
                elif optimization_criteria == "LOW_TURBULENCE":
                    return (
                        f"**Optimized Flight Plan: {origin} to {destination} (Low Turbulence)**\n"
                        f"  Date: {date_str}\n"
                        f"  Route: {origin} -> (Optimized for smoother air, e.g., slightly longer path to avoid storms) -> {destination}\n"
                        f"  Estimated Flight Time: ~14h 30m\n"
                        f"  Estimated Fuel Consumption: Moderate\n"
                        f"  Considerations: Proactive avoidance of known turbulence zones based on forecast data."
                    )
                else: # Default or general case for YUL-ZSPD (e.g., Vancouver layover)
                    return (
                        f"**Optimized Flight Plan: {origin} to {destination} (General Optimization)**\n"
                        f"  Date: {date_str}\n"
                        f"  Route: {origin} -> Vancouver (Layover: 49.2827°N, -123.1207°W) -> {destination}\n"
                        f"  Estimated Flight Time: 14h 08m (aligned with Vancouver layover study)\n"
                        f"  Estimated Fuel Consumption: 86644.55 kg (Good efficiency due to jet streams)\n"
                        f"  Considerations: Balancing time and fuel, potential benefits of layovers for advantageous conditions. (Source: 'AI-Driven Flight Path Optimization' Case Study 2)"
                    )
            # General JFK to LAX case
            elif origin == 'JFK' and destination == 'LAX':
                return f"**Optimized Flight Plan: {origin} to {destination}**\n  Date: {date_str}\n  Optimal route: JFK -> DEN -> LAX. Estimated duration: 5h 15m, Fuel efficiency: Very High. Weather conditions: Clear. Current air traffic: Moderate."
            else:
                return f"**Optimized Flight Plan: {origin} to {destination}**\n  Date: {date_str}\n  Direct route available. Estimated duration: ~4h, Fuel efficiency: Normal. Real-time data considered: Basic."

        except Exception as e:
            return f"Error optimizing route: Invalid input format. Please use 'ORIGIN,DESTINATION' or 'ORIGIN,DESTINATION,CRITERIA,YYYY-MM-DD'. Error: {e}"

flight_optimizer_tool = FlightOptimizerTool()
print("\nFlight Optimizer Tool defined.")


# 6. Define the Prompt Template for SQL Generation
# The prompt is now more explicit about considering optimization parameters for the LLM
sql_gen_template = """Translate the following natural language query into a precise SQL query based on the provided database schema.
If the query asks for flight planning, route optimization, or efficiency, identify the origin, destination, and any specific date/time or optimization criteria (e.g., 'most fuel-efficient', 'shortest time', 'low turbulence').
Prioritize extracting airport codes and any optimization criteria for the Flight Optimizer Tool.
Output ONLY the SQL query string, no additional text, explanation, or formatting like markdown.

Database Schema:
{db_schema}

Natural Language Query:
{query}

SQL:
"""

sql_gen_prompt = PromptTemplate(
    input_variables=["db_schema", "query"],
    template=sql_gen_template,
)
print("\nSQL Generation Prompt Template defined.")

In [7]:
#7. Create the LLM Chain for SQL Generation
if llm_instance is None:
    print("\nERROR: LLM instance is NOT available. Cannot create LLM Chain.")
else:
    try:
        sql_gen_chain = LLMChain(
            llm=llm_instance,
            prompt=sql_gen_prompt,
            verbose=True, # Set verbose to True to see the prompt sent to the LLM
        )
        print("\nLLMChain for SQL generation created.")

        # 8. Define the Natural Language Query - specifically targeting optimization and efficiency
        # Using YUL and ZSPD to match the article's detailed case studies
        combined_query = "What is the most fuel-efficient flight route from 'Montreal (YUL)' to 'Shanghai (ZSPD)' for today, and what are the flight prices for that date?"

        print(f"\n--- Running Langchain Flow for combined query: \"{combined_query}\"")

        # 9. Run the LLMChain to generate SQL for flight details/airport/date extraction
        print("\n--- Generating SQL using LLMChain for initial flight details/airport/date extraction ---")
        generated_sql_result = sql_gen_chain.run(db_schema=db_schema_string_for_prompt, query=combined_query)
        generated_sql = generated_sql_result.strip()
        final_generated_sql = generated_sql.split(';')[0].strip() if ';' in generated_sql else generated_sql.strip()

        print(f"\n--- Generated SQL: ---")
        print(final_generated_sql)

        # Extract origin/destination and date/optimization criteria from the original query for the optimizer
        import re
        # Enhanced regex to capture origin/destination, date, and potential optimization keywords
        match = re.search(r"from\s+'(?:[^()]+?\((\w+)\)|([^']+))'\s+to\s+'(?:[^()]+?\((\w+)\)|([^']+))'(?:.*for\s+(.*?))?(?:.*(most fuel-efficient|shortest time|low turbulence))?", combined_query, re.IGNORECASE)

        origin_airport = (match.group(1) or match.group(2)).upper() if match else None
        destination_airport = (match.group(3) or match.group(4)).upper() if match else None
        date_str_extracted = match.group(5).strip() if match and match.group(5) else "TODAY"
        optimization_criteria_extracted = match.group(6).replace(" ", "_").upper() if match and match.group(6) else "DEFAULT"

        # Simple date normalization for the tool if "today" is mentioned
        if "today" in date_str_extracted.lower():
            import datetime
            date_str_extracted = datetime.date.today().strftime("%Y-%m-%d")


        if origin_airport and destination_airport:
            # 10. Execute the generated SQL using the SQL Executor Tool
            # This part still runs to fulfill the "show me existing flight prices" part of the query.
            print("\n--- Executing Generated SQL using SQL Executor Tool ---")
            sql_tool_execution_result = sql_executor_tool.run(final_generated_sql)
            print(f"\n--- SQL Tool Execution Result: ---")
            print(sql_tool_execution_result)

            # 11. Call the Flight Optimizer Tool with extracted airports, date, and optimization criteria
            print(f"\n--- Calling Flight Optimizer Tool for {origin_airport},{destination_airport},{optimization_criteria_extracted},{date_str_extracted} ---")
            optimization_result = flight_optimizer_tool.run(f"{origin_airport},{destination_airport},{optimization_criteria_extracted},{date_str_extracted}")
            print(f"\n--- Flight Optimization Result: ---")
            print(optimization_result)
        else:
            print("\nCould not extract origin and destination airports for optimization from the query.")
            print("\n--- Executing Generated SQL using SQL Executor Tool (optimization not possible) ---")
            sql_tool_execution_result = sql_executor_tool.run(final_generated_sql)
            print(f"\n--- SQL Tool Execution Result: ---")
            print(sql_tool_execution_result)


        print("\n### Combined Agentic Flow Finished ###")

    except Exception as e:
        print(f"\n--- An error occurred during the Langchain flow: {e} ---")
        import traceback
        traceback.print_exc()

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



LLMChain for SQL generation created.

--- Running Langchain Flow for combined query: "What is the most fuel-efficient flight route from 'Montreal (YUL)' to 'Shanghai (ZSPD)' for today, and what are the flight prices for that date?"

--- Generating SQL using LLMChain for initial flight details/airport/date extraction ---


[1m> Entering new LLMChain chain...[0m
Prompt after formatting:
[32;1m[1;3mTranslate the following natural language query into a precise SQL query based on the provided database schema.
If the query asks for flight planning, route optimization, or efficiency, identify the origin, destination, and any specific date/time or optimization criteria (e.g., 'most fuel-efficient', 'shortest time', 'low turbulence').
Prioritize extracting airport codes and any optimization criteria for the Flight Optimizer Tool.
Output ONLY the SQL query string, no additional text, explanation, or formatting like markdown.

Database Schema:
{'tables': {'flights': ['flight_id', 'departure_