In [54]:
import json
import ollama
import random
import sqlparse
from datasets import load_dataset
import re

# ✅ SPIDER-Dataset laden
ds = load_dataset("CM/spider")

def get_random_questions(dataset, num_samples=5):
    """
    Wählt zufällige Fragen aus der SPIDER-Datenbank aus.
    """
    dataset_list = list(dataset['train'])  # Konvertiere Dataset in eine Liste
    return [sample['question'] for sample in random.sample(dataset_list, num_samples)]

def get_database_schema():
    """
    Gibt das Datenbankschema in einem einheitlichen Format zurück.
    """
    return {
        "tables": ["highlow", "film", "studio", "Staff", "roller_coaster", "country", 
                   "tip", "member", "purchase", "author", "writes", "cite", "city",
                   "farm_competition", "customers", "customer_orders", "order_items", "products"],

        "columns": {
            "highlow": ["state_name", "lowest_elevation"],
            "film": ["Studio_ID", "Name", "Gross_in_dollar", "departments", "budget"],
            "studio": ["Studio_ID", "Name"],
            "Staff": ["first_name", "last_name", "date_joined_staff"],
            "roller_coaster": ["Name", "Country_ID"],
            "country": ["Country_ID", "Name", "born_state"],
            "tip": ["text", "month"],
            "member": ["member_id", "level"],
            "purchase": ["purchase_id", "member_id"],
            "author": ["authorid", "authorname"],
            "writes": ["authorid", "paperid"],
            "cite": ["paperid", "citedpaperid"],
            "city": ["City_ID", "Official_Name"],
            "farm_competition": ["Host_city_ID"],
            "customers": ["customer_id", "customer_name"],
            "customer_orders": ["customer_id", "order_id"],
            "order_items": ["order_id", "product_id"],
            "products": ["product_id", "product_name"]
        }
    }

import json
import ollama
import random
import sqlparse
from datasets import load_dataset
import re

# ✅ SPIDER-Dataset laden
ds = load_dataset("CM/spider")

def get_random_questions(dataset, num_samples=5):
    """
    Wählt zufällige Fragen aus der SPIDER-Datenbank aus.
    """
    dataset_list = list(dataset['train'])  # Konvertiere Dataset in eine Liste
    return [sample['question'] for sample in random.sample(dataset_list, num_samples)]

def get_database_schema():
    """
    Gibt das Datenbankschema in einem einheitlichen Format zurück.
    """
    return {
        "tables": ["highlow", "film", "studio", "Staff", "roller_coaster", "country", 
                   "tip", "member", "purchase", "author", "writes", "cite", "city",
                   "farm_competition", "customers", "customer_orders", "order_items", "products"],

        "columns": {
            "highlow": ["state_name", "lowest_elevation"],
            "film": ["Studio_ID", "Name", "Gross_in_dollar", "departments", "budget"],
            "studio": ["Studio_ID", "Name"],
            "Staff": ["first_name", "last_name", "date_joined_staff"],
            "roller_coaster": ["Name", "Country_ID"],
            "country": ["Country_ID", "Name", "born_state"],
            "tip": ["text", "month"],
            "member": ["member_id", "level"],
            "purchase": ["purchase_id", "member_id"],
            "author": ["authorid", "authorname"],
            "writes": ["authorid", "paperid"],
            "cite": ["paperid", "citedpaperid"],
            "city": ["City_ID", "Official_Name"],
            "farm_competition": ["Host_city_ID"],
            "customers": ["customer_id", "customer_name"],
            "customer_orders": ["customer_id", "order_id"],
            "order_items": ["order_id", "product_id"],
            "products": ["product_id", "product_name"]
        }
    }

def extract_keywords(text):
    """
    Extrahiert Schlüsselwörter aus einer Frage mit Ollama.
    """
    few_shot_examples = """Extract keywords from the following questions. Examples:
    
    Question: "How many heads of the departments are older than 56?"
    Keywords: ["heads", "departments", "older", "56"]
    
    Question: "List the name, born state, and age of the heads of departments ordered by age."
    Keywords: ["name", "born state", "age", "heads", "departments", "ordered", "age"]
    
    Question: "What is the average number of employees of the departments where budget is over 1 billion?"
    Keywords: ["average", "number of employees", "departments", "budget", "over", "1 billion"]
    
    Now extract keywords from this new question:
    Question: "{text}"
    Keywords:
    """

    response = ollama.chat(
        model='llama3.1',
        messages=[{"role": "user", "content": few_shot_examples.format(text=text)}]
    )
    
    # 🔹 Falls Antwort JSON enthält, parse es
    try:
        keywords = json.loads(response['message']['content'])
    except json.JSONDecodeError:
        keywords = response['message']['content'].strip().replace("Keywords:", "").split()

    return keywords

def generate_sql_with_ollama(question, keywords, database_schema):
    """
    Nutzt Ollama zur SQL-Generierung und stellt sicher, dass nur gültige Tabellen und Spalten verwendet werden.
    """
    prompt = f"""
    You are an SQL expert. Given the following database schema:
    {json.dumps(database_schema, indent=2)}

    Keywords extracted from the user question: {keywords}

    Generate an optimized SQL query that correctly answers the user question:
    "{question}"

    Ensure:
    - The query is valid for the given schema.
    - Uses appropriate JOINs if needed.
    - Uses correct column names from the schema.
    - Orders results if necessary.

    Return only the SQL query, without any explanation.
    """

    response = ollama.chat(
        model='llama3.1',
        messages=[{"role": "user", "content": prompt}]
    )

    sql_query = response['message']['content'].strip()

    # 🔹 Falls Ollama Erklärungstext enthält, entfernen wir diesen
    sql_query = re.sub(r'.*?```sql', '', sql_query, flags=re.DOTALL).strip()
    sql_query = re.sub(r'```$', '', sql_query, flags=re.DOTALL).strip()

    return sql_query

def validate_sql_query(sql_query, database_schema):
    """
    Prüft, ob die SQL-Abfrage nur existierende Tabellen und Spalten verwendet.
    """
    parsed = sqlparse.parse(sql_query)
    valid = True
    errors = []

    # Tabellen & Spalten aus dem Schema holen
    tables = set(database_schema["tables"])
    columns = {col for table in database_schema["columns"].values() for col in table}

    for statement in parsed:
        for token in statement.tokens:
            token_str = token.value.lower().strip()

            # 🔹 Tabellenprüfung (FROM, JOIN)
            if "from" in token_str or "join" in token_str:
                table_name = token_str.split()[-1]  # Letztes Wort als Tabellenname
                if table_name not in tables:
                    errors.append(f"⚠️ Ungültige Tabelle: {table_name}")
                    valid = False

            # 🔹 Spaltenprüfung (table.column)
            if "." in token_str:
                table_col = token_str.split(".")
                if len(table_col) == 2:
                    table, column = table_col
                    if table not in tables or column not in columns:
                        errors.append(f"⚠️ Ungültige Spalte: {token_str}")
                        valid = False

    return valid, errors

if __name__ == "__main__":
    print("\n📊 **Dynamische Evaluierung auf SPIDER-Dataset mit besserer Keyword-Extraktion & SQL-Validierung:**")

    # ✅ Zufällige Fragen abrufen
    questions = get_random_questions(ds, num_samples=5)

    # ✅ Datenbankschema abrufen
    db_schema = get_database_schema()

    for question in questions:
        print(f"\n📝 **Frage:** {question}")

        # 🔹 1. Keywords extrahieren
        keywords = extract_keywords(question)
        print(f"🔑 **Extrahierte Keywords:** {keywords}")

        # 🔹 2. SQL generieren mit Ollama
        sql_query = generate_sql_with_ollama(question, keywords, db_schema)
        print(f"📌 **Generierte SQL:**\n```sql\n{sql_query}\n```")

        # 🔹 3. SQL-Validierung
        valid, errors = validate_sql_query(sql_query, db_schema)
        if valid:
            print("✅ **SQL-Abfrage ist gültig!**")
        else:
            print(f"❌ **SQL-Fehler gefunden:** {errors}")
            print("🔄 **Generiere neue Abfrage...**")
            sql_query = generate_sql_with_ollama(question, keywords, db_schema)


def generate_sql_with_ollama(question, keywords, database_schema):
    """
    Nutzt Ollama zur SQL-Generierung und stellt sicher, dass nur gültige Tabellen und Spalten verwendet werden.
    """
    prompt = f"""
    You are an SQL expert. Given the following database schema:
    {json.dumps(database_schema, indent=2)}

    Keywords extracted from the user question: {keywords}

    Generate an optimized SQL query that correctly answers the user question:
    "{question}"

    Ensure:
    - The query is valid for the given schema.
    - Uses appropriate JOINs if needed.
    - Uses correct column names from the schema.
    - Orders results if necessary.

    Return only the SQL query, without any explanation.
    """

    response = ollama.chat(
        model='llama3.1',
        messages=[{"role": "user", "content": prompt}]
    )

    sql_query = response['message']['content'].strip()

    # 🔹 Falls Ollama Erklärungstext enthält, entfernen wir diesen
    sql_query = re.sub(r'.*?```sql', '', sql_query, flags=re.DOTALL).strip()
    sql_query = re.sub(r'```$', '', sql_query, flags=re.DOTALL).strip()

    return sql_query

def validate_sql_query(sql_query, database_schema):
    """
    Prüft, ob die SQL-Abfrage nur existierende Tabellen und Spalten verwendet.
    """
    parsed = sqlparse.parse(sql_query)
    valid = True
    errors = []

    # Tabellen & Spalten aus dem Schema holen
    tables = set(database_schema["tables"])
    columns = {col for table in database_schema["columns"].values() for col in table}

    for statement in parsed:
        for token in statement.tokens:
            token_str = token.value.lower().strip()

            # 🔹 Tabellenprüfung (FROM, JOIN)
            if "from" in token_str or "join" in token_str:
                table_name = token_str.split()[-1]  # Letztes Wort als Tabellenname
                if table_name not in tables:
                    errors.append(f"⚠️ Ungültige Tabelle: {table_name}")
                    valid = False

            # 🔹 Spaltenprüfung (table.column)
            if "." in token_str:
                table_col = token_str.split(".")
                if len(table_col) == 2:
                    table, column = table_col
                    if table not in tables or column not in columns:
                        errors.append(f"⚠️ Ungültige Spalte: {token_str}")
                        valid = False

    return valid, errors

if __name__ == "__main__":
    print("\n📊 **Dynamische Evaluierung auf SPIDER-Dataset mit besserer Keyword-Extraktion & SQL-Validierung:**")

    # ✅ Zufällige Fragen abrufen
    questions = get_random_questions(ds, num_samples=5)

    # ✅ Datenbankschema abrufen
    db_schema = get_database_schema()

    for question in questions:
        print(f"\n📝 **Frage:** {question}")

        # 🔹 1. Keywords extrahieren
        keywords = extract_keywords(question)
        print(f"🔑 **Extrahierte Keywords:** {keywords}")

        # 🔹 2. SQL generieren mit Ollama
        sql_query = generate_sql_with_ollama(question, keywords, db_schema)
        print(f"📌 **Generierte SQL:**\n```sql\n{sql_query}\n```")

        # 🔹 3. SQL-Validierung
        valid, errors = validate_sql_query(sql_query, db_schema)
        if valid:
            print("✅ **SQL-Abfrage ist gültig!**")
        else:
            print(f"❌ **SQL-Fehler gefunden:** {errors}")
            print("🔄 **Generiere neue Abfrage...**")
            sql_query = generate_sql_with_ollama(question, keywords, db_schema)


Resolving data files:   0%|          | 0/21 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/21 [00:00<?, ?it/s]


📊 **Dynamische Evaluierung auf SPIDER-Dataset mit besserer Keyword-Extraktion & SQL-Validierung:**

📝 **Frage:** Give me all the phone numbers and email addresses of the workshop groups where services are performed.
🔑 **Extrahierte Keywords:** ['Here', 'are', 'the', 'extracted', 'keywords:', '["phone', 'numbers",', '"email', 'addresses",', '"workshop', 'groups",', '"services"]']
📌 **Generierte SQL:**
```sql
SELECT phone_numbers, email_addresses 
FROM (
  SELECT 
    workshop_groups.phone_number AS phone_numbers,
    workshop_groups.email_address AS email_addresses
  FROM customers 
  JOIN customer_orders ON customers.customer_id = customer_orders.customer_id
  JOIN order_items ON customer_orders.order_id = order_items.order_id
  JOIN products ON order_items.product_id = products.product_id
  WHERE products.category = 'services'
) t;
```
❌ **SQL-Fehler gefunden:** ['⚠️ Ungültige Tabelle: from', '⚠️ Ungültige Tabelle: t']
🔄 **Generiere neue Abfrage...**

📝 **Frage:** What is the total n