In [5]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

tokenizer = AutoTokenizer.from_pretrained("leotod/gemma-270m-Text2SQL-Fine-tuned")
model = AutoModelForCausalLM.from_pretrained("leotod/gemma-270m-Text2SQL-Fine-tuned")

In [6]:
from datasets import load_dataset

dataset = load_dataset("gretelai/synthetic_text_to_sql")

In [7]:
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
{context}
</SCHEMA>

<USER_QUERY>
{question}
</USER_QUERY>
"""
def create_conversation(sample):
  return {
    "messages": [
      # {"role": "system", "content": system_message},
      {"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])},
      {"role": "assistant", "content": sample["sql"]}
    ]
  }

dataset["test"] = dataset["test"].map(create_conversation)

In [None]:
import re

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
rand_idx = 3
test_sample = dataset["test"][rand_idx]

stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
print("test sample: ")
print(test_sample["messages"][:1])
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:1], tokenize=False, add_generation_prompt=True)

# Generate SQL query.
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)

# Extract the user query and original answer
print(f"Context:\n", re.search(r'<SCHEMA>\n(.*?)\n</SCHEMA>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Query:\n", re.search(r'<USER_QUERY>\n(.*?)\n</USER_QUERY>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())


Device set to use mps:0
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


test sample: 
[{'content': "Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.\n\n<SCHEMA>\nCREATE TABLE sales (id INT, location VARCHAR(20), quantity INT, price DECIMAL(5,2)); INSERT INTO sales (id, location, quantity, price) VALUES (1, 'Northeast', 50, 12.99), (2, 'Midwest', 75, 19.99), (3, 'West', 120, 14.49);\n</SCHEMA>\n\n<USER_QUERY>\nWhat is the maximum quantity of seafood sold in a single transaction?\n</USER_QUERY>\n", 'role': 'user'}]


In [None]:
print(f"Original Answer:\n{test_sample['messages'][1]['content']}")

Original Answer:
SELECT MAX(quantity) FROM sales;


In [None]:
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")

Generated Answer:
SELECT MAX(quantity) FROM sales WHERE price > 1000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000


In [None]:
outputs[0]["generated_text"], outputs[0]

("<|user|>\nGiven the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.\n\n<SCHEMA>\nCREATE TABLE sales (id INT, location VARCHAR(20), quantity INT, price DECIMAL(5,2)); INSERT INTO sales (id, location, quantity, price) VALUES (1, 'Northeast', 50, 12.99), (2, 'Midwest', 75, 19.99), (3, 'West', 120, 14.49);\n</SCHEMA>\n\n<USER_QUERY>\nWhat is the maximum quantity of seafood sold in a single transaction?\n</USER_QUERY>\n\n<|assistant|>\nSELECT MAX(quantity) FROM sales WHERE price > 1000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
 {'generated_text': "<|user|>\nGiven the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considerin