In [1]:
from pandasql import sqldf
from typing import List, Dict
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
model = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")


In [3]:
def pysqldf(q, dataframes):
    return sqldf(q, dataframes)

#df_table = st.session_state.uploaded_data

def get_prompt(tables, question):
    prompt = f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
    return prompt

def prepare_input(question: str, tables: Dict[str, List[str]]):
    tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
    tables = ", ".join(tables)
    prompt = get_prompt(tables, question)
    input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
    return input_ids

def inference(question: str, tables: Dict[str, List[str]]) -> str:
    input_data = prepare_input(question=question, tables=tables)
    input_data = input_data.to(model.device)
    outputs = model.generate(inputs=input_data, num_beams=10, max_length=512) #top_k=10,
    result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
    return result

def user_query_dataframe(question: str, df_table: pd.DataFrame) -> (str, pd.DataFrame):
    """
    Receives a string (question) and a dataframe
    Pass the question to the model inference and apply the result query to the given dataframe

    """
    try:
        columns_list = df_table.columns.tolist()
        df_table_schema = {"df_table": columns_list}
        query = inference(question, df_table_schema)
        print(query)
        result_df = pysqldf(query, locals())
        return ('table2',result_df)
    except Exception as e:
        print(f"Error executing SQL query: {str(e)}")
        return ('0', question)  # Return an empty DataFrame or handle the error as needed


In [4]:
print(inference("what is id with name jui and age less than 25", {
    "people_name": ["id", "name", "age"]
}))

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


SELECT id FROM people_name WHERE name = 'jui' AND age < 25


In [6]:
print(inference("how many customers are there with product play_sports and younger than 35", {
    "product_table": ["customernumber", "product", "cus_age"]
}))

SELECT count(*) FROM product_table WHERE product = 'play_sports' AND cus_age < 35


In [8]:
print(inference("how many customers are there with product play_sports, lifestage young_adults and older than 18", {
    "customer_table": ["customernumber", "product", "cus_age", "lifestage"]
}))

SELECT count(*) FROM customer_table WHERE product = 'play_sports', cus_age = 'young_adults' AND age > 18


In [9]:
print(inference("how many customers are there with lifestage equal to young_adult and older than 18", {
    "customer_table": ["customernumber", "product", "cus_age", "lifestage"]
}))

SELECT count(*) FROM customer_table WHERE lifestage = 'young_adult' AND cus_age > 18
