In [None]:
import sqlglot
from sqlglot.executor import execute
from sqlglot.errors import ExecuteError, TokenError, SchemaError, ExecuteError, ParseError, UnsupportedError, SqlglotError

import json

from datasets import load_dataset

from faker import Faker
import json

# Initialize a Faker instance
fake = Faker()

In [None]:
dataset = load_dataset('b-mc2/sql-create-context')

In [None]:
dataset

# creating a new column

```python
def compute_review_length(example):
    return {"review_length": len(example["review"].split())}

drug_dataset = drug_dataset.map(compute_review_length)
```

alternatively 

```python 
drug_dataset = drug_dataset.add_column(list)
```

In [None]:
def compute_table_count(dataset): 
    count = len(dataset['context'].split(';'))
    return {"table_count": count}

In [None]:
def abstract_column_types(dataset): 
    
    atls = sqlglot.parse(dataset['context'])
    tables = {}

    for atl in atls:
        column_types = {}
        table_name = atl.find(sqlglot.expressions.Identifier).this
        for expr in atl.this.expressions:
            column_name = expr.find(sqlglot.expressions.Identifier).this
            column_type = expr.find(sqlglot.expressions.DataType).this.value
            column_types[column_name] = column_type
        tables[table_name] = column_types

    return {"column_types": json.dumps(tables)}

In [None]:
def identify_duplicate_create_table(dataset): 

    create_count = dataset['table_count']
    table_count = len(json.loads(dataset['column_types']).keys())

    if create_count == table_count:
        return {"duplicate_create_table": False}
    else:
        return {"duplicate_create_table": True}

In [None]:
# Your column types data
column_types_data = '{"station": {"name": "VARCHAR", "id": "VARCHAR", "installation_date": "VARCHAR"}, "status": {"station_id": "VARCHAR", "bikes_available": "INT"}}'
column_types = json.loads(column_types_data)

def generate_random_data(data_type):
    if data_type == "VARCHAR":
        return fake.name()
    elif data_type == "INT":
        return fake.random_int(min=1, max=100)
    else:
        return None

def generate_filler_data(column_types, num_records=5):
    filler_data = {}
    
    for table_name, columns in column_types.items():
        filler_data[table_name] = []
        
        for _ in range(num_records):
            record = {}
            for column_name, data_type in columns.items():
                record[column_name] = generate_random_data(data_type)
            filler_data[table_name].append(record)
    
    return filler_data

def populate_data(dataset): 

    column_types = json.loads(dataset['column_types'])
    return {"filler_data": json.dumps(generate_filler_data(column_types))}

# Generate filler data
filler_data = generate_filler_data(column_types)
print(json.dumps(filler_data, indent=4))

In [None]:
'''
def validate_query(dataset): 

    tables = json.loads(dataset['filler_data'])
    query = dataset['answer']

    try:
        result = execute(query, tables)
        return {"valid_query": True}
    except:
        return {"valid_query": False}
'''
    
def validate_query(dataset): 

    tables = json.loads(dataset['filler_data'])
    query = dataset['answer']

    try:
        result = execute(query, tables=tables)
        result = str(result.rows)
        if result == None:
            result = ''
        return {"query_result": result, "valid_query": True}
    except ExecuteError as e:
        return {"query_result": "ExecuteError", "valid_query": False}
    except TokenError as e:
        return {"query_result": "TokenError", "valid_query": False}
    except SchemaError as e:
        return {"query_result": "SchemaError", "valid_query": False}
    except ParseError as e:
        return {"query_result": "ParseError", "valid_query": False}
    except UnsupportedError as e:
        return {"query_result": "UnsupportedError", "valid_query": False}
    except SqlglotError as e: # it seems like this one gets thrown a lot even when its correct 
        return {"query_result": "SqlglotError", "valid_query": True}
    except Exception as e:
        return {"query_result": e, "valid_query": False}

In [None]:
dataset = dataset.map(compute_table_count)

In [None]:
dataset = dataset.map(abstract_column_types)

In [None]:
dataset = dataset.map(identify_duplicate_create_table)

In [None]:
dataset = dataset.map(populate_data)

In [None]:
dataset = dataset.map(validate_query)

In [None]:
dataset['train'][12]

In [None]:
# get a count of invalid queries
invalid_queries = dataset['train'].filter(lambda x: x['valid_query'] == False)

In [None]:
invalid_queries

In [None]:
invalid_queries[6]