In [1]:
import sqlglot
from sqlglot.executor import execute
from sqlglot.errors import ExecuteError, TokenError, SchemaError, ExecuteError, ParseError, UnsupportedError, SqlglotError, OptimizeError

import json

from datasets import load_dataset

from faker import Faker
import json

# Initialize a Faker instance
fake = Faker()

In [48]:
from data import SQLData

In [49]:
sn = 'b-mc2/sql-create-context'

In [53]:
dataz = SQLData()

In [54]:
dataz.load_data(dataset_name=sn)

In [2]:
dataset = load_dataset('b-mc2/sql-create-context')

In [109]:
dataset

DatasetDict({
    train: Dataset({
        features: ['answer', 'context', 'question'],
        num_rows: 78577
    })
})

# creating a new column

```python
def compute_review_length(example):
    return {"review_length": len(example["review"].split())}

drug_dataset = drug_dataset.map(compute_review_length)
```

alternatively 

```python 
drug_dataset = drug_dataset.add_column(list)
```

In [3]:
def compute_table_count(dataset): 
    count = len(dataset['context'].split(';'))
    return {"table_count": count}

In [4]:
def abstract_column_types(dataset): 
    
    atls = sqlglot.parse(dataset['context'])
    tables = {}

    for atl in atls:
        column_types = {}
        table_name = atl.find(sqlglot.expressions.Identifier).this
        for expr in atl.this.expressions:
            column_name = expr.find(sqlglot.expressions.Identifier).this
            column_type = expr.find(sqlglot.expressions.DataType).this.value
            column_types[column_name] = column_type
        tables[table_name] = column_types

    return {"column_types": json.dumps(tables)}

In [5]:
def identify_duplicate_create_table(dataset): 

    create_count = dataset['table_count']
    table_count = len(json.loads(dataset['column_types']).keys())

    if create_count == table_count:
        return {"duplicate_create_table": False}
    else:
        return {"duplicate_create_table": True}

In [6]:
# Your column types data
column_types_data = '{"station": {"name": "VARCHAR", "id": "VARCHAR", "installation_date": "VARCHAR"}, "status": {"station_id": "VARCHAR", "bikes_available": "INT"}}'
column_types = json.loads(column_types_data)

def generate_random_data(data_type):
    if data_type == "VARCHAR":
        return fake.name()
    elif data_type == "INT":
        return fake.random_int(min=1, max=100)
    else:
        return None

def generate_filler_data(column_types, num_records=5):
    filler_data = {}
    
    for table_name, columns in column_types.items():
        filler_data[table_name] = []
        
        for _ in range(num_records):
            record = {}
            for column_name, data_type in columns.items():
                record[column_name] = generate_random_data(data_type)
            filler_data[table_name].append(record)
    
    return filler_data

def populate_data(dataset): 

    column_types = json.loads(dataset['column_types'])
    return {"filler_data": json.dumps(generate_filler_data(column_types))}

# Generate filler data
filler_data = generate_filler_data(column_types)
print(json.dumps(filler_data, indent=4))

{
    "station": [
        {
            "name": "Timothy Butler",
            "id": "Shaun Douglas",
            "installation_date": "Linda Owens"
        },
        {
            "name": "Amy Wood",
            "id": "Jennifer Coleman",
            "installation_date": "James Baker"
        },
        {
            "name": "Shaun Moreno",
            "id": "Mark Moses",
            "installation_date": "Pamela Beltran"
        },
        {
            "name": "Shari Tucker",
            "id": "Lauren Mack",
            "installation_date": "Christina Tucker"
        },
        {
            "name": "James Goodman",
            "id": "Benjamin Williams",
            "installation_date": "David Craig"
        }
    ],
    "status": [
        {
            "station_id": "Edward Hamilton",
            "bikes_available": 81
        },
        {
            "station_id": "Jasmine Griffin",
            "bikes_available": 99
        },
        {
            "station_id": "William Rodriguez"

In [7]:
def blanket_answer_syntax(dataset):

    answer = dataset['answer'].replace('"', "'")
    return {"answer": answer}
    

In [15]:
def validate_query(dataset): 

    tables = json.loads(dataset['filler_data'])
    query = dataset['answer']

    try:
        result = execute(query, tables=tables)
        result = str(result.rows)
        if result == None:
            result = ''
        return {"query_result": result, "valid_query": True}
    except ExecuteError as e:
        return {"query_result": "ExecuteError", "valid_query": False}
    except OptimizeError as e:
        return {"query_result": "OptimizeError", "valid_query": False}
    except TokenError as e:
        return {"query_result": "TokenError", "valid_query": False}
    except SchemaError as e:
        return {"query_result": "SchemaError", "valid_query": False}
    except ParseError as e:
        return {"query_result": "ParseError", "valid_query": False}
    except UnsupportedError as e:
        return {"query_result": "UnsupportedError", "valid_query": False}
    except SqlglotError as e: 
        return {"query_result": "SqlglotError", "valid_query": False}
    except Exception as e:
        return {"query_result": str(e), "valid_query": False}

In [9]:
dataset = dataset.map(compute_table_count)

In [10]:
dataset = dataset.map(abstract_column_types)

In [11]:
dataset = dataset.map(identify_duplicate_create_table)

In [12]:
dataset = dataset.map(populate_data)

Map:   0%|          | 0/78577 [00:00<?, ? examples/s]

In [13]:
dataset = dataset.map(blanket_answer_syntax)

Map:   0%|          | 0/78577 [00:00<?, ? examples/s]

In [16]:
dataset = dataset.map(validate_query)

Map:   0%|          | 0/78577 [00:00<?, ? examples/s]



In [17]:
# get a count of invalid queries
invalid_queries = dataset['train'].filter(lambda x: x['valid_query'] == False)

Filter:   0%|          | 0/78577 [00:00<?, ? examples/s]

In [93]:
check_error = dataset['train'].filter(lambda x: x['query_result'] == "OptimizeError")

Filter:   0%|          | 0/78577 [00:00<?, ? examples/s]

In [32]:
dataset['train'][80]

{'answer': "SELECT T1.name, T1.id FROM station AS T1 JOIN status AS T2 ON T1.id = T2.station_id GROUP BY T2.station_id HAVING AVG(T2.bikes_available) > 14 UNION SELECT name, id FROM station WHERE installation_date LIKE '12/%'",
 'context': 'CREATE TABLE station (name VARCHAR, id VARCHAR); CREATE TABLE station (name VARCHAR, id VARCHAR, installation_date VARCHAR); CREATE TABLE status (station_id VARCHAR, bikes_available INTEGER)',
 'question': 'What are the names and ids of stations that had more than 14 bikes available on average or were installed in December?',
 'table_count': 3,
 'column_types': '{"station": {"name": "VARCHAR", "id": "VARCHAR", "installation_date": "VARCHAR"}, "status": {"station_id": "VARCHAR", "bikes_available": "INT"}}',
 'duplicate_create_table': True,
 'filler_data': '{"station": [{"name": "Emily Nichols", "id": "Colleen Roth", "installation_date": "Andres Lopez"}, {"name": "Patrick Duncan", "id": "Spencer Miller", "installation_date": "Kevin Carroll"}, {"name":

In [30]:
dataset['train'][80]['answer'].replace('"', "'")

"SELECT T1.name, T1.id FROM station AS T1 JOIN status AS T2 ON T1.id = T2.station_id GROUP BY T2.station_id HAVING AVG(T2.bikes_available) > 14 UNION SELECT name, id FROM station WHERE installation_date LIKE '12/%'"

In [23]:
invalid_queries[5]

{'answer': 'SELECT DISTINCT zip_code FROM weather EXCEPT SELECT DISTINCT zip_code FROM weather WHERE max_dew_point_f >= 70',
 'context': 'CREATE TABLE weather (zip_code VARCHAR, max_dew_point_f VARCHAR)',
 'question': 'Find all the zip codes in which the max dew point have never reached 70.',
 'table_count': 1,
 'column_types': '{"weather": {"zip_code": "VARCHAR", "max_dew_point_f": "VARCHAR"}}',
 'duplicate_create_table': False,
 'filler_data': '{"weather": [{"zip_code": "Juan Gentry", "max_dew_point_f": "Karen Rodriguez"}, {"zip_code": "Kevin Ritter", "max_dew_point_f": "Douglas Garza"}, {"zip_code": "Angela Smith", "max_dew_point_f": "Robert Combs"}, {"zip_code": "Julia Velazquez", "max_dew_point_f": "Carrie Knapp DDS"}, {"zip_code": "Rachel Simmons", "max_dew_point_f": "Gail Hayes"}]}',
 'query_result': 'ExecuteError',
 'valid_query': False}

In [103]:
tt = check_error[5]['answer']

In [107]:
tt.replace('"', "'")

"SELECT T2.student_id FROM courses AS T1 JOIN student_course_registrations AS T2 ON T1.course_id = T2.course_id WHERE T1.course_name = 'statistics' ORDER BY T2.registration_date"

In [74]:
table = json.loads(check_error[41005]['filler_data'])
table

{'table_name_51': [{'pick': 'Emily Potter',
   'college_high_school_club': 'Dr. Megan Coleman MD',
   'round': 'Rodney Bridges',
   'draft': 'David Lucas'},
  {'pick': 'Beverly Mendoza',
   'college_high_school_club': 'Tiffany Chen',
   'round': 'Jonathan Baker',
   'draft': 'Deborah Rodriguez'},
  {'pick': 'Stephanie Burgess',
   'college_high_school_club': 'Sandra Morgan',
   'round': 'Mr. Donald Gray',
   'draft': 'Michelle Mcdonald'},
  {'pick': 'Bonnie Frank',
   'college_high_school_club': 'Angela Edwards',
   'round': 'Matthew Joseph',
   'draft': 'Rebecca Obrien'},
  {'pick': 'Melissa Schmidt',
   'college_high_school_club': 'Robert Barnes',
   'round': 'Jennifer Pitts',
   'draft': 'Mary Schneider'}]}

In [91]:
query = check_error[4105]['answer']
query

IndexError: Invalid key: 4105 is out of bounds for size 0

In [77]:
execute(query, tables=table)

OptimizeError: Column '"western kentucky"' could not be resolved

In [17]:
invalid_queries

Dataset({
    features: ['answer', 'context', 'question', 'table_count', 'column_types', 'duplicate_create_table', 'filler_data', 'query_result', 'valid_query'],
    num_rows: 2799
})

In [35]:
invalid_queries[100]

{'answer': "SELECT party FROM driver WHERE home_city = 'Hartford' AND age > 40",
 'context': 'CREATE TABLE driver (party VARCHAR, home_city VARCHAR, age VARCHAR)',
 'question': 'Show the party with drivers from Hartford and drivers older than 40.',
 'table_count': 1,
 'column_types': '{"driver": {"party": "VARCHAR", "home_city": "VARCHAR", "age": "VARCHAR"}}',
 'duplicate_create_table': False,
 'filler_data': '{"driver": [{"party": "Joseph Strickland", "home_city": "John Hamilton", "age": "Steven Rhodes"}, {"party": "Lisa Sanchez", "home_city": "Daniel Paul", "age": "Anthony Walsh"}, {"party": "Edwin Davis", "home_city": "Joshua Franco", "age": "Scott Jenkins"}, {"party": "Raymond Boyer", "home_city": "Zachary Cain", "age": "Jacqueline Young"}, {"party": "Cody Ramos", "home_city": "Gabriela Henson", "age": "Phyllis Anderson"}]}',
 'query_result': 'ExecuteError',
 'valid_query': False}

In [25]:
non_execute_error = invalid_queries.filter(lambda x: x['query_result'] != 'ExecuteError')

Filter:   0%|          | 0/2799 [00:00<?, ? examples/s]

In [32]:
non_execute_error.filter(lambda x: x['query_result'] != 'ParseError')

Filter:   0%|          | 0/201 [00:00<?, ? examples/s]

Dataset({
    features: ['answer', 'context', 'question', 'table_count', 'column_types', 'duplicate_create_table', 'filler_data', 'query_result', 'valid_query'],
    num_rows: 0
})

In [41]:
non_execute_error[5]

{'answer': 'SELECT MIN(2002 AS _population) FROM table_13764346_1',
 'context': 'CREATE TABLE table_13764346_1 (Id VARCHAR)',
 'question': 'What is the smallest population recorded back in 2002?',
 'table_count': 1,
 'column_types': '{"table_13764346_1": {"Id": "VARCHAR"}}',
 'duplicate_create_table': False,
 'filler_data': '{"table_13764346_1": [{"Id": "Kathleen Burns"}, {"Id": "Jared West"}, {"Id": "Jacqueline Smith"}, {"Id": "Curtis Thomas"}, {"Id": "Charles Mcbride"}]}',
 'query_result': 'ParseError',
 'valid_query': False}