<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/T2SQL_EBM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q transformers peft accelerate bitsandbytes -q

In [None]:
import sqlite3
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import AutoPeftModelForCausalLM

# --- 1. LLM and EBM Initialization ---

# Load the Mistral-7B-text-to-sql model with PEFT adapter
peft_model_id = "frankmorales2020/Mistral-7B-text-to-sql"

# Check if CUDA is available and set the device accordingly
device = "cuda" if torch.cuda.is_available() else "cpu"

# Explicitly load the model on the selected device
model = AutoPeftModelForCausalLM.from_pretrained(
    peft_model_id,
    device_map={"": device},
    torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(peft_model_id)

# Set the pad_token_id to the eos_token_id
tokenizer.pad_token_id = tokenizer.eos_token_id


# Use the base Mistral model for the explainer LLM (same as t2sql_model)
explainer_model = model

## full code


This version incorporates the following corrections and improvements:

* Organized structure: Model loading first, then database setup, and other components.

* Efficient model loading: Loads the base Mistral model once and applies the PEFT adapter.

* Reusing the base model: Uses the same Mistral model instance for T2SQL and explanations.

* Simplified EBM: Includes a simplified EBM function.

* Simulated RLHF: Simulates human feedback and EBM updates.
Error handling: Includes error handling for query execution.

* Correct parameter substitution: Ensures the placeholder in the SQL query is replaced with the actual value.

* Corrected EBM call: The generate_sql_with_ebm function is called with both the customer_question and the generated_sql as arguments.

* Improved EBM Logic: The generate_sql_with_ebm function now checks for the presence of order_id in the sql_query to ensure it's more likely to be correct before returning it.

* Corrected customer question: The question is now formulated correctly to avoid ambiguity and provide context to the T2SQL model.

* pad_token_id in generate(): The pad_token_id is passed as an argument to the model.generate() function to suppress the warning.

This complete code provides a functional and illustrative example of the LLM-EBM-RLHF synergy for a T2SQL task. Remember that the EBM and RLHF parts are still simplified for demonstration purposes. You would need to implement a real EBM and a more sophisticated RLHF system in a production environment.

In [None]:
import sqlite3
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import AutoPeftModelForCausalLM

# --- 1. LLM and EBM Initialization ---

# Load the Mistral-7B-text-to-sql model with PEFT adapter
peft_model_id = "frankmorales2020/Mistral-7B-text-to-sql"

# Check if CUDA is available and set the device accordingly
device = "cuda" if torch.cuda.is_available() else "cpu"

# Explicitly load the model on the selected device
model = AutoPeftModelForCausalLM.from_pretrained(
    peft_model_id,
    device_map={"": device},
    torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

# Use the base Mistral model for the explainer LLM (same as t2sql_model)
explainer_model = model

# --- 2. Set up a SQLite database (example schema) ---

conn = sqlite3.connect('customer_orders.db')
cursor = conn.cursor()

cursor.execute('''
  CREATE TABLE IF NOT EXISTS customers (
    customer_id INTEGER PRIMARY KEY,
    customer_name TEXT NOT NULL
  );
''')

cursor.execute('''
  CREATE TABLE IF NOT EXISTS orders (
    order_id INTEGER PRIMARY KEY,
    customer_id INTEGER NOT NULL,
    order_date TEXT NOT NULL,
    order_status TEXT NOT NULL,
    FOREIGN KEY (customer_id) REFERENCES customers (customer_id)
  );
''')

# Insert some sample data
cursor.execute("INSERT INTO customers (customer_name) VALUES ('Alice')")
cursor.execute("INSERT INTO customers (customer_name) VALUES ('Bob')")
cursor.execute("INSERT INTO orders (customer_id, order_date, order_status) VALUES (1, '2024-11-10', 'Pending')")
cursor.execute("INSERT INTO orders (customer_id, order_date, order_status) VALUES (2, '2024-11-12', 'Shipped')")
conn.commit()


# --- 3. Define a function to generate SQL with EBM (simplified example) ---

def generate_sql_with_ebm(question, sql_query):
  """Generates SQL using a simplified EBM approach."""
  if "order status" in question.lower():
    # Check if the query contains the correct column name
    if "order_status" in sql_query.lower() and "orders" in sql_query.lower() and "order_id" in sql_query.lower():
      return sql_query
    else:
      return "SELECT order_status FROM orders WHERE order_id = ?"
  elif "order date" in question.lower():
    # Check if the query contains the correct column name
    if "order_date" in sql_query.lower() and "orders" in sql_query.lower() and "order_id" in sql_query.lower():
      return sql_query
    else:
      return "SELECT order_date FROM orders WHERE order_id = ?"
  else:
    # If none of the above conditions match, return a default query
    return "SELECT * FROM customers"


# --- 4. RLHF (Simplified Example) ---

def get_human_feedback(sql_query, result):
  """Simulates getting human feedback on the generated SQL."""
  print(f"Generated SQL: {sql_query}")
  print(f"Result: {result}")
  feedback = input("Is this correct? (yes/no): ")
  return 1 if feedback.lower() == "yes" else 0

def update_ebm(feedback, question, sql_query):
  """Simulates updating the EBM based on human feedback."""
  print(f"Updating EBM with feedback: {feedback}")
  # ... (Logic to adjust EBM based on feedback)


# --- 5. Main Execution ---

customer_question = "What is the order_status of the order with order_id 1?"  # Corrected question

# 6. Use the Mistral T2SQL LLM for preprocessing
prompt = f"""
### sqlite
SELECT * FROM customers;
SELECT * FROM orders;
### {customer_question}
"""

inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=512, pad_token_id=tokenizer.eos_token_id)
preprocessed_question = tokenizer.batch_decode(outputs, skip_special_tokens=True)

print("Preprocessed Question:", preprocessed_question)

# 7. Generate SQL with the EBM
sql_query = generate_sql_with_ebm(customer_question, preprocessed_question[0])
print("Generated SQL:", sql_query)

# 8. Execute the query
try:
    if "?" in sql_query:
        cursor.execute(sql_query, ("1",))
    else:
        cursor.execute(sql_query)
    result = cursor.fetchone()
    print("Query Result:", result)
except sqlite3.OperationalError as e:
    print(f"Error executing query: {e}")
    result = None

# 9. Get human feedback (simulated)
feedback = get_human_feedback(sql_query, result)

# 10. Update the EBM based on feedback (simulated)
update_ebm(feedback, customer_question, sql_query)

# 11. Generate an explanation using the Mistral explainer LLM
prompt = f"""Explain this SQL query to a customer: {sql_query}"""
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = explainer_model.generate(**inputs, max_new_tokens=128, pad_token_id=tokenizer.eos_token_id)
explanation = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print("Explanation:", explanation[0])

# 12. Close the database connection
conn.close()

## partial code 1

In [None]:
# --- 2. Set up a SQLite database (example schema) ---

conn = sqlite3.connect('customer_orders.db')
cursor = conn.cursor()

cursor.execute('''
  CREATE TABLE IF NOT EXISTS customers (
    customer_id INTEGER PRIMARY KEY,
    customer_name TEXT NOT NULL
  );
''')

cursor.execute('''
  CREATE TABLE IF NOT EXISTS orders (
    order_id INTEGER PRIMARY KEY,
    customer_id INTEGER NOT NULL,
    order_date TEXT NOT NULL,
    order_status TEXT NOT NULL,
    FOREIGN KEY (customer_id) REFERENCES customers (customer_id)
  );
''')



# Delete any existing records from the tables
cursor.execute("DELETE FROM customers")
cursor.execute("DELETE FROM orders")


# Insert some sample data
cursor.execute("INSERT INTO customers (customer_name) VALUES ('Alice')")
cursor.execute("INSERT INTO customers (customer_name) VALUES ('Bob')")
cursor.execute("INSERT INTO orders (customer_id, order_date, order_status) VALUES (1, '2024-11-10', 'Pending')")
cursor.execute("INSERT INTO orders (customer_id, order_date, order_status) VALUES (2, '2024-11-12', 'Shipped')")
conn.commit()

def check_database_records(cursor):
  """Checks the number of records in the database tables.

  Args:
    cursor: The database cursor object.

  Returns:
    True if the expected number of records are found, False otherwise.
  """
  # Check the number of records in the 'customers' table
  cursor.execute("SELECT COUNT(*) FROM customers")
  num_customers = cursor.fetchone()[0]
  if num_customers != 2:  # Expecting 2 customers
    print(f"Error: Expected 2 customers, found {num_customers}")
    return False

  # Check the number of records in the 'orders' table
  cursor.execute("SELECT COUNT(*) FROM orders")
  num_orders = cursor.fetchone()[0]
  if num_orders != 2:  # Expecting 2 orders
    print(f"Error: Expected 2 orders, found {num_orders}")
    return False

  return True  # All checks passed

print("Checking database records...")
# Check the number of records
if not check_database_records(cursor):
    print("Database check failed!")
    # ... (Handle the error, e.g., exit the program or retry data insertion)
else:
    print("Database check passed!")
    # ... (Proceed with the rest of the code)
print('\n')

# --- 3. Define a function to generate SQL with EBM (simplified example) ---


def generate_sql_with_ebm(question, sql_query):  # Add sql_query as an argument
  """Calculates the loss with LLM feedback.

  Args:
    question: The customer's question in natural language.
    sql_query: The generated SQL query.

  Returns:
    An SQL query.
  """
  # (In a real scenario, this would involve a trained EBM)
  # This example uses simple rules for demonstration
  if "order status" in question.lower():
    # Check if the query contains the correct column name
    if "order_status" in sql_query.lower() and "orders" in sql_query.lower():
      return sql_query  # Use the generated query if it's correct
    else:
      return "SELECT order_status FROM orders WHERE order_id = ?"  # Fallback to the correct query
  elif "order date" in question.lower():
    # Check if the query contains the correct column name
    if "order_date" in sql_query.lower() and "orders" in sql_query.lower():
      return sql_query  # Use the generated query if it's correct
    else:
      return "SELECT order_date FROM orders WHERE order_id = ?"  # Fallback to the correct query
  else:
    return "SELECT * FROM customers"  # Default query

# --- 4. RLHF (Simplified Example) ---

def get_human_feedback(sql_query, result):
  """Simulates getting human feedback on the generated SQL."""
  print(f"Generated SQL: {sql_query}")
  print(f"Result: {result}")
  feedback = input("Is this correct? (yes/no): ")
  return 1 if feedback.lower() == "yes" else 0

def update_ebm(feedback, question, sql_query):
  """Simulates updating the EBM based on human feedback."""
  print(f"Updating EBM with feedback: {feedback}")
  # ... (Logic to adjust EBM based on feedback)

# --- 5. Main Execution ---

customer_question = "What is the order_status of the order with order_id 1?"  # Corrected question

# 6. Use the Mistral T2SQL LLM for preprocessing
prompt = f"""
### sqlite
SELECT * FROM customers;
SELECT * FROM orders;
### {customer_question}
"""

# Assuming the model and tokenizer are already loaded in Step 1
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=512, pad_token_id=tokenizer.eos_token_id)
preprocessed_question = tokenizer.batch_decode(outputs, skip_special_tokens=True)

print("Preprocessed Question:", preprocessed_question)

# 7. Generate SQL with the EBM
sql_query = generate_sql_with_ebm(customer_question, preprocessed_question[0])
print("Generated SQL:", sql_query)

# 8. Execute the query
try:
    if "?" in sql_query:
        cursor.execute(sql_query, ("1",))
    else:
        cursor.execute(sql_query)
    result = cursor.fetchall()  # Fetch all records
    print("Query Result:", result)
except sqlite3.OperationalError as e:
    print(f"Error executing query: {e}")
    result = None


# 9. Get human feedback (simulated)
feedback = get_human_feedback(sql_query, result)

# 10. Update the EBM based on feedback (simulated)
update_ebm(feedback, customer_question, sql_query)

# 11. Generate an explanation using the Mistral explainer LLM
prompt = f"""Explain this SQL query to a customer: {sql_query}"""
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = explainer_model.generate(**inputs, max_new_tokens=128, pad_token_id=tokenizer.eos_token_id)
explanation = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print("Explanation:", explanation[0])

Checking database records...
Database check passed!




## partial code 2 corrected

In [None]:
# --- 2. Set up a SQLite database (example schema) ---

conn = sqlite3.connect('customer_orders.db')
cursor = conn.cursor()

cursor.execute('''
  CREATE TABLE IF NOT EXISTS customers (
    customer_id INTEGER PRIMARY KEY,
    customer_name TEXT NOT NULL
  );
''')

cursor.execute('''
  CREATE TABLE IF NOT EXISTS orders (
    order_id INTEGER PRIMARY KEY,
    customer_id INTEGER NOT NULL,
    order_date TEXT NOT NULL,
    order_status TEXT NOT NULL,
    FOREIGN KEY (customer_id) REFERENCES customers (customer_id)
  );
''')

# Insert some sample data
cursor.execute("INSERT INTO customers (customer_name) VALUES ('Alice')")
cursor.execute("INSERT INTO customers (customer_name) VALUES ('Bob')")
cursor.execute("INSERT INTO orders (customer_id, order_date, order_status) VALUES (1, '2024-11-10', 'Pending')")
cursor.execute("INSERT INTO orders (customer_id, order_date, order_status) VALUES (2, '2024-11-12', 'Shipped')")
conn.commit()


# --- 3. Define a function to generate SQL with EBM (simplified example) ---

def generate_sql_with_ebm(question, sql_query):
  """Generates SQL using a simplified EBM approach."""
  if "order status" in question.lower():
    # Check if the query contains the correct column name
    if "order_status" in sql_query.lower() and "orders" in sql_query.lower() and "order_id" in sql_query.lower():
      return sql_query
    else:
      return "SELECT order_status FROM orders WHERE order_id = ?"
  elif "order date" in question.lower():
    # Check if the query contains the correct column name
    if "order_date" in sql_query.lower() and "orders" in sql_query.lower() and "order_id" in sql_query.lower():
      return sql_query
    else:
      return "SELECT order_date FROM orders WHERE order_id = ?"
  else:
    # If none of the above conditions match, return a default query
    return "SELECT * FROM customers"


# --- 4. RLHF (Simplified Example) ---

def get_human_feedback(sql_query, result):
  """Simulates getting human feedback on the generated SQL."""
  print(f"Generated SQL: {sql_query}")
  print(f"Result: {result}")
  feedback = input("Is this correct? (yes/no): ")
  return 1 if feedback.lower() == "yes" else 0

def update_ebm(feedback, question, sql_query):
  """Simulates updating the EBM based on human feedback."""
  print(f"Updating EBM with feedback: {feedback}")
  # ... (Logic to adjust EBM based on feedback)


# --- 5. Main Execution ---

customer_question = "What is the order_status of the order with order_id 1?"  # Corrected question

# 6. Use the Mistral T2SQL LLM for preprocessing
prompt = f"""
### sqlite
SELECT * FROM customers;
SELECT * FROM orders;
### {customer_question}
"""

inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=512, pad_token_id=tokenizer.eos_token_id)
preprocessed_question = tokenizer.batch_decode(outputs, skip_special_tokens=True)

print("Preprocessed Question:", preprocessed_question)

# 7. Generate SQL with the EBM
sql_query = generate_sql_with_ebm(customer_question, preprocessed_question[0])
print("Generated SQL:", sql_query)

# 8. Execute the query
try:
    if "?" in sql_query:
        cursor.execute(sql_query, ("1",))
    else:
        cursor.execute(sql_query)
    result = cursor.fetchone()
    print("Query Result:", result)
except sqlite3.OperationalError as e:
    print(f"Error executing query: {e}")
    result = None

# 9. Get human feedback (simulated)
feedback = get_human_feedback(sql_query, result)

# 10. Update the EBM based on feedback (simulated)
update_ebm(feedback, customer_question, sql_query)

# 11. Generate an explanation using the Mistral explainer LLM
prompt = f"""Explain this SQL query to a customer: {sql_query}"""
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = explainer_model.generate(**inputs, max_new_tokens=128, pad_token_id=tokenizer.eos_token_id)
explanation = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print("Explanation:", explanation[0])

# Close the database connection (this was missing in the previous response)
conn.close()