# Dataset Preperation

In [6]:
import json
import os
import re
from concurrent.futures import ThreadPoolExecutor

import nltk
import pandas as pd
import sqlalchemy
from datasets import Dataset
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
from sqlalchemy import MetaData, create_engine
from sqlalchemy.exc import CompileError, NoReferencedColumnError
from sqlalchemy.schema import CreateTable

nltk.download("wordnet")
nltk.download("stopwords")

[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/schilver/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/schilver/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [9]:
DATASET_PATH = 'spider/database'
database_paths = [(db, os.path.join(DATASET_PATH, f'{db}/{db}.sqlite')) for db in os.listdir(DATASET_PATH)]
database_schema_paths = { db[0]: os.path.join(DATASET_PATH, f'{db[0]}/schema.sql') for db in database_paths }
database_con_strings = [(db[0], f'sqlite:///{db[1]}') for db in database_paths]
print("Found Databases: ", database_con_strings)


def crawl_database(con_string, db_id):
    engine = create_engine(con_string)

    # Reflect the existing database into a new MetaData instance
    metadata = MetaData()
    metadata.reflect(bind=engine)

    # Generate 'CREATE TABLE' statements for all tables
    create_table_statements = []
    try:
        for table in metadata.sorted_tables:
            create_statement = str(CreateTable(table).compile(engine)).strip()
            create_table_statements.append(create_statement)
    except (NoReferencedColumnError, CompileError)  as e:
        with open(database_schema_paths[db_id], 'r') as f:
            create_table_statements = filter(lambda x: len(x) > 0, map(lambda x: x.strip(), f.readlines()))
    except FileNotFoundError as e:
        print(f"Schema not found for {db_id}")
        return ''
    return '\n'.join(create_table_statements)

Found Databases:  [('browser_web', 'sqlite:///spider/database/browser_web/browser_web.sqlite'), ('musical', 'sqlite:///spider/database/musical/musical.sqlite'), ('farm', 'sqlite:///spider/database/farm/farm.sqlite'), ('voter_1', 'sqlite:///spider/database/voter_1/voter_1.sqlite'), ('game_injury', 'sqlite:///spider/database/game_injury/game_injury.sqlite'), ('hospital_1', 'sqlite:///spider/database/hospital_1/hospital_1.sqlite'), ('manufacturer', 'sqlite:///spider/database/manufacturer/manufacturer.sqlite'), ('station_weather', 'sqlite:///spider/database/station_weather/station_weather.sqlite'), ('perpetrator', 'sqlite:///spider/database/perpetrator/perpetrator.sqlite'), ('storm_record', 'sqlite:///spider/database/storm_record/storm_record.sqlite'), ('flight_1', 'sqlite:///spider/database/flight_1/flight_1.sqlite'), ('manufactory_1', 'sqlite:///spider/database/manufactory_1/manufactory_1.sqlite'), ('cre_Theme_park', 'sqlite:///spider/database/cre_Theme_park/cre_Theme_park.sqlite'), ('mu

In [10]:
database_schemas = {db[0]: crawl_database(db[1], db[0]) for db in database_con_strings}
with open('database_schemas.json', 'w') as f:
    json.dump(database_schemas, f)

  metadata.reflect(bind=engine)
  metadata.reflect(bind=engine)
  metadata.reflect(bind=engine)
  metadata.reflect(bind=engine)
  metadata.reflect(bind=engine)


In [11]:
databases = {}
type_map = []

def extract_tables_columns(tables):
    def extract_foriegn_keys(foriegn_keys):
        result = []
        for key in foriegn_keys:
            try:
                result.append((key.column.table.name, key.column.name))
            except Exception as e:
                pattern = r"'(.*?)'"
                matches = re.findall(pattern, str(key))
                if not matches:
                    matches = str(key).split('.')
                if len(matches) == 1:
                    result.append(matches[0].split('.'))
                elif len(matches) >= 2:
                    result.append((matches[0], matches[1]))
                else:
                    print(f"Error extracting foreign key: {e}")
        return result

# Finding all matches of the pattern in the input string
    result = {}
    for name, table in tables.items():
        for col in table.columns:
            col_type = str(col.type).split('(')[0]
            try:
                col_type_index = type_map.index(col_type)
            except ValueError:
                type_map.append(col_type)
                col_type_index = len(type_map) - 1
            if name not in result:
                result[name] = []
            result[name].append((col.name, col_type_index, int(col.nullable), int(col.primary_key), col.default, extract_foriegn_keys(col.foreign_keys)))
    return result


for name, db in database_con_strings:
    engine = sqlalchemy.create_engine(db)
    inspector = sqlalchemy.inspect(engine)
    metadata = sqlalchemy.MetaData()
    metadata.reflect(engine)
    try:
        tables = extract_tables_columns(metadata.tables)
    except Exception as e:
        print(f"Error extracting tables from {db}: {e}")
    databases[name] = tables

with open('databases.json', 'w') as f:
    json.dump(databases, f)

  metadata.reflect(engine)
  metadata.reflect(engine)
  metadata.reflect(engine)
  metadata.reflect(engine)
  metadata.reflect(engine)


In [15]:
def preprocess_text(text):
    # Tokenize the text
    tokens = word_tokenize(text)

    # Remove stopwords
    stop_words = set(stopwords.words('english'))
    tokens = [t for t in tokens if t not in stop_words]

    # Lemmatize the tokens
    lemmatizer = WordNetLemmatizer()
    tokens = [lemmatizer.lemmatize(t) for t in tokens]

    # Remove punctuation
    tokens = [t for t in tokens if t.isalpha()]

    # Join the tokens back into a string
    text = ' '.join(tokens)

    return text

def preprocess_dataset_entry(entry):
    """
    This function preprocesses a single dataset entry.
    It assumes `database_schemas` is a dictionary with DB schemas accessible by `db_id`.
    """
    processed_text = preprocess_text(entry['question'])
    processed_entry = {
        'input': f"translate to SQL: {processed_text} \n Schema: {json.dumps(databases.get(entry['db_id'], ''))}",
        'target': entry['query'],
    }
    return processed_entry

# Load your dataset
with open('spider/train_spider.json') as f:
    data = json.load(f)

processed_results = list(map(preprocess_dataset_entry, data))

# Convert the list of dictionaries to a pandas DataFrame
dataset_df = pd.DataFrame(processed_results)

In [16]:
dataset = Dataset.from_pandas(dataset_df)

In [17]:
max_input_len = max(map(lambda x: len(x['input'].split()), dataset))
max_target_len = max(map(lambda x: len(x['target'].split()), dataset))
print("Max Input Length: ", max_input_len)
print("Max Target Length: ", max_target_len)
print("Dataset Size: ", len(dataset))

dataset.save_to_disk('spider_dataset')

Max Input Length:  2181
Max Target Length:  87
Dataset Size:  7000


Saving the dataset (1/1 shards): 100%|██████████| 7000/7000 [00:00<00:00, 1233463.34 examples/s]


In [18]:
ds = dataset.train_test_split(test_size=0.1)
train_data = ds['train']
test_data = ds['test']

# Training the model

In [19]:
CKPT = 't5-small'
from transformers import AutoTokenizer, T5ForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained(CKPT)
model = T5ForConditionalGeneration.from_pretrained(CKPT)

In [20]:
# tokenize the examples
def convert_to_features(example_batch):
    input_encodings = tokenizer.batch_encode_plus(example_batch['input'], pad_to_max_length=True, max_length=2048)
    target_encodings = tokenizer.batch_encode_plus(example_batch['target'], pad_to_max_length=True, max_length=128)

    encodings = {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        'labels': target_encodings['input_ids'],
        'decoder_attention_mask': target_encodings['attention_mask']
    }

    return encodings

In [21]:
train_data = train_data.map(convert_to_features, batched=True, remove_columns=train_data.column_names)
test_data = test_data.map(convert_to_features, batched=True, remove_columns=test_data.column_names)

columns = ['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask']

train_data.set_format(type='torch', columns=columns)
test_data.set_format(type='torch', columns=columns)

Map:   0%|          | 0/6300 [00:00<?, ? examples/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Map: 100%|██████████| 6300/6300 [00:03<00:00, 1871.17 examples/s]
Map: 100%|██████████| 700/700 [00:00<00:00, 1574.75 examples/s]


In [22]:
from transformers import Seq2SeqTrainer
from transformers import Seq2SeqTrainingArguments

In [23]:
# set training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="t5-small-finetuned-spider",
    per_device_train_batch_size=16,
    num_train_epochs=10,
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    evaluation_strategy="epoch",
    do_train=True,
    do_eval=True,
    logging_steps=500,
    save_strategy="epoch",
    #save_steps=1000,
    #eval_steps=1000,
    overwrite_output_dir=True,
    save_total_limit=3,
    load_best_model_at_end=True,
    #fp16=True,
)

In [24]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    # compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=test_data,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [25]:
trainer.train()

  0%|          | 0/3940 [00:00<?, ?it/s]

RuntimeError: MPS backend out of memory (MPS allocated: 15.09 GB, other allocations: 2.06 GB, max allowed: 18.13 GB). Tried to allocate 2.00 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).