<a href="https://colab.research.google.com/github/jai-llm/TEXT2SQL/blob/main/0_Text2SQL_Create_Datav3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 0. Text2SQL Data Pre-Processing
**Step 0:** Process Dataset For Training & Evaluation

In [1]:
!pip install -q -U datasets

### Imports

In [2]:
import pandas as pd
import json
import torch
import os

# Load Methods from Datasets Library
from datasets import load_dataset, Dataset, load_metric, load_from_disk

### Global Constants

In [3]:
dataset_name = "b-mc2/sql-create-context"

In [4]:
# GDrive Location for Train/Test Data
DATA_PATH ="/content/drive/MyDrive/Text2SQL/Data/"
DS_DIR = "sql_train_test"
PKL_DIR = "test/"
PKL_FILE ="sql_test.pkl"

# Generating Train/Test Data Parameters
TABLE_NAMES = True # Drops Most of WikiSQL
SIMPLE_INST = False # Complex Prompt
SAMPLE_RATE = 0.1 # Train=90% vs Test=10% Split

### Common Functions

In [5]:
def process_data(dataset_name, sample_rate,
                 table_names=False, simple_inst=False):
  '''Function Returns a DataFrame '''
  # 1. Move data to df
  txt2sql_ds = load_dataset(dataset_name)
  txt2sql_df = pd.DataFrame(txt2sql_ds)
  dsdf = pd.json_normalize(txt2sql_df['train'])
  # display(dsdf.head(2))

  # 2. Cleanup Steps
  # Dropping all examples where no definite table name is given
  # i.e. most of WikiSQL
  if table_names:
    dsdf = dsdf.loc[~dsdf['answer'].str.contains('FROM table_')]
  # Drop Dups
  dsdf.drop_duplicates(inplace=True)
  dsdf.rename(columns={'answer': 'response'}, inplace=True)

  # 3. Simple Instruction
  if simple_inst:
    template = """Below is an instruction that describes a task. \
    Write a response that appropriately completes the request.

    ### Instruction:
    Generate SQL query: {question}, \
    given the following schema: {context}

    ### Response:
    {response}
    ### End"""
  else:
    # change instuction according to the task
    template = """### Instruction:
    You are a powerful text-to-SQL model. \
    Your job is to answer questions about a database. \
    You are given a question and context regarding one or more tables.

    You must output the SQL query that answers the question.

    ### Input:
    {question}
    ### Context:
    {context}
    ### Response:
    {response}
    ### End"""
  # print('Got Here!!!')
  dsdf['text'] = dsdf.apply(template.format_map, axis=1)
  display(dsdf.head(2))
  # dataset = Dataset.from_pandas(pd.DataFrame(dsdf['text'])).train_test_split(test_size=0.2)
  # dataset = Dataset.from_pandas(dsdf.loc[:, ['text']]).train_test_split(test_size=0.2)
  dataset = Dataset.from_pandas(dsdf).train_test_split(test_size=sample_rate,
                                                       seed=42)
  print('Training Sample:')
  display(pd.DataFrame(dataset["train"]).head(2))
  print('Testing Sample:')
  display(pd.DataFrame(dataset["test"]).head(2))
  return dsdf, dataset

In [6]:
def process_test(ds, col='test', table_names=False, simple_inst=False):
  dsdf = pd.DataFrame(ds[col])
  # display(dsdf.head(2))

  # 2. Cleanup Steps
  # Dropping all examples where no definite table name is given
  # i.e. most of WikiSQL
  if table_names:
    dsdf = dsdf.loc[~dsdf['response'].str.contains('FROM table_')]
  # Drop Dups
  dsdf.drop_duplicates(inplace=True)
  # Drop "text" column since creating new "text" w/o Response
  dsdf.drop(columns=['text'], inplace=True)

  # 3. Simple Instruction
  if simple_inst:
    template = """Below is an instruction that describes a task. \
    Write a response that appropriately completes the request.

    ### Instruction:
    Generate SQL query: {question}, \
    given the following schema: {context}

    ### Response:
    """
  else:
    # change instuction according to the task
    template = """### Instruction:
    You are a powerful text-to-SQL model. \
    Your job is to answer questions about a database. \
    You are given a question and context regarding one or more tables.

    You must output the SQL query that answers the question.

    ### Input:
    {question}
    ### Context:
    {context}
    ### Response:
    """
  # print('Got Here!!!')
  dsdf['text'] = dsdf.apply(template.format_map, axis=1)
  display(dsdf.head(2))
  return dsdf

### Load and Store Process Dataset
- Stored as HF Dataset

In [7]:
# Note: table_names & simple_inst need to match in process_test and process_data
dsdf, dataset = process_data(dataset_name, sample_rate=SAMPLE_RATE,
                             table_names=TABLE_NAMES, simple_inst=SIMPLE_INST)

Unnamed: 0,question,context,response,text
0,How many heads of the departments are older th...,CREATE TABLE head (age INTEGER),SELECT COUNT(*) FROM head WHERE age > 56,### Instruction:\n You are a powerful text-...
1,"List the name, born state and age of the heads...","CREATE TABLE head (name VARCHAR, born_state VA...","SELECT name, born_state, age FROM head ORDER B...",### Instruction:\n You are a powerful text-...


Training Sample:


Unnamed: 0,question,context,response,text,__index_level_0__
0,What are the nationalities that are shared by ...,CREATE TABLE people (Nationality VARCHAR),SELECT Nationality FROM people GROUP BY Nation...,### Instruction:\n You are a powerful text-...,4326
1,What is the checking balance of the account wh...,"CREATE TABLE checking (balance VARCHAR, custid...",SELECT T2.balance FROM accounts AS T1 JOIN che...,### Instruction:\n You are a powerful text-...,1034


Testing Sample:


Unnamed: 0,question,context,response,text,__index_level_0__
0,Show the name of track and the number of races...,"CREATE TABLE track (name VARCHAR, track_id VAR...","SELECT T2.name, COUNT(*) FROM race AS T1 JOIN ...",### Instruction:\n You are a powerful text-...,429
1,Show names of shops and the carriers of device...,"CREATE TABLE shop (Shop_Name VARCHAR, Shop_ID ...","SELECT T3.Shop_Name, T2.Carrier FROM stock AS ...",### Instruction:\n You are a powerful text-...,2907


In [8]:
dataset.save_to_disk(DATA_PATH + DS_DIR)

Saving the dataset (0/1 shards):   0%|          | 0/4086 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/454 [00:00<?, ? examples/s]

In [9]:
ds2=load_from_disk(DATA_PATH + DS_DIR)

#### Check Stored Dataset

In [10]:
display(dataset['train'])
display(dataset['test'])

Dataset({
    features: ['question', 'context', 'response', 'text', '__index_level_0__'],
    num_rows: 4086
})

Dataset({
    features: ['question', 'context', 'response', 'text', '__index_level_0__'],
    num_rows: 454
})

In [11]:
display(ds2['train'])
display(ds2['test'])

Dataset({
    features: ['question', 'context', 'response', 'text', '__index_level_0__'],
    num_rows: 4086
})

Dataset({
    features: ['question', 'context', 'response', 'text', '__index_level_0__'],
    num_rows: 454
})

### Save Test Pandas DataFrame

In [12]:
# Note: table_names & simple_inst need to match in process_test and process_data
test_df = process_test(dataset, col='test', table_names=TABLE_NAMES,
                       simple_inst=SIMPLE_INST)
display(test_df['text'][4])

Unnamed: 0,question,context,response,__index_level_0__,text
0,Show the name of track and the number of races...,"CREATE TABLE track (name VARCHAR, track_id VAR...","SELECT T2.name, COUNT(*) FROM race AS T1 JOIN ...",429,### Instruction:\n You are a powerful text-...
1,Show names of shops and the carriers of device...,"CREATE TABLE shop (Shop_Name VARCHAR, Shop_ID ...","SELECT T3.Shop_Name, T2.Carrier FROM stock AS ...",2907,### Instruction:\n You are a powerful text-...


'### Instruction:\n    You are a powerful text-to-SQL model.     Your job is to answer questions about a database.     You are given a question and context regarding one or more tables.\n\n    You must output the SQL query that answers the question.\n\n    ### Input:\n    Find the locations that have more than one movie theater with capacity above 300.\n    ### Context:\n    CREATE TABLE cinema (LOCATION VARCHAR, capacity INTEGER)\n    ### Response:\n    '

In [13]:
test_df.to_pickle(DATA_PATH + PKL_DIR + PKL_FILE)

In [14]:
test_df2=pd.read_pickle(DATA_PATH+'test/'+PKL_FILE)

#### Check Stored Test DataFrame

In [15]:
display(test_df.head(2))

Unnamed: 0,question,context,response,__index_level_0__,text
0,Show the name of track and the number of races...,"CREATE TABLE track (name VARCHAR, track_id VAR...","SELECT T2.name, COUNT(*) FROM race AS T1 JOIN ...",429,### Instruction:\n You are a powerful text-...
1,Show names of shops and the carriers of device...,"CREATE TABLE shop (Shop_Name VARCHAR, Shop_ID ...","SELECT T3.Shop_Name, T2.Carrier FROM stock AS ...",2907,### Instruction:\n You are a powerful text-...


In [16]:
display(test_df2.head(2))

Unnamed: 0,question,context,response,__index_level_0__,text
0,Show the name of track and the number of races...,"CREATE TABLE track (name VARCHAR, track_id VAR...","SELECT T2.name, COUNT(*) FROM race AS T1 JOIN ...",429,### Instruction:\n You are a powerful text-...
1,Show names of shops and the carriers of device...,"CREATE TABLE shop (Shop_Name VARCHAR, Shop_ID ...","SELECT T3.Shop_Name, T2.Carrier FROM stock AS ...",2907,### Instruction:\n You are a powerful text-...
