In [None]:
import sqlite3
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import transformers

# Step 1: Function to capture user's question
def get_user_input():
    return input("Enter your question: ")

# Step 2: Extract Database Schema
def get_database_schema(db_path):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    schema = {}
    for table in tables:
        table = table[0]
        cursor.execute(f"PRAGMA table_info({table})")
        schema[table] = cursor.fetchall()
    conn.close()
    return schema

# Step 3: Create Keyword to Schema Mapping
def generate_schema_mapping(db_path):
    schema = get_database_schema(db_path)
    schema_mapping = {}
    for table, columns in schema.items():
        for column in columns:
            column_name = column[1].lower()
            schema_mapping[column_name] = (table, column_name)
    return schema_mapping

# Step 4: Map user's question to relevant columns and tables
def refined_map_keywords_to_schema(question, schema_mapping):
    tokens = set(question.lower().split())
    tables = set()
    columns = set()
    for token in tokens:
        if token in schema_mapping:
            table, column = schema_mapping[token]
            tables.add(table)
            columns.add(column)
    return list(tables), list(columns)

# Step 5: Generate a refined prompt for T5 model
def refined_mapping_based_generate_prompt(question, schema_mapping):
    relevant_tables, relevant_columns = refined_map_keywords_to_schema(question, schema_mapping)
    if relevant_tables and relevant_columns:
        tables_str = ', '.join(relevant_tables)
        columns_str = ', '.join(relevant_columns)
        prompt = f"Translate the question '{question}' into SQL. Relevant tables: {tables_str}. Relevant columns: {columns_str}."
    else:
        prompt = question
    return prompt

# Step 6: Generate SQL using T5 model
def get_sql(enhanced_prompt, tokenizer, model):
    source_text = "English to SQL: " + enhanced_prompt
    source_text = ' '.join(source_text.split())
    source = tokenizer.batch_encode_plus([source_text], max_length=128, pad_to_max_length=True,
                                         truncation=True, padding="max_length", return_tensors='pt')
    source_ids = source['input_ids']
    source_mask = source['attention_mask']
    generated_ids = model.generate(
        input_ids=source_ids.to(dtype=torch.long),
        attention_mask=source_mask.to(dtype=torch.long),
        max_length=150,
        num_beams=2,
        repetition_penalty=2.5,
        length_penalty=1.0,
        early_stopping=True
    )
    preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
    return preds[0]

# Step 7: Validate and correct the SQL
def refined_validate_sql_against_schema(sql_query, expected_tables, expected_columns):
    tokens = sql_query.split()
    for i, token in enumerate(tokens):
        if token.upper() == "FROM" and i + 1 < len(tokens) and tokens[i + 1] not in expected_tables:
            if expected_tables:
                tokens[i + 1] = expected_tables[0]
        if token.upper() == "SELECT" and i + 1 < len(tokens) and tokens[i + 1] not in expected_columns:
            if expected_columns:
                tokens[i + 1] = expected_columns[0]
    corrected_sql = ' '.join(tokens)
    return corrected_sql

# Main Workflow
def enhanced_workflow_with_mapping(question, tokenizer, model, schema_mapping):
    enhanced_prompt = refined_mapping_based_generate_prompt(question, schema_mapping)
    sql_query = get_sql(enhanced_prompt, tokenizer, model)
    expected_tables, expected_columns = refined_map_keywords_to_schema(question, schema_mapping)
    corrected_sql = refined_validate_sql_against_schema(sql_query, expected_tables, expected_columns)
    return corrected_sql

# Example Usage:
model = T5ForConditionalGeneration.from_pretrained('dsivakumar/text2sql')
tokenizer = T5Tokenizer.from_pretrained('dsivakumar/text2sql')
db_path = "example-covid-vaccinations.sqlite3"  # Replace with your actual database path
schema_mapping = generate_schema_mapping(db_path)
question = get_user_input()
resulting_sql = enhanced_workflow_with_mapping(question, tokenizer, model, schema_mapping)
print(resulting_sql)