In [1]:
import torch
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig, PeftModel, get_peft_config, get_peft_model
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import pandas as pd
import re

In [2]:
# Call model/tokenizer
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
base_model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    device_map='auto', 
    torch_dtype=torch.bfloat16)

tokenizer.pad_token = tokenizer.eos_token
# print(base_model.hf_device_map)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
# call data
with open("/home/broodling/finQA/post_schemas_0928.json", "r") as f1:
  posts = json.load(f1)

with open("/home/broodling/finQA/pre_schemas_0928.json", "r") as f2:
  pres = json.load(f2)

with open("/home/broodling/finQA/table_schemas_0928.json", "r") as f3:
  tables = json.load(f3)

with open("/home/broodling/finQA/datasets/FinQA/dataset/train.json", "r") as f4:
  datas = json.load(f4)


questions = []
for ques in datas[:200]:
  questions.append(ques['qa']['question'])


print(len(posts), len(pres), len(tables), len(questions))

200 200 200 200


In [10]:
# prompt engineering
sys_prompt = """Given the following pair of SQL schema and natual language question, generate a correct SQL query to correctly answer the given question. Considering given schema information, especially column names, generate SQL query that can solve given question and syntactically perfect(must be executed without any errors). 
Use the following format and ONLY answer in sql query:
Question: "Question"
SQLQuery: "SQL Query to run"
"""

template = """Given the SQL schema and a natual language question, generate a corresponding SQL query.
Schema: {schema}\n
Question: {ques}
SQLQuery: """

## few-shot example (BookSQL)
sch1="""CREATE TABLE chart_of_accounts(
    id INTEGER ,
    businessID INTEGER NOT NULL,
    Account_name TEXT NOT NULL,
    Account_type TEXT NOT NULL,
);
CREATE TABLE master_txn_table(
    id INTEGER ,
    businessID INTEGER NOT NULL ,
    Transaction_ID INTEGER NOT NULL,
    Transaction_DATE DATE NOT NULL,
    Transaction_TYPE TEXT NOT NULL,
    Amount DOUBLE NOT NULL,
    Account TEXT NOT NULL,
    Due_DATE DATE,             
);"""
user_prompt_1 = template.format(schema=sch1, ques="What acount had our biggest expense This week to date?")
assistant_prompt_1 = "SELECT account, SUM(debit) FROM master_txn_table AS T1 JOIN chart_of_accounts AS T2 ON T1.account = T2.account_name  WHERE account_type IN ('Expense','Other Expense') AND transaction_date BETWEEN date( current_date, \"weekday 0\", \"-7 days\") AND date( current_date) GROUP BY account ORDER BY SUM(debit) DESC LIMIT 1"

sch2 = """CREATE TABLE student (
    student_id INTEGER,
    last_name TEXT,
    first_name TEXT,
    age INTEGER,
    sex TEXT,
    major INTEGER,
    advisor INTEGER, 
    city_code TEXT, 
);
CREATE TABLE has_pet (
    student_id INTEGER,
    pet_id INTEGER,
);"""
user_prompt_2 = template.format(schema=sch2, ques="What is the average age for all students who do not own any pets?")
assistant_prompt_2 = "SELECT avg(age) FROM student WHERE student_id NOT IN (SELECT T1.student_id FROM student AS T1 JOIN has_pet AS T2 ON T1.student_id = T2.student_id)"

messages =[
  {"role": "system", "content": sys_prompt},
  {"role": "user", "content": user_prompt_1},
  {"role": "assistant", "content": assistant_prompt_1},
  {"role": "user", "content": user_prompt_2},
  {"role": "assistant", "content": assistant_prompt_2},
]

In [11]:
for pre, table, post in zip(pres, tables, posts):
  total = pre + table + post

In [16]:
# Run text2SQL generation (Generate SQL Query) => posts, pres, tables, questions
SQLs = []
for idx in tqdm(range(0,5)):
  total_db = pres[idx] + tables[idx] + posts[idx]
  user_prompt = template.format(schema=total_db, ques=questions[idx])
  dic = {"role": "user", "content": user_prompt}
  messages.append(dic)
  # print(dic["content"])

  input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
  input_ids = input_ids.to(base_model.device) 

  output = base_model.generate(input_ids=input_ids,
                               max_length = 12500,
                               temperature=0.2,
                               pad_token_id = tokenizer.eos_token_id)[0]
  
  response = tokenizer.decode(output)
  #print(response)
  res = response.split("<|eot_id|><|start_header_id|>assistant<|end_header_id|>")[3]
  res = res.rstrip("<|eot_id|>")
  res = res.lstrip("\n")
  # print(res)

  SQLs.append(res)
  del messages[-1]

 20%|██        | 1/5 [00:00<00:03,  1.09it/s]

SELECT interest_expense FROM TableData WHERE date = '2009-01-01'


 40%|████      | 2/5 [00:03<00:06,  2.01s/it]

SELECT CASE WHEN SUM(EquityAwards.granted_value) - SUM(EquityAwards.forfeited_value) > PostD.additional_stock_based_compensation_expense THEN 'Yes' ELSE 'No' END 
FROM EquityAwards 
JOIN PostD ON EquityAwards.year = PostD.year 
WHERE EquityAwards.year = 2012


 60%|██████    | 3/5 [00:04<00:03,  1.55s/it]

SELECT total_operating_expenses / 1000000 FROM OperatingExpenses WHERE year = 2018


 80%|████████  | 4/5 [00:06<00:01,  1.83s/it]

SELECT CAST(T2.available_for_sale_investments AS REAL) * 100 / T2.total_cash_and_investments FROM TableData AS T2 JOIN PreD AS T1 ON T2.date = T1.date WHERE T1.date = '2012-12-29'


100%|██████████| 5/5 [00:08<00:00,  1.78s/it]

SELECT (net_revenue - (SELECT net_revenue FROM TableData WHERE year = 2007)) / (SELECT net_revenue FROM TableData WHERE year = 2007) * 100 FROM TableData WHERE year = 2008





In [2]:
import json

with open("/home/broodling/finQA/text2sql_0929_200.json", "r") as f:
  sqls = json.load(f)

for sql in sqls[100:105]:
  print(sql)

SELECT T1.net_change FROM UnrecognizedTaxBenefits AS T1 JOIN Skyworks AS T2 ON T1.year = T2.year WHERE T1.year = 2012 AND T2.year = 2011
SELECT T1.operating_profit_increase FROM OperatingProfit AS T1 JOIN Aeronautics AS T2 ON T1.year = T2.year WHERE T2.year = 2011
SELECT CAST((operating_lease_obligations - LAG(operating_lease_obligations) OVER (ORDER BY year)) / LAG(operating_lease_obligations) OVER (ORDER BY year) * 100 AS DECIMAL(10, 2)) FROM ContractualObligations WHERE year IN (2009, 2010)
SELECT CAST(SUM(CASE WHEN T2.country = 'United States' THEN T1.square_feet ELSE 0 END) AS DECIMAL(10, 2)) * 100 / SUM(T1.square_feet) FROM TableData AS T1 JOIN AppliedLocations AS T2 ON T1.location = T2.location_nam
SELECT total_return FROM StockReturn WHERE index_name = 'Nasdaq Composite' AND initial_investment = 1000000.00 AND start_date BETWEEN '2009-01-01' AND '2009-12-31' AND end_date BETWEEN '2010-01-01' AND '2010-12-31'
