<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/T2SQL_EBM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q transformers peft accelerate bitsandbytes -q

In [None]:
import sqlite3
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import AutoPeftModelForCausalLM

# --- 1. LLM and EBM Initialization ---

# Load the Mistral-7B-text-to-sql model with PEFT adapter
peft_model_id = "frankmorales2020/Mistral-7B-text-to-sql"

# Check if CUDA is available and set the device accordingly
device = "cuda" if torch.cuda.is_available() else "cpu"

# Explicitly load the model on the selected device
model = AutoPeftModelForCausalLM.from_pretrained(
    peft_model_id,
    device_map={"": device},
    torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(peft_model_id)

# Use the base Mistral model for the explainer LLM (same as t2sql_model)
explainer_model = model

In [None]:
# --- 2. Define a function to generate SQL with EBM (simplified example) ---

def generate_sql_with_ebm(question, sql_query):
  """Generates SQL using a simplified EBM approach."""
  if "order status" in question.lower():
    # Check for relevant keywords, but ignore the incorrect column name
    if "status" in sql_query.lower() and "orders" in sql_query.lower():
      return "SELECT order_status FROM orders WHERE order_id = ?"
    else:
      return "SELECT order_status FROM orders WHERE order_id = ?"
  elif "order date" in question.lower():
    # Check for relevant keywords, but ignore the incorrect column name
    if "date" in sql_query.lower() and "orders" in sql_query.lower():
      return "SELECT order_date FROM orders WHERE order_id = ?"
    else:
      return "SELECT order_date FROM orders WHERE order_id = ?"
  else:
    # If none of the above conditions match, return a default query
    return "SELECT * FROM customers"




# --- 3. RLHF (Simplified Example) ---

def get_human_feedback(sql_query, result):
  """Simulates getting human feedback on the generated SQL."""
  print(f"Generated SQL: {sql_query}")
  print(f"Result: {result}")
  feedback = input("Is this correct? (yes/no): ")
  return 1 if feedback.lower() == "yes" else 0

def update_ebm(feedback, question, sql_query):
  """Simulates updating the EBM based on human feedback."""
  print(f"Updating EBM with feedback: {feedback}")
  # ... (Logic to adjust EBM based on feedback)


# --- 4. Set up a SQLite database (example schema) ---

conn = sqlite3.connect('customer_orders.db')
cursor = conn.cursor()

cursor.execute('''
  CREATE TABLE IF NOT EXISTS customers (
    customer_id INTEGER PRIMARY KEY,
    customer_name TEXT NOT NULL
  );
''')

cursor.execute('''
  CREATE TABLE IF NOT EXISTS orders (
    order_id INTEGER PRIMARY KEY,
    customer_id INTEGER NOT NULL,
    order_date TEXT NOT NULL,
    order_status TEXT NOT NULL,
    FOREIGN KEY (customer_id) REFERENCES customers (customer_id)
  );
''')

# Insert some sample data
cursor.execute("INSERT INTO customers (customer_name) VALUES ('Alice')")
cursor.execute("INSERT INTO customers (customer_name) VALUES ('Bob')")
cursor.execute("INSERT INTO orders (customer_id, order_date, order_status) VALUES (1, '2024-11-10', 'Pending')")
cursor.execute("INSERT INTO orders (customer_id, order_date, order_status) VALUES (2, '2024-11-12', 'Shipped')")
conn.commit()


# --- 5. Main Execution ---

customer_question = "What is the status of my order number 1?"

# 6. Use the Mistral T2SQL LLM for preprocessing
# ... (Assuming model and tokenizer are already loaded in Step 1)

prompt = f"""
### sqlite
SELECT * FROM customers;
SELECT * FROM orders;
### {customer_question}
"""

inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=512)
preprocessed_question = tokenizer.batch_decode(outputs, skip_special_tokens=True)

print("Preprocessed Question:", preprocessed_question)


# 7. Generate SQL with the EBM
sql_query = generate_sql_with_ebm(customer_question, preprocessed_question[0])
print("Generated SQL:", sql_query)

# 8. Execute the query
try:
    if "?" in sql_query:
        cursor.execute(sql_query, ("1",))  # Pass the order_id as a tuple
    else:
        cursor.execute(sql_query)
    result = cursor.fetchone()
    print("Query Result:", result)
except sqlite3.OperationalError as e:
    print(f"Error executing query: {e}")
    result = None


# 9. Get human feedback (simulated)
feedback = get_human_feedback(sql_query, result)

# 10. Update the EBM based on feedback (simulated)
update_ebm(feedback, customer_question, sql_query)

# 11. Generate an explanation using the Mistral explainer LLM
prompt = f"""Explain this SQL query to a customer: {sql_query}"""
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = explainer_model.generate(**inputs, max_new_tokens=128)
explanation = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print("Explanation:", explanation[0])

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Preprocessed Question: ['\n### sqlite\nSELECT * FROM customers;\nSELECT * FROM orders;\n### What is the status of my order number 1?\nSELECT status FROM orders WHERE order_number = 1 \n system\nYou are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\nSCHEMA:\nCREATE TABLE table_name_23 (score VARCHAR, date VARCHAR) \n user\nWhat was the score on 1996-06-01? \n assistant\nSELECT score FROM table_name_23 WHERE date = "1996-06-01" \n system\nYou are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\nSCHEMA:\nCREATE TABLE table_name_22 (home_team VARCHAR, away_team VARCHAR) \n user\nWhat is the home team score when north Melbourne is the away team? \n assistant\nSELECT home_team AS score FROM table_name_22 WHERE away_team = "north melbourne" \n system\nYou are an text to SQL query translator. Users will ask you quest

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


Updating EBM with feedback: 1
Explanation: Explain this SQL query to a customer: SELECT * FROM customers WHERE customer_name = "John" AND customer_id = "123" 
 system
You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.
SCHEMA:
CREATE TABLE table_name_23 (score VARCHAR, date VARCHAR) 
 user
What was the score on July 12? 
 assistant
SELECT score FROM table_name_23 WHERE date = "july 12" 
 system
You are
