# Notebook to test similarity between columns and question

### Compute similarities between question and table info

In [1]:
import re
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from sentence_transformers import SentenceTransformer, util

nltk.download("punkt")
nltk.download("stopwords")

def remove_special_characters(text):
    # Remove special characters and spaces
    text = re.sub(r'[^a-zA-Z0-9]', ' ', text)
    # Remove extra spaces
    text = ' '.join(text.split())
    return text

def compute_similarity(question, table_info):
    # Load a pre-trained sentence transformer model (e.g., 'bert-base-nli-stsb-mean-tokens')
    model = SentenceTransformer('all-MiniLM-L12-v2')

    # Tokenize the question and remove stopwords
    question_tokens = set(word_tokenize(question.lower()))
    stopwords_set = set(stopwords.words("english"))
    question_tokens -= stopwords_set

    # Create a similarity dictionary to store similarity scores for each column
    column_similarities = {}

    for table_name, columns in table_info.items():
        # Remove special characters from table name
        # table_name = remove_special_characters(table_name)

        for column in columns:
            # Remove special characters from column name
            column_cleaned = remove_special_characters(column)

            # Compute embeddings for the cleaned column name
            column_embedding = model.encode(column_cleaned, convert_to_tensor=True)

            # Compute cosine similarity
            similarity_score = util.pytorch_cos_sim(
                model.encode(' '.join(question_tokens), convert_to_tensor=True),
                column_embedding
            ).item()

            # Store the similarity score with the original column name
            column_similarities[f'{table_name}.{column}'] = similarity_score

    # Sort the column similarities by descending similarity score
    sorted_similarities = sorted(column_similarities.items(), key=lambda x: x[1], reverse=True)

    return sorted_similarities


[nltk_data] Downloading package punkt to /home/namtrinh/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/namtrinh/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [14]:
model = SentenceTransformer('all-MiniLM-L12-v2')
similarity_score = util.pytorch_cos_sim(
    model.encode("what is id with name jui and age less than 25", convert_to_tensor=True),
    model.encode("name", convert_to_tensor=True)
).item()

print(similarity_score)

0.30940142273902893


### Filter the table based on the similarities

In [5]:
def filter_columns(table_info, similarities, threshold):
    # Create a new table_info dictionary to store filtered columns
    filtered_table_info = {}

    for table_name, columns in table_info.items():
        filtered_columns = []
        for column, similarity in similarities:
            # Check if the column belongs to the current table
            if column.startswith(table_name):
                # Keep the column if similarity is above the threshold
                if similarity >= threshold:
                    filtered_columns.append(column)

        # Only add the table to the filtered_table_info if it has at least one matching column
        if filtered_columns:
            filtered_table_info[table_name] = [col.split('.')[-1] for col in filtered_columns]

    return filtered_table_info

### Test the two functions

In [8]:
# Example table information with special characters and spaces
table_info = {
    'product_inventory': ['Product ID', 'Product Name', 'Category', 'Price', 'Quantity Available'],
    'employee_info': ['Employee ID', 'First Name', 'Last Name', 'Department', 'Salary'],
    'customer_orders': ['Order ID', 'Customer Name', 'Product ID', 'Quantity Ordered', 'Order Date'],
    'supplier_info': ['Supplier ID', 'Supplier Name', 'Product ID', 'Price', 'Availability'],
}

# Example question 
question = "What is the price of Product ID 12345?"

# Compute similarities for the question
similarities = compute_similarity(question, table_info)

for column, similarity in similarities:
    print(f'{column}: {similarity}')
    
# Set the threshold for filtering columns
threshold = 0.5

# Filter columns by threshold and print the filtered table_info
filtered_table_info = filter_columns(table_info, similarities, threshold)
print('\nFiltered table info: ', filtered_table_info)

product_inventory.Product ID: 0.6935524344444275
customer_orders.Product ID: 0.6935524344444275
supplier_info.Product ID: 0.6935524344444275
product_inventory.Price: 0.601889431476593
supplier_info.Price: 0.601889431476593
customer_orders.Order ID: 0.5161583423614502
supplier_info.Supplier ID: 0.47025567293167114
product_inventory.Product Name: 0.4453805088996887
product_inventory.Quantity Available: 0.39232760667800903
employee_info.Employee ID: 0.36830848455429077
customer_orders.Customer Name: 0.34300094842910767
customer_orders.Quantity Ordered: 0.3100193738937378
supplier_info.Supplier Name: 0.23508597910404205
customer_orders.Order Date: 0.1878075897693634
employee_info.Department: 0.18518131971359253
supplier_info.Availability: 0.18446308374404907
employee_info.First Name: 0.16425180435180664
employee_info.Salary: 0.1637563407421112
employee_info.Last Name: 0.1349141150712967
product_inventory.Category: 0.11724652349948883

Filtered table info:  {'product_inventory': ['Product I

In [10]:
from typing import List, Dict
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
model = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")

In [12]:
def get_prompt(tables, question):
    prompt = f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
    return prompt

def prepare_input(question: str, tables: Dict[str, List[str]]):
    tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
    tables = ", ".join(tables)
    prompt = get_prompt(tables, question)
    input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
    return input_ids

def inference(question: str, tables: Dict[str, List[str]]) -> str:
    similarities = compute_similarity(question, tables)
    filtered_table_info = filter_columns(tables, similarities, threshold=0.4)
    print(similarities)
    print('Filtered table info put to the prompt: ', str(filtered_table_info))
    
    input_data = prepare_input(question=question, tables=filtered_table_info)
    input_data = input_data.to(model.device)
    outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
    result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
    return result

print(inference("how many people with name jui and age less than 25", {
    "people_name": ["id", "name"],
    "people_age": ["people_id", "age"]
}))

print(inference("what is id with name jui and age less than 25", {
    "people_name": ["id", "name", "age"]
}))

[('people_age.age', 0.5091162323951721), ('people_age.people_id', 0.27685049176216125), ('people_name.name', 0.17837558686733246), ('people_name.id', 0.0970044806599617)]
Filtered table info put to the prompt:  {'people_age': ['age']}




SELECT count(*) FROM people_age WHERE name = 'jui' AND age < 25
[('people_name.age', 0.4023008644580841), ('people_name.id', 0.32995232939720154), ('people_name.name', 0.2714892625808716)]
Filtered table info put to the prompt:  {'people_name': ['age']}
SELECT id FROM people_name WHERE name = 'jui' AND age < 25
