In [1]:
import torch
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig, PeftModel, get_peft_config, get_peft_model
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import pandas as pd
import re

In [2]:
# Call model/tokenizer
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
base_model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    device_map='auto', 
    torch_dtype=torch.bfloat16)

tokenizer.pad_token = tokenizer.eos_token
# print(base_model.hf_device_map)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [17]:
# call data
with open("/home/broodling/finQA/tr_post_table_0808.json", "r") as f1:
  posts = json.load(f1)

with open("/home/broodling/finQA/tr_pre_table_0808.json", "r") as f2:
  pres = json.load(f2)

with open("/home/broodling/finQA/subtable_extract_0817.json", "r") as f3:
  tables = json.load(f3)

with open("/home/broodling/finQA/datasets/FinQA/dataset/train.json", "r") as q:
  datas = json.load(q)

questions = []
for ques in datas:
  questions.append(ques['qa']['question'])

print(len(posts), len(pres), len(tables), len(questions))

6251 6251 6251 6251


In [4]:
example = """|<cell>| <col>  restricted stock and restricted stock units at beginning of year number of shares ( in thousands ) </col> <val> 407 </val> |</cell>|
 |<cell>| <col>  granted number of shares ( in thousands ) </col> <val> 607 </val> |</cell>|
 |<cell>| <col>  vested number of shares ( in thousands ) </col> <val> -134 ( 134 ) </val> |</cell>|
 |<cell>| <col>  forfeited number of shares ( in thousands ) </col> <val> -9 ( 9 ) </val> |</cell>|
 |<cell>| <col>  restricted stock and restricted stock units at end of year number of shares ( in thousands ) </col> <val> 871 </val> |</cell>|
 |<cell>| <col>  granted weighted average grant date fair value ( per share ) </col> <val> 18.13 </val> |</cell>|
 |<cell>| <col>  vested weighted average grant date fair value ( per share ) </col> <val> 10.88 </val> |</cell>|
 |<cell>| <col>  forfeited weighted average grant date fair value ( per share ) </col> <val> 13.72 </val> |</cell>|
 |<cell>| <col>  restricted stock and restricted stock units at end of year weighted average grant date fair value ( per share ) </col> <val> $ 15.76 </val> |</cell"""

In [29]:
# make schema code
def table_schema(table):
  schema = "CREATE TABLE TableD (\n"

  # 정규식을 통해 col/val 추출 및 val 정규화
  pattern = re.compile(r"<col>\s*(.*?)\s*</col>\s*<val>\s*(.*?)\s*</val>")
  matches = pattern.findall(table)
  # number_pattern = re.compile(r"^-?\d{1,3}(,\d{3})*(\.\d+)?$")
  number_pattern = re.compile(r"^-?\d+(\.\d+)?$")

  for col, val in matches:
    col_name = col.strip().replace(" ", "_").replace("-", "_").lower()
    val_clean = val.strip().replace(",", "")

    # (음수) 예외처리
    if("( " in val_clean):
       val_clean = val_clean.split("( ")[0]
       val_clean = val_clean.rstrip(" ")

    # $ 예외처리
    if("$ " in val_clean):
       val_clean = val_clean.lstrip("$ ")

    # data type 확인
    if number_pattern.match(val_clean):
        if "." in val_clean:
            col_type = "FLOAT"
        else:
            col_type = "INT"
    else:
        col_type = "VARCHAR(255)"

    schema += f"    {col_name} {col_type},\n"
  
  schema = schema.rstrip(",\n") + "\n);" # print(schema)
  
  return schema


table_schema(example)

'CREATE TABLE TableD (\n    restricted_stock_and_restricted_stock_units_at_beginning_of_year_number_of_shares_(_in_thousands_) INT,\n    granted_number_of_shares_(_in_thousands_) INT,\n    vested_number_of_shares_(_in_thousands_) INT,\n    forfeited_number_of_shares_(_in_thousands_) INT,\n    restricted_stock_and_restricted_stock_units_at_end_of_year_number_of_shares_(_in_thousands_) INT,\n    granted_weighted_average_grant_date_fair_value_(_per_share_) FLOAT,\n    vested_weighted_average_grant_date_fair_value_(_per_share_) FLOAT,\n    forfeited_weighted_average_grant_date_fair_value_(_per_share_) FLOAT,\n    restricted_stock_and_restricted_stock_units_at_end_of_year_weighted_average_grant_date_fair_value_(_per_share_) FLOAT\n);'

In [30]:
# pre/post text schema
def text_schema(text, pos):
  schema = f"CREATE TABLE {pos} (\n"

  pattern = re.compile(r"<col>\s*(.*?)\s*</col>\s*<val>\s*(.*?)\s*</val>")
  matches = pattern.findall(text)
  # number_pattern = re.compile(r"^-?\d{1,3}(,\d{3})*(\.\d+)?$")
  number_pattern = re.compile(r"^-?\d+(\.\d+)?$")
  
  for col, val in matches:
    col_name = col.strip().replace(" ", "_").replace("-", "_").lower()
    val_clean = val.strip().replace(",", "")

    # data type 확인
    if number_pattern.match(val_clean):
        if "." in val_clean:
            col_type = "FLOAT"
        else:
            col_type = "INT"
    else:
        col_type = "VARCHAR(255)"

    schema += f"    {col_name} {col_type},\n"
  
  schema = schema.rstrip(",\n") + "\n);" # print(schema)
  
  return schema

In [31]:
# prompt engineering
sys_prompt = """Given the following SQL schema tables and a natual language question, write a syntactically correct SQL query to run. Answer ONLY SQL query format, not specific explanation.
Use the following format:
Question: "Question"
SQLQuery: "SQL Query to run"
"""

template = """Given below SQL tables and a natual language question, generate a corresponding SQL query. \n{db}

Question: {ques}
SQLQuery: """

messages =[
  {"role": "system", "content": sys_prompt},
]

In [39]:
# Run text2SQL generation (Generate SQL Query) => posts, pres, tables, questions
total_len = len(tables)
SQLs = []
schemas = []

for idx in range(0,total_len):
  table_db = table_schema(tables[idx])
  pre_db = text_schema(pres[idx], "PreD")
  post_db = text_schema(posts[idx], "PostD")

  total_db = pre_db + "\n" + table_db + "\n" + post_db
  schemas.append(total_db)
  user_prompt = template.format(db=total_db, ques=questions[idx])
  dic = {"role": "user", "content": user_prompt}
  messages.append(dic)
  # print(dic["content"])

  input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")
  input_ids = input_ids.to(base_model.device) 

  output = base_model.generate(input_ids=input_ids,
                               max_length = 4096,
                               temperature=0.2,
                               pad_token_id = tokenizer.eos_token_id)[0]
  
  response = tokenizer.decode(output)
  res = response.split("<|eot_id|><|start_header_id|>assistant<|end_header_id|>")[1]
  res = res.rstrip("<|eot_id|>")
  res = res.lstrip("\n")
  print(res)

  SQLs.append(res)
  del messages[-1]


## save result
with open("text2SQL_0819.json", "w") as sql:
  json.dump(SQLs, sql)

with open("schemas_0819.json", "w") as sch:
  json.dump(schemas, sch)

Question: what is the the interest expense in 2009?
SQLQuery: SELECT fair_value_of_forward_exchange_contracts_asset_(_liability_)_october_31_2009 FROM TableD
SELECT CASE 
       WHEN (SUM(granted_weighted_average_grant_date_fair_value_(_per_share_)) - SUM(forfeited_weighted_average_grant_date_fair_value_(_per_share_))) > 
       (SUM(granted_number_of_shares_(_in_thousands_)) * vested_weighted_average_grant_date_fair_value_(_per_share_))
       THEN 'Yes'
       ELSE 'No'
       END AS answer
FROM TableD
SELECT year_2018_aircraft_fuelexpense * 1000000 FROM TableD
Question: what percentage of total cash and investments as of dec. 29 2012 was comprised of available-for-sale investments?
SQLQuery: SELECT CAST(((_in_millions_)_available_for_sale_investments_dec_292012 * 1.0) / ((_in_millions_)_total_cash_and_investments_dec_292012 * 1.0) * 100 AS FLOAT) FROM TableD
SELECT (TableD.2008_net_revenue_amount_(_in_millions_) - TableD.2007_net_revenue_amount_(_in_millions_)) / TableD.2007_net_rev

KeyboardInterrupt: 