In [1]:
from datasets import load_dataset
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def datasetTojsonl(file_path:str, dataset_path:str, type:str = "train"):
    import json 
    '''
    file_path - 저장할 파일위치와 이름
    dataset_path - huggingface repo
    type - choose one of thing what is train and test
    '''
    dataset = load_dataset(dataset_path)

    # DataFrame으로 변환하는 이유는 dataset에서 for문을 실행하면 매우 오래걸림
    df = pd.DataFrame(dataset[f'{type}'])
    df = df[['sql_prompt', 'sql_context', 'sql']]
    # df = df[(df['domain'] == 'food service') & (df['sql_complexity'] == 'basic SQL')]
    
    with open(file_path, "w") as f:
        for row in df.iterrows():
            newitem = {
                "input": row[1]['sql_prompt'],
                "context": row[1]['sql_context'],
                "output": row[1]['sql'], 
            }
            f.write(json.dumps(newitem) + "\n")

In [3]:
from math import ceil
from pathlib import Path
import json

from math import ceil

def load_jsonl(data_dir):
    data_path = Path(data_dir).as_posix()
    data = load_dataset("json", data_files=data_path)
    return data

def get_train_val_splits(
    data_dir: str,
    val_ratio: float = 0.1,
    seed: int = 42,
    shuffle: bool = True,
):
    data = load_jsonl(data_dir)
    num_samples = len(data["train"])
    val_set_size = ceil(val_ratio * num_samples)

    train_val = data["train"].train_test_split(
        test_size=val_set_size, shuffle=shuffle, seed=seed
    )
    return train_val["train"].shuffle(), train_val["test"].shuffle()

def save_jsonl(data_dicts, out_path):
    with open(out_path, "w") as fp:
        for data_dict in data_dicts:
            fp.write(json.dumps(data_dict) + "\n")

In [4]:
datasetTojsonl("./data/hf_train.json", "gretelai/synthetic_text_to_sql", "train")

In [5]:
raw_train_data, raw_val_data = get_train_val_splits(data_dir="./data/hf_train.json")

Generating train split: 100000 examples [00:00, 792847.12 examples/s]


In [6]:
save_jsonl(raw_train_data, "./data/train_hf_raw.jsonl")
save_jsonl(raw_val_data, "./data/val_hf_raw.jsonl")

In [7]:
raw_train_data[2]

{'input': 'What is the average age of players who play VR games and their total revenue?',
 'context': "CREATE TABLE Players (PlayerID INT, Age INT, Gender VARCHAR(10), PlayVR INT, TotalRevenue INT); INSERT INTO Players (PlayerID, Age, Gender, PlayVR, TotalRevenue) VALUES (1, 30, 'Female', 1, 5000); INSERT INTO Players (PlayerID, Age, Gender, PlayVR, TotalRevenue) VALUES (2, 25, 'Male', 0, 4000); INSERT INTO Players (PlayerID, Age, Gender, PlayVR, TotalRevenue) VALUES (3, 35, 'Non-binary', 1, 6000); INSERT INTO Players (PlayerID, Age, Gender, PlayVR, TotalRevenue) VALUES (4, 28, 'Male', 1, 7000); INSERT INTO Players (PlayerID, Age, Gender, PlayVR, TotalRevenue) VALUES (5, 40, 'Female', 0, 8000);",
 'output': 'SELECT AVG(Players.Age), SUM(Players.TotalRevenue) FROM Players WHERE Players.PlayVR = 1;'}

In [8]:
### Format is similar to the nous-hermes LLMs

text_to_sql_tmpl_str = """\
<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n### Instruction:{system_message}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>{user_message}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n### Response:\nGiven the database schema, here is the SQL query that answers\n{response}<|eot_id|>"""

text_to_sql_inference_tmpl_str = """\
<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n### Instruction:{system_message}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>{user_message}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n### Response:\nGiven the database schema, here is the SQL query that answers\n"""

### Alternative Format
### Recommended by gradient.ai docs, but empirically we found worse results here

# text_to_sql_tmpl_str = """\
# <s>[INST] SYS\n{system_message}\n<</SYS>>\n\n{user_message} [/INST] {response} </s>"""

# text_to_sql_inference_tmpl_str = """\
# <s>[INST] SYS\n{system_message}\n<</SYS>>\n\n{user_message} [/INST] """


def _generate_prompt_sql(input, context, output=""):
    system_message = f"""
- Generate a SQL query to answer [QUESTION]{input}[/QUESTION]
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity"""
    user_message = f"""
### Input:
[QUESTION]{input}[/QUESTION]


### Context:
This query will run on a database whose schema is represented in this string:
{context}
"""
    if output:
        return text_to_sql_tmpl_str.format(
            system_message=system_message,
            user_message=user_message,
            response=output,
        )
    else:
        return text_to_sql_inference_tmpl_str.format(
            system_message=system_message, user_message=user_message
        )


def generate_prompt(data_point):    
    full_prompt = _generate_prompt_sql(
        data_point["input"],
        data_point['context'],
        output=data_point["output"],
    )
  
    return {"inputs": full_prompt}

In [9]:
# <s>[INST] <<SYS>> <</SYS>> [/INST]</s>

In [10]:
train_data = []
for d in raw_train_data.map(generate_prompt):
    t_res = {"inputs": d["inputs"]}
    train_data.append(t_res)
save_jsonl(train_data, "./data/train_hf_last.jsonl")

val_data = []
for j in raw_val_data.map(generate_prompt):
    v_res = {"inputs": j["inputs"]}
    val_data.append(v_res)
save_jsonl(val_data, "./data/val_hf_last.jsonl")


Map: 100%|██████████| 90000/90000 [00:04<00:00, 18007.02 examples/s]
Map: 100%|██████████| 10000/10000 [00:00<00:00, 17745.21 examples/s]


In [56]:
sample = '''
 CREATE TABLE emp (
  EMPNO INT NOT NULL, -- number of employee
  ENAME VARCHAR(10) NULL, -- Name of the employee
  JOB VARCHAR(9) NOT NULL,  -- The employee's job
  MGR DECIMAL(4,0) NULL,  -- direct supervisor's employee number
  HIREDATE DATE NULL,  -- Employee joining date
  SAL DECIMAL(7,2) NULL, -- The employee's monthly salary
  COMM DECIMAL(7,2) NULL, -- Commissions		
  DEPTNO INT NULL, -- Department number
  PRIMARY KEY (EMPNO, DEPTNO)
 );

CREATE TABLE dept (
  DEPTNO INT , -- Department ID
  DNAME VARCHAR(14) NOT NULL, -- The Department's Name
  LOC VARCHAR(13) NOT NULL,  -- The Department's Location
  PRIMARY KEY (DEPTNO)
 );
 
CREATE TABLE athletics(
  id INT NOT NULL PRIMARY KEY, -- ID of the Athletics
  name VARCHAR(100) NOT NULL -- The Athletics name
);

CREATE TABLE events(
  id int NOT NULL PRIMARY KEY, -- ID of the Event
  sport varchar(50) NOT NULL, -- Name of the Sport
  event varchar(100) NOT NULL -- Name of the Event
);

CREATE TABLE teams(
  id INT NOT NULL PRIMARY KEY, -- ID of the Team Nation
  team VARCHAR(10) NOT NULL -- Name of the Team Nation
);

CREATE TABLE olympic_games(
  id INT NOT NULL PRIMARY KEY, -- ID of the Olympic games
  year INT NOT NULL, -- Olympic games Year
  season VARCHAR(10) NOT NULL, -- Olympic Season(Summer or Winter)
  city VARCHAR(50) NOT NULL -- Olympic host city
);

CREATE TABLE records(
  id INT NOT NULL PRIMARY KEY, -- ID of the Records
  athlete_id INT NOT NULL, -- ID of the Athlete
  sex VARCHAR(5), -- Player's sex(M or F)
  age INT NULL, -- Player's age at time of participation
  weight DECIMAL(5,1) NULL, -- Player's weight at time of participation
  height DECIMAL(5,1) NULL, -- Player's height at time of participation
  game_id INT NOT NULL, -- ID of the Olympic games
  team_id INT NOT NULL, -- ID of the Team Nation
  event_id INT NOT NULL, -- ID of the Event
  medal VARCHAR(10) NULL -- Medal(Gold, Silver, Bronze),
  FOREIGN KEY(athlete_id) REFERENCES athletics(id),
  FOREIGN KEY(game_id) REFERENCES olympic_games(id),
  FOREIGN KEY(team_id) REFERENCES teams(id),
  FOREIGN KEY(event_id) REFERENCES events(id),
);

-- emp.DEPTNO can be joined with dept.DEPTNO
-- athletes.id can be joined with records.athlete_id
-- olympic_games.id can be joined with records.game_id
-- teams.id can be joined with records.team_id
-- events.id can be joined with records.event_id
'''

In [57]:
classicmodels = '''
CREATE TABLE customers (
  customerNumber INT NOT NULL, -- number of customer
  customerName VARCHAR(50) NOT NULL, -- Name of the customer
  contactLastName VARCHAR(50) NOT NULL,  -- Last name of the employee
  contactFirstName VARCHAR(50) NOT NULL,  -- First name of the employee
  phone VARCHAR(50) NOT NULL,  -- Phone Number of the customer
  addressLine1 VARCHAR(50) NOT NULL, -- First AddressLine of the customer
  addressLine2 VARCHAR(50) NOT NULL, -- Second AddressLine of the customer
  city VARCHAR(50) NOT NULL, -- City where customers live
  state VARCHAR(50) NULL, -- state where customers live
  PRIMARY KEY(customerNumber)
 );

CREATE TABLE offices (
  officeCode VARCHAR(10) NOT NULL, -- officeCode 
  city VARCHAR(50) NOT NULL, -- The city where the office is located
  phone VARCHAR(50) NOT NULL,  -- The office phone number
  addressLine1 VARCHAR(50) NOT NULL,  -- Office's first address line
  addressLine2 VARCHAR(50) NULL,  -- Office's second address line
  state VARCHAR(50) NULL, -- The state where the office is located
  country VARCHAR(50) NULL, -- Country where the office is located
  postalCode VARCHAR(15) NOT NULL, -- PostalCode of office
  territory VARCHAR(10) NOT NULL, -- the state of territory
  PRIMARY KEY(officeCode)
  );

CREATE TABLE orderdetails (
  orderNumber INT NOT NULL, -- Order number
  productCode VARCHAR(50) NOT NULL, -- Product code number
  quantityOrdered VARCHAR(50) NOT NULL,  -- Order quantity of product
  priceEach DECIMAL(10,2) NOT NULL,  -- price per order
  orderLineNumber SMALLINT NOT NULL,  -- ?
  PRIMARY KEY(orderNumber, productCode)
);

CREATE TABLE orders (
  orderNumber INT NOT NULL, -- Order number
  orderDate DATE NOT NULL, -- Product order date
  requiredDate DATE NOT NULL,  -- The day the customer wants to receive the product they want
  shippedDate DATE NULL,  -- Date the product was delivered
  status VARCHAR(15) NOT NULL,  -- delivery status
  comments TEXT NULL,  -- Customer comment(requests)
  customerNumber INT NOT NULL,  -- number of customer
  PRIMARY KEY(orderNumber),
  FOREIGN KEY(customerNumber) REFERENCES customers(customerNumber)
);

CREATE TABLE payments (
  customerNumber INT NOT NULL, -- number of customer
  checkNumber VARCHAR(50) NOT NULL, -- payment confirmation number
  paymentDate DATE NOT NULL,  -- The day you paid for the product
  amount DECIMAL(10,2) NOT NULL,  -- Number of products ordered
  PRIMARY KEY(customerNumber, checkNumber),
  FOREIGN KEY(customerNumber) REFERENCES customers(customerNumber)
);

CREATE TABLE productlines (
  productLine VARCHAR(50) NOT NULL, -- Category of product
  textDescription VARCHAR(4000) NULL, -- Advertisement text for the product
  htmlDescription MEDIUMTEXT NULL,
  image MEDIUMBLOB NULL,
  PRIMARY KEY(productLine)
);

CREATE TABLE products (
  productCode VARCHAR(15) NOT NULL, -- Code of product
  productName VARCHAR(70) NOT NULL, -- Product name
  productLine VARCHAR(50) NOT NULL,  -- Category of product
  productScale VARCHAR(10) NOT NULL,  -- Measures for product size
  productVendor VARCHAR(50) NOT NULL,  -- A company that sells products
  productDescription TEXT NOT NULL, -- Description of the product
  quantityInStock SMALLINT NOT NULL, -- Quantity of product in stock
  buyPrice DECIMAL(10, 2) NOT NULL, -- price of one product
  MSRP DECIMAL(10, 2) NOT NULL, -- Manufacturer's Suggested Retail Price
  PRIMARY KEY(productCode),
  FOREIGN KEY(productLine) REFERENCES productlines(productLine)
);

CREATE TABLE employees (
  employeeNumber INT NOT NULL, -- number of employee
  lastName VARCHAR(50) NOT NULL, -- Last name of the employee
  firstName VARCHAR(50) NOT NULL,  -- First name of the employee
  extension VARCHAR(50) NOT NULL,  -- extension
  email VARCHAR(100) NOT NULL,  -- Email of the employee
  officeCode VARCHAR(10) NOT NULL, -- OfficeCode of the employee
  reportsTo INT NULL, -- Immediate Superior of the employee
  jobTitle VARCHAR(50) NOT NULL, -- Job position of the employee
  PRIMARY KEY(employeeNumber),
  FOREIGN KEY(officeCode) REFERENCES offices(officeCode),
  FOREIGN KEY(reportsTo) REFERENCES employees (employeeNumber)
);

-- productlines.productLine can be joined with products.productLine
-- orderdetails.customerNumber can be joined with orders.customerNumber
-- customers.customerNumber can be joined with payments.customerNumber
-- customers.customerNumber can be joined with orders.customerNumber
-- orderdetails.orderNumber can be joined with orders.orderNumber
-- offices.officeCode can be joined with employees.officeCode
-- orderdetails.orderNumber can be joined with orders.orderNumber
'''

In [11]:
train_data = []
for d in raw_train_data.map(generate_prompt):
    t_res = {"inputs": d["inputs"]}
    train_data.append(t_res)
save_jsonl(train_data, "./data/train_hf_last.jsonl")

val_data = []
for d in raw_val_data.map(generate_prompt):
    v_res = {"inputs": d["inputs"]}
    val_data.append(v_res)
save_jsonl(val_data, "./data/val_hf_last.jsonl")

Map: 100%|██████████| 90000/90000 [00:02<00:00, 32415.95 examples/s]
Map: 100%|██████████| 10000/10000 [00:00<00:00, 31746.94 examples/s]


### custom dataset

In [17]:
import pandas as pd 
from googletrans import Translator

df = pd.read_csv("./data/query_data.csv")
df = df[['스키마','질문' ,'Google translate', '답변', '결과개수']]
df.dropna(subset=['질문'], axis=0, how='all', inplace=True)
df
# df.reset_index(drop=True).head(3)

Unnamed: 0,스키마,질문,Google translate,답변,결과개수
0,sample,사원 테이블의 모든 레코드를 조회하시오.,Search all records in the employee table.,SELECT * FROM EMP,14.0
2,sample,사원명과 입사일을 조회하시오.,Check the employee name and employment date,"SELECT ENAME,HIREDATE FROM EMP",14.0
3,sample,사원번호와 이름을 조회하시오.,Check the employee number and name.,"SELECT DEPTNO,ENAME FROM EMP",14.0
4,sample,사원테이블에 있는 직책의 목록을 조회하시오.,View the list of positions in the employee table.,SELECT DISTINCT JOB FROM EMP,5.0
5,sample,총 사원수를 구하시오.,Find the total number of employees.,SELECT COUNT(EMPNO) FROM EMP,1.0
...,...,...,...,...,...
86,sample,2010년대에 개최된 올림픽에서 메달 수상을 하지 못한 팀을 조회하는 쿼리를 작성,Write a query to find teams that did not win a...,select distinct(t.team)\nfrom records r left j...,119.0
87,sample,여름에 개최된 올림픽에 참가한 선수들의 BMI 지수 상위 500명의 선수명과 BMI...,Write a query to retrieve the names and BMI in...,"select distinct(a.name),\n(r.weight/(r.height/...",11.0
88,sample,올림픽에 참가한 국가별 메달이 가장 많은 종목을 조회하는 쿼리를 작성\n결과는 참가...,Write a query to find the sport with the most ...,"select t.team,e.sport,count(*) AS Medal_Count\...",83.0
89,sample,"1988년에 개최된 올림픽에서 메달을 수상한 선수명과 종목명, 나이, 키를 조회하는...",Create a query to retrieve the name of the ath...,"select a.name,e.event, r.age, r.height \nfrom ...",11.0


In [50]:
### Format is similar to the nous-hermes LLMs

text_to_sql_tmpl_str = """\
<|begin_of_text|><|start_header_id|>system<|end_header_id|>### Instruction:{system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>{user_message}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>### Response:\nGiven the database schema, here is the SQL query that answers\n{response}<|eot_id|>"""

text_to_sql_inference_tmpl_str = """\
<|begin_of_text|><|start_header_id|>system<|end_header_id|>### Instruction:{system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>{user_message}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>### Response:\nGiven the database schema, here is the SQL query that answers\n"""

### Alternative Format
### Recommended by gradient.ai docs, but empirically we found worse results here

# text_to_sql_tmpl_str = """\
# <s>[INST] SYS\n{system_message}\n<</SYS>>\n\n{user_message} [/INST] {response} </s>"""

# text_to_sql_inference_tmpl_str = """\
# <s>[INST] SYS\n{system_message}\n<</SYS>>\n\n{user_message} [/INST] """


def _generate_prompt_sql(input, context, output=""):
    system_message = f"""
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity"""
    user_message = f"""

### Input:
{input}


### Context:
This query will run on a database whose schema is represented in this string:
{context}

"""
    if output:
        return text_to_sql_tmpl_str.format(
            system_message=system_message,
            user_message=user_message,
            response=output,
        )
    else:
        return text_to_sql_inference_tmpl_str.format(
            system_message=system_message, user_message=user_message
        )


def generate_prompt(data_point):
    if data_point['schema'] == 'sample':
        full_prompt = _generate_prompt_sql(
            data_point["input"],
            sample,
            output=data_point["output"],
        )
    elif data_point['schema'] == 'classicmodels':
        full_prompt = _generate_prompt_sql(
            data_point["input"],
            classicmodels,
            output=data_point["output"],
        )
    return {"inputs": full_prompt}

In [26]:
def datasetTojsonl(csv_path:str, save_path:str):
    import json 
    '''
    file_path - 저장할 파일위치와 이름
    dataset_path - huggingface repo
    type - choose one of thing what is train and test
    '''

    # DataFrame으로 변환하는 이유는 dataset에서 for문을 실행하면 매우 오래걸림
    df = pd.read_csv(csv_path)
    df = df[['스키마', '질문', 'Google translate', '답변']]
    df.dropna(subset=['질문'], axis=0, how='all', inplace=True)
    # df = df[(df['domain'] == 'food service') & (df['sql_complexity'] == 'basic SQL')]
    
    with open(save_path, "w") as f:
        for row in df.iterrows():
            newitem = {
                "schema": row[1]['스키마'],
                "input": row[1]['Google translate'],
                "output": row[1]['답변'], 
            }
            f.write(json.dumps(newitem) + "\n")

In [27]:
datasetTojsonl("./data/query_data.csv", "./data/query_data.json")

In [None]:
from math import ceil
from pathlib import Path
import json

from math import ceil

def load_jsonl(data_dir):
    data_path = Path(data_dir).as_posix()
    data = load_dataset("json", data_files=data_path)
    return data

def get_train_val_splits(
    data_dir: str,
    val_ratio: float = 0.1,
    seed: int = 42,
    shuffle: bool = True,
):
    data = load_jsonl(data_dir)
    num_samples = len(data["train"])
    val_set_size = ceil(val_ratio * num_samples)

    train_val = data["train"].train_test_split(
        test_size=val_set_size, shuffle=shuffle, seed=seed
    )
    return train_val["train"].shuffle(), train_val["test"].shuffle()

def save_jsonl(data_dicts, out_path):
    with open(out_path, "w") as fp:
        for data_dict in data_dicts:
            fp.write(json.dumps(data_dict) + "\n")

In [29]:
from math import ceil
from pathlib import Path
import json

from math import ceil

def load_jsonl(data_dir):
    data_path = Path(data_dir).as_posix()
    data = load_dataset("json", data_files=data_path)
    return data

def save_jsonl(data_dicts, out_path):
    with open(out_path, "w") as fp:
        for data_dict in data_dicts:
            fp.write(json.dumps(data_dict) + "\n")

In [33]:
train_0421 = load_jsonl("./data/query_data.json")

Generating train split: 74 examples [00:00, 36805.23 examples/s]


In [43]:
train_0421['train']['input'][2]

'Check the employee number and name.'

In [34]:
save_jsonl(train_0421, "./data/train_0421_raw.jsonl")

In [58]:
train_data = []
for d in train_0421['train'].map(generate_prompt):
    print(d)
    t_res = {"inputs": d["inputs"]}
    train_data.append(t_res)
save_jsonl(train_data, "./data/train_0421_query.jsonl")


Map: 100%|██████████| 74/74 [00:00<00:00, 20898.09 examples/s]

{'schema': 'sample', 'input': 'Search all records in the employee table.', 'output': 'SELECT *  FROM EMP', 'inputs': "<|begin_of_text|><|start_header_id|>system<|end_header_id|>### Instruction:\n- If you cannot answer the question with the available database schema, return 'I do not know'\n- Remember that revenue is price multiplied by quantity\n- Remember that cost is supply_price multiplied by quantity<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n### Input:\nSearch all records in the employee table.\n\n\n### Context:\nThis query will run on a database whose schema is represented in this string:\n\n CREATE TABLE emp (\n  EMPNO INT NOT NULL, -- number of employee\n  ENAME VARCHAR(10) NULL, -- Name of the employee\n  JOB VARCHAR(9) NOT NULL,  -- The employee's job\n  MGR DECIMAL(4,0) NULL,  -- direct supervisor's employee number\n  HIREDATE DATE NULL,  -- Employee joining date\n  SAL DECIMAL(7,2) NULL, -- The employee's monthly salary\n  COMM DECIMAL(7,2) NULL, -- Commissions\t\


