<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/AGENTIC_T2SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# From the provided reference:
# Assume these are already installed as per the notebook:
!pip install -U langchain-community -q
!pip install -U crewai -q
!pip install 'crewai [tools]' -q
!pip install transformers -U -q
!pip install colab-env -q
!pip install unsloth -q
!pip install torch -q

In [None]:
import os
import torch
import warnings
from typing import Any, List, Dict, Optional

# Ensure all necessary Langchain/Transformers/Unsloth imports are here
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
from langchain_core.outputs import ChatResult, ChatGeneration

# Import PromptTemplate and LLMChain for the new approach
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

# Unsloth and Transformers imports for model loading
from unsloth import FastLanguageModel
from transformers import pipeline, AutoConfig # Make sure AutoConfig is imported

# Import BaseTool if you still want to use your tool class structure
from langchain.tools import BaseTool

# --- 1. Custom LLM Wrapper (UnslothCrewAILLM) ---
# This class makes your fine-tuned model compatible with Langchain.
# (Keep the same class definition from the last attempt as it's the most compliant)
class UnslothCrewAILLM(BaseChatModel):
    model: Any
    tokenizer: Any
    pipeline: Any = None
    max_new_tokens: int = 1024
    temperature: float = 0.1
    do_sample: bool = False
    trust_remote_code: bool = True

    def __init__(self, model, tokenizer, pipeline=None, max_new_tokens=1024, temperature=0.1, do_sample: bool = False, trust_remote_code=True):
        super().__init__(
            model=model,
            tokenizer=tokenizer,
            pipeline=pipeline,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=do_sample,
            trust_remote_code=trust_remote_code
        )

        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Any = None,
        **kwargs: Any,
    ) -> ChatResult:
        if not messages:
            raise ValueError("No messages provided to the LLM wrapper.")

        # Langchain often sends a list of messages, take the last one as the primary prompt
        final_message_content = messages[-1].content

        if self.pipeline:
            try:
                response = self.pipeline(
                    final_message_content,
                    num_return_sequences=1,
                    return_full_text=False,
                    max_new_tokens=self.max_new_tokens,
                    temperature=self.temperature,
                    do_sample=self.do_sample,
                )
                generated_text = response[0].get('generated_text', '').strip() if response else ''
            except Exception as e:
                print(f"Error during pipeline generation in wrapper: {e}")
                generated_text = f"Error generating response: {e}"
        elif self.model and self.tokenizer:
            try:
                max_input_length = getattr(self.tokenizer, 'model_max_length', self.max_new_tokens)
                inputs = self.tokenizer(final_message_content, return_tensors="pt", truncation=True, max_length=max_input_length).to(self.model.device)

                if self.tokenizer.pad_token_id is None:
                    self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=self.max_new_tokens,
                    temperature=self.temperature,
                    do_sample=self.do_sample,
                    pad_token_id=self.tokenizer.pad_token_id,
                    stopping_criteria=stop,
                )
                input_length = inputs.input_ids.shape[1]
                generated_ids = outputs[0, input_length:]
                generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
            except Exception as e:
                print(f"Error during manual generation in wrapper: {e}")
                import traceback
                traceback.print_exc()
                generated_text = f"Error generating response: {e}"
        else:
            generated_text = "Error: Model or pipeline not loaded in wrapper."

        message = AIMessage(content=generated_text)
        generation = ChatGeneration(message=message)
        return ChatResult(generations=[generation])

    @property
    def _llm_type(self) -> str:
        return "unsloth_transformer_wrapper"

    def supports_stop_words(self) -> bool:
        """Returns whether the model supports stop words."""
        return True

    @property
    def supports_control_chars(self) -> bool:
        """Returns whether the model supports control characters."""
        return False

    # Add dummy implementations for other BaseChatModel methods for compatibility
    # Implement stream, invoke, batch methods for better Langchain compatibility
    # For this example, we can delegate _invoke to _generate
    def _stream(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Any = None, **kwargs: Any):
        """Implement stream method (not used in this wrapper's logic, but required by BaseChatModel)."""
        raise NotImplementedError("Streaming is not implemented for this wrapper.")

    def _invoke(self, prompt: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Any = None, **kwargs: Any):
        """Implement invoke method (required by BaseChatModel)."""
        # Delegate to generate and return the first message
        return self._generate(prompt, stop=stop, run_manager=run_manager, **kwargs).generations[0].message

    def _batch(self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None, run_manager: Any = None, **kwargs: Any) -> List[ChatResult]:
         """Implement batch method (required by BaseChatModel)."""
         return [self._generate(msgs, stop=stop, run_manager=run_manager, **kwargs) for msgs in messages]

    async def _agenerate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Any = None,
        **kwargs: Any,
    ) -> ChatResult:
        return self._generate(messages, stop, run_manager, **kwargs)

    async def _astream(self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Any = None, **kwargs: Any):
         """Implement async stream method."""
         raise NotImplementedError("Async streaming is not implemented for this wrapper.")

    async def _ainvoke(self, prompt: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Any = None, **kwargs: Any):
         """Implement async invoke method."""
         return (await self._agenerate(prompt, stop=stop, run_manager=run_manager, **kwargs)).generations[0].message

    async def _abatch(self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None, run_manager: Any = None, **kwargs: Any) -> List[ChatResult]:
         """Implement async batch method."""
         import asyncio
         return await asyncio.gather(*[self._agenerate(msgs, stop=stop, run_manager=run_manager, **kwargs) for msgs in messages])


# --- 2. Database Schema Definition for Flight Planning ---
db_schema = {
    "tables": {
        "flights": ['flight_id', 'departure_airport', 'arrival_airport', 'departure_time', 'arrival_time', 'aircraft_type', 'status', 'price'],
        "airports": ['airport_code', 'airport_name', 'city', 'country'],
        "passengers": ['passenger_id', 'first_name', 'last_name', 'email'],
        "bookings": ['booking_id', 'flight_id', 'passenger_id', 'booking_date', 'seat_number']
    }
}
db_schema_string_for_prompt = str(db_schema)

# --- 3. Model Loading (using the model from the reference) ---
fine_tuned_model_id = "frankmorales2020/deepseek_r1_text2sql_finetuned"
max_seq_length = 2048
load_in_4bit = True

print(f"\n--- Attempting Direct LLM Loading for {fine_tuned_model_id} using Unsloth ---")

# Determine optimal dtype for Unsloth
unsloth_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16

model = None
tokenizer = None
unsloth_wrapper_pipeline = None
llm_instance = None # Renamed from llm_for_agents for clarity in this new approach

try:
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=FutureWarning)
        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=fine_tuned_model_id,
            max_seq_length=max_seq_length,
            dtype=unsloth_dtype,
            load_in_4bit=load_in_4bit,
            trust_remote_code=True,
        )
    print("Model and Tokenizer loaded successfully using Unsloth.")

    try:
        # You can still create the pipeline if you prefer, or rely solely on manual generation
        unsloth_wrapper_pipeline = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            max_new_tokens=1024,
            temperature=0.1,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
            return_full_text=False,
        )
        print("Text generation pipeline created.")
    except Exception as e:
        print(f"Warning: Could not create transformers pipeline: {e}. Falling back to manual generation.")
        unsloth_wrapper_pipeline = None # Ensure pipeline is None if creation fails

    # Instantiate your custom LLM
    llm_instance = UnslothCrewAILLM(
        model=model,
        tokenizer=tokenizer,
        pipeline=unsloth_wrapper_pipeline, # Pass the pipeline or None
        max_new_tokens=1024,
        temperature=0.1,
        do_sample=False,
        trust_remote_code=True,
    )
    print("UnslothCrewAILLM instance created.")

except ImportError as e:
    print(f"\n-- Skipping model loading: Unsloth or necessary libraries not installed, or compatible GPU/CUDA setup not found. Error: {e}")
    print("Please ensure you have 'unsloth' and 'torch' installed and a compatible GPU/CUDA setup.")
except Exception as e:
    print(f"\n--- An error occurred during model loading (Unsloth): {e} ---")
    import traceback
    traceback.print_exc()

# --- 4. Define the SQL Query Executor Tool (as a Langchain BaseTool) ---
# Keep the same tool definition
class SQLQueryExecutorTool(BaseTool):
    name: str = "SQL Query Executor"
    description: str = "Executes a given SQL query against the flight database and returns the results or errors."

    def _run(self, query: str) -> str:
        print(f"\n--- Attempting to execute SQL query: ---\n{query}\n--------------------------------------")
        # Simple validation/simulation
        if "DROP TABLE" in query.upper() or "DELETE FROM" in query.upper():
            return "Error: Harmful SQL query detected and blocked for safety."
        # Add a check for the specific flight query pattern
        if "SELECT" in query.upper() and "FROM flights" in query.lower() and "'JFK'" in query and "'LAX'" in query and "2025-07-01" in query:
             return "SQL executed successfully. Sample results for flight query: [{'flight_id': 101, 'departure_airport': 'JFK', 'arrival_airport': 'LAX', 'price': 450.00}, {'flight_id': 105, 'departure_airport': 'JFK', 'arrival_airport': 'LAX', 'price': 520.00}]"
        elif "category = 'Electronics'" in query: # Keep old simulated results if needed for other tests
            return "SQL executed successfully. Sample results: [{'name': 'Laptop', 'price': 1200}, {'name': 'Smartphone', 'price': 800}]"
        elif "orders made after 2023-01-01" in query: # Keep old simulated results if needed for other tests
            return "SQL executed successfully. Sample results: [{'order_id': 1, 'order_date': '2023-02-15'}, {'order_id': 2, 'order_date': '2024-01-20'}]"
        elif not query.strip().lower().startswith("select"):
             return "Error: Only SELECT queries are supported by this tool for safety and simplicity in this demo."
        else:
            if "SELECT" in query.upper() and "FROM" in query.upper():
                return "SQL executed successfully. (Simulated) No specific results available for this general query."
            else:
                return "Error: Invalid or unexecutable SQL query format (simulated error)."

sql_executor_tool = SQLQueryExecutorTool()
print("\nSQL Query Executor Tool defined.")


# --- 5. Define the Prompt Template for SQL Generation ---
sql_gen_template = """Translate the following natural language query into a precise SQL query based on the provided database schema.

Database Schema:
{db_schema}

Natural Language Query:
{query}

Output ONLY the SQL query string, no additional text, explanation, or formatting like markdown.

SQL:
"""

sql_gen_prompt = PromptTemplate(
    input_variables=["db_schema", "query"],
    template=sql_gen_template,
)
print("\nSQL Generation Prompt Template defined.")


In [4]:
# --- 6. Create the LLM Chain for SQL Generation ---

if llm_instance is None:
     print("\nERROR: LLM instance is NOT available. Cannot create LLM Chain.")
else:
    try:
        sql_gen_chain = LLMChain(
            llm=llm_instance,
            prompt=sql_gen_prompt,
            verbose=True, # Set verbose to True to see the prompt sent to the LLM
        )
        print("\nLLMChain for SQL generation created.")

        # --- 7. Define the Natural Language Query ---
        flight_query = "Find all flights departing from 'JFK' to 'LAX' after 2025-07-01 and their prices."

        print(f"\n--- Running Langchain Flow for query: \"{flight_query}\" ---")

        # --- 8. Run the LLMChain to generate SQL ---
        # The LLMChain will take the prompt template, format it with inputs,
        # and pass the resulting messages to the llm_instance._generate method.
        print("\n--- Generating SQL using LLMChain ---")
        generated_sql_result = sql_gen_chain.run(db_schema=db_schema_string_for_prompt, query=flight_query)

        # The output from LLMChain.run() is typically the generated text
        generated_sql = generated_sql_result.strip()

        # Post-process to try and get just the SQL line (reuse parsing logic)
        final_generated_sql = generated_sql.split(';')[0].strip() if ';' in generated_sql else generated_sql.split('\n')[0].strip()


        print(f"\n--- Generated SQL: ---")
        print(final_generated_sql)

        # --- 9. Manually execute the generated SQL using the Tool ---
        print("\n--- Executing Generated SQL using Tool ---")
        tool_execution_result = sql_executor_tool.run(final_generated_sql)

        print(f"\n--- Tool Execution Result: ---")
        print(tool_execution_result)

        print("\n### Langchain Flow Finished ###")

    except Exception as e:
        print(f"\n--- An error occurred during the Langchain flow: {e} ---")
        import traceback
        traceback.print_exc()

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



LLMChain for SQL generation created.

--- Running Langchain Flow for query: "Find all flights departing from 'JFK' to 'LAX' after 2025-07-01 and their prices." ---

--- Generating SQL using LLMChain ---


[1m> Entering new LLMChain chain...[0m
Prompt after formatting:
[32;1m[1;3mTranslate the following natural language query into a precise SQL query based on the provided database schema.

Database Schema:
{'tables': {'flights': ['flight_id', 'departure_airport', 'arrival_airport', 'departure_time', 'arrival_time', 'aircraft_type', 'status', 'price'], 'airports': ['airport_code', 'airport_name', 'city', 'country'], 'passengers': ['passenger_id', 'first_name', 'last_name', 'email'], 'bookings': ['booking_id', 'flight_id', 'passenger_id', 'booking_date', 'seat_number']}}

Natural Language Query:
Find all flights departing from 'JFK' to 'LAX' after 2025-07-01 and their prices.

Output ONLY the SQL query string, no additional text, explanation, or formatting like markdown.

SQL:
[0m

