In [4]:
from typing import List, Dict
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [5]:
tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")
model = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema-v2")


In [21]:
def get_prompt(tables, question):
    prompt = f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
    return prompt

def prepare_input(question: str, tables: Dict[str, List[str]]):
    tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
    tables = ", ".join(tables)
    prompt = get_prompt(tables, question)
    input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
    return input_ids

def inference(question: str, tables: Dict[str, List[str]]) -> str:
    input_data = prepare_input(question=question, tables=tables)
    input_data = input_data.to(model.device)
    outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
    result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
    return result

print(inference("List all students who are subscrived in the Database Systems and are born in 2003.",
    {'Students': ['student_id', 'name', 'birth_year'], 'Courses': ['course_id', 'name', 'max_enrollments'], 'Enrollments': ['student_id', 'course_id']}
    ))

print(inference("List names of employees who have a salary greater than $1300, work in the city of Oporto and are younger than 30 years.",
    { 'employees': ['employee_id', 'name', 'salary', 'age'], 'company': ['company_id', 'name', 'city'], 'work': ['employee_id', 'company_id']}))

print(inference("List the names of the employees who work in the company named 'Google'.",
    { 'employees': ['employee_id', 'name', 'salary', 'age'], 'company': ['company_id', 'name', 'city'], 'work': ['employee_id', 'company_id']}))

SELECT T1.name FROM students AS T1 JOIN enrollments AS T2 ON T1.student_id = T2.student_id WHERE T2.course_id = 'Database Systems' AND T1.birth_year = 2003
SELECT t1.name FROM employees AS t1 JOIN work AS t2 ON t1.employee_id = t2.employee_id JOIN company AS t3 ON t2.company_id = t3.company_id WHERE t3.city = 'Oporto' AND t1.age < 30
SELECT T1.name FROM employees AS T1 JOIN work AS T2 ON T1.employee_id = T2.employee_id WHERE T2.name = 'Google'
