In [None]:
!pip install -U datasets bitsandbytes accelerate peft trl

Collecting datasets
  Downloading datasets-2.19.1-py3-none-any.whl (542 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/542.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━[0m [32m225.3/542.0 kB[0m [31m6.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting bitsandbytes
  Downloading bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl (119.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.8/119.8 MB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate
  Downloading accelerate-0.30.1-py3-none-any.whl (302 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.6/302.6 kB[0m [31m37.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting peft
  Downloading peft-0.11.1-py3-none-any.whl (251 kB)
[2K     [90m━━━━━━━━━━━━━━━━━

In [None]:
import os
import torch
from datasets import Dataset, load_dataset
from transformers import BitsAndBytesConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed

def set_deterministic(seed):
    # SET DETERMINISTIC
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True, warn_only=True)
    os.putenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
    set_seed(seed, deterministic=True)

In [None]:
dataset = load_dataset("djagatiya/synthetic_text_to_sql_d14")
dataset

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/1.57k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.48M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/95.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/6713 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/408 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'domain', 'sql_complexity', 'sql_prompt', 'sql_context', 'sql'],
        num_rows: 6713
    })
    test: Dataset({
        features: ['id', 'domain', 'sql_complexity', 'sql_prompt', 'sql_context', 'sql'],
        num_rows: 408
    })
})

In [None]:
dataset['train'][0]

{'id': 5131,
 'domain': 'healthcare',
 'sql_complexity': 'basic SQL',
 'sql_prompt': 'What is the minimum cultural competency score by worker?',
 'sql_context': 'CREATE TABLE worker_scores (worker_id INT, score INT); INSERT INTO worker_scores (worker_id, score) VALUES (1, 95), (2, 88), (3, 72);',
 'sql': 'SELECT worker_id, MIN(score) FROM worker_scores;'}

### Prepare Dataset

In [None]:
SQL_PROMPT = """
Instruct: Write SQL query of question asked by user based on following database structure context.
Context: {context}
Question: {question}
Output: {sql}
<|endoftext|>
"""

def prepare_dataset(sample):

    sample['text'] = SQL_PROMPT.format(
        context=sample['sql_context'],
        question=sample['sql_prompt'],
        sql=sample['sql']
    ).strip()

    return sample

In [None]:
dataset = dataset.map(prepare_dataset)

Map:   0%|          | 0/6713 [00:00<?, ? examples/s]

Map:   0%|          | 0/408 [00:00<?, ? examples/s]

In [None]:
print(dataset['train'][0]['text'])

Instruct: Write SQL query of question asked by user based on following database structure context.
Context: CREATE TABLE worker_scores (worker_id INT, score INT); INSERT INTO worker_scores (worker_id, score) VALUES (1, 95), (2, 88), (3, 72);
Question: What is the minimum cultural competency score by worker?
Output: SELECT worker_id, MIN(score) FROM worker_scores;
<|endoftext|>


### Model Building

In [None]:
model_name='microsoft/phi-2'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    padding_side="left",
    use_fast=False,
    trust_remote_code=True
)
tokenizer.add_special_tokens({'pad_token': '<|padding|>'})
print(tokenizer.all_special_tokens)

tokenizer_config.json:   0%|          | 0.00/7.34k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/1.08k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


['<|endoftext|>', '<|padding|>']


In [None]:
encoded = tokenizer("Hello world<|endoftext|>", add_special_tokens=True, max_length=10, padding='max_length')
encoded

{'input_ids': [50295, 50295, 50295, 50295, 50295, 50295, 50295, 15496, 995, 50256], 'attention_mask': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1]}

In [None]:
tokenizer.convert_ids_to_tokens(encoded['input_ids'], skip_special_tokens=False)

['<|padding|>',
 '<|padding|>',
 '<|padding|>',
 '<|padding|>',
 '<|padding|>',
 '<|padding|>',
 '<|padding|>',
 'Hello',
 'Ġworld',
 '<|endoftext|>']

In [None]:
train_max_len = max([len(tokenizer.encode(t)) for t in dataset['train']['text']])
train_max_len

534

In [None]:
# Free memory
import gc
if 'model' in globals():
    print("Existing model deleted.")
    del model
gc.collect()
torch.cuda.empty_cache()

# Set Seed
set_deterministic(1)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype='float16',
    bnb_4bit_use_double_quant=False,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=0,
    quantization_config=bnb_config,
)

model.resize_token_embeddings(len(tokenizer))

from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
    lora_alpha=32,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "mlp.fc1", "mlp.fc2"],
    modules_to_save=["lm_head", "embed_tokens"]
)

model = get_peft_model(model, peft_config)

model.print_trainable_parameters()

Existing model deleted.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 330,966,136 || all params: 3,106,020,592 || trainable%: 10.6556


In [None]:
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import DataCollatorForLanguageModeling, TrainingArguments

In [None]:
response_template_with_context = "\nOutput:"
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)
response_template_ids

[198, 26410, 25]

In [None]:
collator = DataCollatorForCompletionOnlyLM(
    response_template_ids,
    tokenizer=tokenizer,
    mlm=False
)

In [None]:
train_ds, test_ds = dataset['train'], dataset['test']

In [None]:
args = TrainingArguments(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=1,
    num_train_epochs=2,
    # learning_rate=1e-4,
    optim="adamw_8bit",
    output_dir="model2",
    # save_embedding_layers=True,
    logging_strategy='steps',
    save_strategy='steps',
    evaluation_strategy='steps',
    eval_steps=300,
    logging_steps=300,
    save_steps=300,
    save_only_model=True
)

trainer = SFTTrainer(
    model,
    tokenizer=tokenizer,
    data_collator=collator,
    args=args,
    dataset_text_field='text',
    train_dataset=train_ds,
    eval_dataset=test_ds,
    max_seq_length=600
)

trainer.train()



Step,Training Loss,Validation Loss
300,0.1973,0.170346
600,0.1899,0.162319
900,0.1778,0.159404
1200,0.176,0.154415




Step,Training Loss,Validation Loss
300,0.1973,0.170346
600,0.1899,0.162319
900,0.1778,0.159404
1200,0.176,0.154415
1500,0.1601,0.149537
1800,0.1475,0.145144
2100,0.1384,0.139313
2400,0.1339,0.135821
2700,0.1314,0.136104
3000,0.1258,0.13544


TrainOutput(global_step=3356, training_loss=0.15461336415487478, metrics={'train_runtime': 7930.59, 'train_samples_per_second': 1.693, 'train_steps_per_second': 0.423, 'total_flos': 4.043936056164211e+16, 'train_loss': 0.15461336415487478, 'epoch': 1.9997020706092656})

### Evaluation

In [None]:
!zip -rj checkpoint-3300.zip /content/model2/checkpoint-3300

  adding: trainer_state.json (deflated 75%)
  adding: tokenizer_config.json (deflated 93%)
  adding: README.md (deflated 66%)
  adding: special_tokens_map.json (deflated 79%)
  adding: added_tokens.json (deflated 82%)
  adding: merges.txt (deflated 53%)
  adding: adapter_config.json (deflated 52%)
  adding: vocab.json (deflated 68%)
  adding: training_args.bin (deflated 51%)
  adding: adapter_model.safetensors (deflated 8%)


In [None]:
!cp /content/checkpoint-3300.zip /content/drive/MyDrive

In [None]:
!unzip /content/drive/MyDrive/checkpoint-3300.zip -d phi2_text2sql

Archive:  /content/drive/MyDrive/checkpoint-3300.zip
  inflating: phi2_text2sql/trainer_state.json  
  inflating: phi2_text2sql/tokenizer_config.json  
  inflating: phi2_text2sql/README.md  
  inflating: phi2_text2sql/special_tokens_map.json  
  inflating: phi2_text2sql/added_tokens.json  
  inflating: phi2_text2sql/merges.txt  
  inflating: phi2_text2sql/adapter_config.json  
  inflating: phi2_text2sql/vocab.json  
  inflating: phi2_text2sql/training_args.bin  
  inflating: phi2_text2sql/adapter_model.safetensors  


In [None]:
load_model_name = '/content/phi2_text2sql'

tokenizer = AutoTokenizer.from_pretrained(
    load_model_name,
    padding_side="left",
    use_fast=False,
    trust_remote_code=True
)

model = AutoModelForCausalLM.from_pretrained(
    'microsoft/phi-2',
    device_map=0,
    quantization_config=BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype='float16',
    bnb_4bit_use_double_quant=False)
)

model.resize_token_embeddings(len(tokenizer))

from peft import PeftModel

# Load the LoRA model
inference_model = PeftModel.from_pretrained(model, load_model_name)
inference_model = inference_model.eval()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
inference_model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): PhiForCausalLM(
      (model): PhiModel(
        (embed_tokens): ModulesToSaveWrapper(
          (original_module): Embedding(50296, 2560)
          (modules_to_save): ModuleDict(
            (default): Embedding(50296, 2560)
          )
        )
        (embed_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-31): 32 x PhiDecoderLayer(
            (self_attn): PhiSdpaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2560, out_features=2560, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=2560, bias=False)
           

In [None]:
import sqlite3

def query_execution_eval(context, true_sql, pred_sql):
    try:
        db = sqlite3.connect(":memory:")
        cur = db.cursor()
        cur.executescript(context)
        true_result = str(cur.execute(true_sql).fetchall())
        pred_result = str(cur.execute(pred_sql).fetchall())
        return true_result == pred_result
    except:
        return False

In [None]:
PRED_SQL_PROMPT = """
Instruct: Write SQL query of question asked by user based on following database structure context.
Context: {context}
Question: {question}
Output:"""

def prediction(samples):

    inputs = []
    for i in range(len(samples['id'])):
        inputs.append(PRED_SQL_PROMPT.format(
            context=samples['sql_context'][i],
            question=samples['sql_prompt'][i],
        ))

    # print(inputs)

    model_inputs = tokenizer(inputs, padding=True, return_tensors="pt").to("cuda")
    # print(model_inputs)

    generated_ids = inference_model.generate(**model_inputs, max_new_tokens=500)
    result = tokenizer.batch_decode(generated_ids[:, model_inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    result = [r.strip() for r in result]

    # AVOID REPEATATIVE
    for i in range(len(result)):
        r = result[i]
        stop_index = r.find(";")
        if stop_index != -1:
            result[i] = r[:stop_index+1]

    return result

In [None]:
import math

start_index, batch_size = 0, 8

all_evals = []
pred_sqls = []

for i in range(math.ceil(len(test_ds) / batch_size)):

    end_index = start_index + batch_size

    print(start_index, end_index)
    samples = test_ds[start_index:end_index]
    results = prediction(samples)
    print(results)
    pred_sqls.extend(results)

    eval_result = [query_execution_eval(context, sql, pred) \
        for context, sql, pred in \
        zip(samples['sql_context'], samples['sql'], results)]

    print(eval_result)

    all_evals.extend(eval_result)

    start_index = end_index

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


0 8


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT name FROM renewable_project WHERE country = 'India' AND budget BETWEEN 50.0 AND 200.0;", "SELECT AVG(energy_rating) FROM energy_efficiency WHERE building_type = 'Commercial' AND location = 'Texas';", 'SELECT player, AVG(running_speed) FROM world_cup WHERE match_id = 2020 GROUP BY player;', 'SELECT name, total_transactions FROM shariah_compliant_institutions;', 'SELECT gender, race, SUM(reoffender) as reoffenders_2017, SUM(reoffenders_2017) as total_reoffenders_2017, SUM(reoffenders_2018) as reoffenders_2018, SUM(reoffenders_2018) as total_reoffenders_2018 FROM parolee WHERE year IN (2017, 2018) GROUP BY gender, race;', 'SELECT district_id, SUM(cases) FROM restorative_justice GROUP BY district_id;', 'SELECT Director, AVG(Rating) as AvgRating FROM DirectorMoviesRating GROUP BY Director;', "SELECT AVG(rating) FROM movies WHERE production_year BETWEEN 2010 AND 2020 AND country = 'USA';"]
[True, False, False, False, False, True, True, True]
8 16


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT SUM(quantity) FROM sales WHERE product_category = 'Apparel' AND year = 2021;", 'SELECT source_type, MAX(size) as max_size, MIN(size) as min_size FROM space_debris GROUP BY source_type;', 'SELECT grade_name, AVG(mental_health_score) as avg_score FROM student_mental_health GROUP BY grade_name ORDER BY avg_score DESC;', "SELECT COUNT(*) FROM power_plants WHERE state = 'Texas' AND source_type IN ('Wind', 'Solar', 'Hydro');", 'SELECT name, AVG(points) FROM hockey_players JOIN nhl_teams ON hockey_players.id = nhl_teams.players_id GROUP BY name;', "SELECT product_name FROM shariah_compliant_products WHERE region = 'Southeast Asia';", 'SELECT release_year, COUNT(*) as num_movies FROM movies GROUP BY release_year;', "SELECT SUM(mass) FROM space_objects_heo WHERE orbit = 'High Earth Orbit';"]
[False, True, False, False, False, True, True, True]
16 24


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT team_name, SUM(points_home - points_away) as total_points FROM baseball_season GROUP BY team_name;', 'SELECT warehouse_id, revenue * 0.9 AS discounted_revenue FROM warehouse_revenue;', 'SELECT genre, SUM(frequency) FROM media_content GROUP BY genre;', 'SELECT COUNT(*) FROM space_objects_count;', "SELECT AVG(mental_health_score) FROM students JOIN courses ON students.course_id = courses.course_id WHERE courses.course_type = 'Traditional';", 'SELECT name, position, MAX(points_per_game) FROM points GROUP BY name, position;', "SELECT AVG(price) FROM products WHERE category = 'Electronics' AND is_circular_supply_chain = TRUE;", "SELECT store_id, SUM(revenue) FROM sales WHERE region = 'Northern' GROUP BY store_id;"]
[True, True, True, True, True, False, True, True]
24 32


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT COUNT(*) FROM storage_projects WHERE country = 'China' AND year_built BETWEEN 2010 AND 2020;", 'SELECT AVG(attendance) FROM nfl_games;', "SELECT AVG(amount) FROM loans WHERE country IN ('Turkey', 'Iran');", "SELECT SUM(services_provided) FROM legal_aid_services WHERE location = 'Rural Area' AND state = 'California' AND year = 2021;", 'SELECT name, country FROM content_creators WHERE represents_group = true ORDER BY views DESC;', "SELECT SUM(mass) FROM space_debris WHERE orbit = 'MEO' AND launch_date < '2010-01-01';", "SELECT AVG(revenue) FROM RetailSales WHERE garment_type = 'Jeans' AND country = 'Mexico' AND year = 2021;", "SELECT AVG(sustainability_score) FROM garment_data_2 WHERE collection = 'Autumn 2021';"]
[True, True, True, True, False, True, True, True]
32 40


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT country, COUNT(*) FROM farm_data WHERE is_organic = true GROUP BY country;', 'SELECT team_name, SUM(points_scored) FROM nba_teams GROUP BY team_name;', 'SELECT continent, name FROM ethics_by_continent;', 'SELECT country, COUNT(*) FROM socially_responsible_loans GROUP BY country;', "SELECT SUM(income) FROM clients WHERE country = 'Canada' AND is_socially_responsible_investor = true;", "SELECT AVG(transit_time) FROM ground_freight_routes WHERE origin = 'Toronto' AND destination = 'Montreal';", "SELECT initiative, country FROM historical_legal_tech WHERE launch_date >= '2010-01-01';", "SELECT SUM(quantity) FROM inventory WHERE fabric_name IN ('Tencel Lyocell', 'Bamboo Viscose');"]
[True, True, True, True, True, True, True, True]
40 48


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT name, SUM(area_in_hectares) FROM crop GROUP BY name;', "SELECT * FROM Farmers WHERE location = 'Asia';", "SELECT AVG(mental_health_score) FROM students WHERE gender = 'Female';", "SELECT SUM(capacity) FROM energy_storage WHERE country IN ('Australia', 'Canada') AND year >= 2018;", "SELECT warehouse.location FROM inventory INNER JOIN warehouse ON inventory.warehouse_id = warehouse.id WHERE inventory.item_code = 'ORG-01' GROUP BY warehouse.location ORDER BY inventory.warehouse_id LIMIT 1;", 'SELECT COUNT(*) FROM workers JOIN factories ON workers.factory_id = factories.id WHERE factories.audit_passed = TRUE;', "SELECT SUM(quantity) FROM WAREHOUSE WHERE product = 'Product A';", 'SELECT SUM(sales_quantity) FROM mexico_mens_garments WHERE quarter = 4 AND year = 2020;']
[True, True, True, True, False, True, True, True]
48 56


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT name FROM research_institutes WHERE type = 'Non-profit' AND location!= 'Midwest';", "SELECT SUM(num_projects) FROM latam_renewable_projects WHERE country IN ('Colombia', 'Peru') AND year IN (2020, 2021);", "SELECT SUM(savings) FROM energy_efficiency WHERE state = 'Texas' AND year = 2020;", "SELECT MIN(quantity) FROM products WHERE category = 'gifts';", 'SELECT AVG(num_satellites) FROM countries INNER JOIN space_programs ON countries.id = space_programs.country WHERE country IS NOT NULL;', "SELECT MAX(SpaceMissions) FROM Astronauts WHERE Nationality = 'Japan';", "SELECT SUM(mass) FROM space_debris WHERE orbit = 'LEO' AND mass > 10;", 'SELECT (COUNT(*) FILTER (WHERE hours > 20)) * 100.0 / COUNT(*) FROM teachers;']
[False, True, True, True, False, False, True, True]
56 64


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT AVG(age) FROM cricket_players;', 'SELECT player_name, goals FROM world_cup_goals ORDER BY goals DESC LIMIT 3;', 'SELECT Program FROM Social_Good_Tech WHERE Month BETWEEN 1 AND 6;', "SELECT AVG(financial_wellbeing_score) FROM shariah_compliant_customers WHERE shariah_compliant_account = true AND wellbeing_assessment_date BETWEEN '2022-04-01' AND '2022-06-30';", 'SELECT MAX(amount_invested) FROM shariah_compliant_funds_investments;', 'SELECT state, 100.0 * SUM(spokes_spanish) / COUNT(*) as percentage_speaking_spanish FROM community_health_workers_lang GROUP BY state;', "SELECT COUNT(*) FROM community_health_workers WHERE state IN ('New York', 'California');", "SELECT crops.name, SUM(crops.yield) FROM crops JOIN farmers ON crops.farmer_id = farmers.id WHERE farmers.country = 'Asia' GROUP BY crops.name;"]
[True, True, False, True, False, False, True, False]
64 72


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT team_name, wins FROM football_teams ORDER BY wins DESC LIMIT 1;', 'SELECT WarehouseId, Country, SUM(Quantity) as TotalQuantity FROM Shipments GROUP BY WarehouseId, Country;', "SELECT MAX(mental_health_score) FROM student_mental_health WHERE date BETWEEN '2021-09-01' AND '2021-09-30';", "SELECT position, name FROM players WHERE sport = 'Hockey';", 'SELECT name, team, AVG(points_per_game) as avg_points_per_game FROM players GROUP BY name, team ORDER BY avg_points_per_game DESC LIMIT 5;', 'SELECT country, COUNT(*) FROM shipments GROUP BY country;', 'SELECT route_id, start_location, end_location, distance FROM parcel_delivery WHERE distance > 1000;', 'SELECT location, MIN(age) as min_age, MAX(age) as max_age, AVG(age) as avg_age FROM victims JOIN restorative_justice_participants ON victims.id = restorative_justice_participants.victim_id GROUP BY location;']
[False, False, True, False, True, True, True, True]
72 80


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT suppliers.supplier_id, suppliers.supplier_name FROM suppliers INNER JOIN materials ON suppliers.supplier_id = materials.supplier_id GROUP BY suppliers.supplier_id HAVING COUNT(*) >= 3;', "SELECT COUNT(*) FROM orders JOIN garments ON orders.garment_id = garments.id WHERE garments.name = 'Vegan Leather Shoes' AND garments.country = 'France' AND orders.quantity > 3;", 'SELECT collection, AVG(co2_emissions) FROM emissions GROUP BY collection;', 'SELECT Gender, COUNT(*) FROM MentalHealthParityGender GROUP BY Gender;', 'SELECT District, CrimeType, COUNT(*) as Count FROM Crimes GROUP BY District, CrimeType ORDER BY District, CrimeType;', "SELECT garment_name FROM Spring2023 WHERE material IN ('Silk', 'Cotton');", "SELECT MIN(budget) FROM ai_projects WHERE country IN ('Germany', 'France', 'UK', 'Spain');", 'SELECT SUM(num_employees) FROM Companies WHERE has_ethical_ai = true;']
[False, True, True, True, False, True, True, True]
80 88


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT category, SUM(quantity) FROM products GROUP BY category;', "SELECT COUNT(*) FROM Satellites WHERE Orbit = 'Low Earth Orbit' AND Operational = TRUE;", "SELECT AVG(age) FROM astronauts WHERE country = 'Japan';", 'SELECT COUNT(*) FROM mars_missions;', 'SELECT team, SUM(games) FROM nba_schedule GROUP BY team;', 'SELECT SUM(penalties) FROM penalties WHERE team_id = 306;', "SELECT COUNT(*) FROM loans WHERE is_socially_responsible = true AND region = 'South';", 'SELECT location, total_inventory FROM warehouse;']
[False, True, True, True, False, True, True, True]
88 96


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT SUM(quantity) FROM inventory;', 'SELECT type, resolution_method, SUM(success) as success_rate FROM disputes GROUP BY type, resolution_method;', 'SELECT m.manufacturer_name, SUM(w.retail_price) as total_retail_value FROM Winter2022 w JOIN Manufacturers m ON w.manufacturer_id = m.manufacturer_id GROUP BY m.manufacturer_name;', "SELECT SUM(capacity) FROM energy_storage WHERE region = 'California' AND year IN (2018, 2019);", "SELECT MIN(tech_accessibility_score), MAX(tech_accessibility_score) FROM org_accessibility WHERE sector = 'education';", 'SELECT salesperson, SUM(items) FROM sales GROUP BY salesperson;', 'SELECT program_id, COUNT(DISTINCT org_id) as org_count FROM community_orgs GROUP BY program_id;', "SELECT COUNT(*) FROM agroecology_research WHERE country IN ('CO', 'PE');"]
[True, False, True, False, True, True, True, False]
96 104


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT AVG(years_of_experience) FROM teachers WHERE mental_health_resource_access IS NOT NULL;', "SELECT COUNT(*) FROM renewable_count WHERE location = 'Country R';", 'SELECT country, production_quantity FROM solar_energy ORDER BY production_quantity DESC LIMIT 3;', "SELECT SUM(initiative_id, region, funds) FROM ethical_ai_initiatives WHERE region = 'North America';", "SELECT organization FROM ai_oversight WHERE region = 'Canada';", "SELECT AVG(weight) FROM Shipments WHERE origin_country = 'UK' AND shipment_date BETWEEN '2022-01-01' AND '2022-01-31';", "SELECT worker_name, SUM(patients_served) as total_patients_served FROM community_workers WHERE community_type IN ('African American', 'Hispanic', 'LGBTQ+', 'Rural', 'Asian', 'Native American') GROUP BY worker_name ORDER BY total_patients_served DESC LIMIT 1;", "SELECT (COUNT(*) FILTER (WHERE city = 'Los Angeles' AND has_been_homeless = true)) * 100.0 / COUNT(*) FROM legal_aid_clients;"]
[True, True, True, False, True, True, False, Tru

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT judge_name, COUNT(*) as total_cases FROM criminal_cases GROUP BY judge_name;', "SELECT COUNT(*) FROM Spacecrafts WHERE Manufacturer = 'SpaceX' AND Operational = TRUE;", 'SELECT MIN(launch_date) FROM space_missions;', "SELECT garment_type, SUM(quantity) as total_quantity FROM garment_sales WHERE region = 'Europe' GROUP BY garment_type ORDER BY total_quantity DESC LIMIT 3;", 'SELECT student_id, MAX(score) - MIN(score) as improvement FROM student_mental_health GROUP BY student_id;', "SELECT SUM(energy_produced) FROM renewable_energy WHERE country = 'Germany' AND year = 2020;", "SELECT COUNT(*) FROM games WHERE league = 'NHL' AND year >= 2000;", 'SELECT team, AVG(engagement) as avg_engagement FROM social_media GROUP BY team;']
[True, True, True, True, False, True, True, True]
112 120


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT MIN(price) FROM Devices WHERE community LIKE '%Underrepresented%';", "SELECT SUM(quantity) FROM production WHERE category = 'Ethical Clothing' AND year IN (2021, 2022);", 'SELECT grade_level, AVG(participation_score) FROM student_open_pedagogy GROUP BY grade_level;', "SELECT COUNT(*) FROM legal_tech_events WHERE location IN ('New York', 'California');", 'SELECT language, format, SUM(views) as total_views FROM open_education_resources GROUP BY language, format ORDER BY total_views DESC;', "SELECT SUM(projects) FROM Ethical_AI WHERE sector = 'Healthcare';", "SELECT client_country, COUNT(*) as num_clients FROM socially_responsible_loans WHERE client_country NOT IN ('Saudi Arabia', 'UAE') GROUP BY client_country;", "SELECT ProductName, SUM(Quantity) FROM Shipments JOIN Warehouses ON Shipments.WarehouseID = Warehouses.WarehouseID WHERE Warehouses.WarehouseName = 'Tokyo Warehouse' GROUP BY ProductName;"]
[False, False, True, True, True, False, True, True]
120 128


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT genre, AVG(length) as avg_length FROM tracks GROUP BY genre ORDER BY avg_length DESC;', "SELECT SUM(weight) FROM shipments WHERE origin = 'Canada' AND destination = 'United States' AND shipped_at BETWEEN '2021-01-01' AND '2021-01-31';", 'SELECT ht, SUM(programs) FROM language_access GROUP BY ht;', "SELECT product_name, price FROM products WHERE country_of_manufacture!= 'USA' AND is_on_sale = FALSE;", 'SELECT c.customer_name, SUM(p.purchase_value) as total_spent_ethical_fashion FROM ethical_fashion_purchases p JOIN customers c ON p.customer_id = c.customer_id WHERE p.purchase_id BETWEEN 10 AND 13 AND p.customer_id IN (1, 2) GROUP BY c.customer_name ORDER BY total_spent_ethical_fashion DESC LIMIT 3;', 'SELECT AVG(Moons) FROM SolarSystem;', 'SELECT country, SUM(cost) FROM missions GROUP BY country;', "SELECT COUNT(*) FROM transactions WHERE region = 'Midwest' AND non_gmo = true;"]
[False, True, False, True, False, False, False, True]
128 136


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT SUM(capacity) FROM renewable_plants WHERE country = 'Australia';", "SELECT name FROM solar_plants WHERE country = 'Spain' AND capacity > 50;", "SELECT AVG(score) FROM ai_tools WHERE type = 'PersonsWithDisabilities';", "SELECT SUM(violation_date) FROM mental_health_parity WHERE location = 'Illinois' AND violation_date BETWEEN '2020-01-01' AND '2020-12-31';", 'SELECT District, AVG(HearingDuration) as AvgHearingDuration FROM CommunityCourtHearings GROUP BY District;', "SELECT MAX(likes) FROM posts WHERE domain = 'Media Literacy' AND region = 'Asia';", "SELECT COUNT(*) FROM products WHERE category = 'grocery';", "SELECT artists.artist_name, COUNT(streams.stream_id) as streams_count FROM artists JOIN streams ON artists.artist_id = streams.artist_id WHERE streams.stream_date BETWEEN '2019-01-01' AND '2019-03-31' GROUP BY artists.artist_name;"]
[True, False, True, False, False, True, True, True]
136 144


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT student_id FROM student_lifelong_learning WHERE course_id IS NULL;', 'SELECT team_name, wins FROM nba_teams ORDER BY wins DESC LIMIT 3;', 'SELECT athlete, AVG(time_in_pool) FROM olympic_swimming GROUP BY athlete;', "SELECT client_id, account_balance FROM microfinance_program WHERE program_name = 'Socially Responsible Microfinance';", "SELECT warehouse_country, warehouse_city, MAX(pallets) FROM warehouse_stats WHERE warehouse_country = 'Colombia' GROUP BY warehouse_country, warehouse_city;", "SELECT AVG(products.price) FROM products JOIN vendors ON products.vendor_id = vendors.vendor_id WHERE products.organic = true AND vendors.country = 'USA';", "SELECT SUM(revenue) FROM carbon_pricing WHERE country = 'Canada' AND year = 2021;", 'SELECT name, AVG(home_run_distance) FROM players;']
[True, True, False, False, False, True, True, False]
144 152


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT MAX(distance) FROM Routes WHERE destination_city = 'City X';", 'SELECT m.name, COUNT(t.training_id) as num_trainings FROM mental_health_parity_officers m JOIN trainings_conducted t ON m.officer_id = t.officer_id GROUP BY m.name;', "SELECT AVG(mental_health_score) FROM patients WHERE community IN ('African American', 'Latinx', 'Asian American');", "SELECT MAX(altitude) FROM leo_satellites WHERE type = 'LEO';", 'SELECT country, COUNT(*) as launches_count FROM launches GROUP BY country;', 'SELECT Agency_Name, COUNT(*) as Num_Satellites FROM Space_Satellites GROUP BY Agency_Name ORDER BY Num_Satellites DESC;', "SELECT COUNT(*) FROM Urban_Agriculture WHERE State IN ('California', 'New York') AND Year = 2019;", "SELECT country, SUM(quantity) FROM production WHERE crop = 'rice' GROUP BY country;"]
[False, True, False, False, True, True, True, True]
152 160


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT gender, AVG(mental_health_score) FROM students GROUP BY gender;', "SELECT COUNT(*) FROM Team_B_Matches WHERE result = 'Win';", 'SELECT r.region, COUNT(t.id) as num_countries FROM technology_access t JOIN regions r ON t.region = r.region GROUP BY r.region ORDER BY num_countries ASC LIMIT 1;', "SELECT Name, Sentence FROM Sentences WHERE Sentence = 'Life Imprisonment without Parole';", 'SELECT AVG(days_in_space) FROM astronauts;', 'SELECT ConcertID, COUNT(DISTINCT Artist) FROM ArtistConcert GROUP BY ConcertID;', 'SELECT CountryName, AVG(Budget) as AvgBudget FROM Country GROUP BY CountryName;', "SELECT AVG(account_balance) FROM islamic_banking_clients WHERE segment = 'Islamic Banking';"]
[True, True, False, True, True, True, False, True]
160 168


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT destination_province, MAX(weight) FROM Packages JOIN Warehouses ON Packages.warehouse_id = Warehouses.id WHERE Warehouses.city = 'Mexico City' GROUP BY destination_province;", "SELECT COUNT(*) FROM cases WHERE resolution_type = 'Mediation' AND resolution_date >= '2020-01-01' AND city = 'New York';", "SELECT AVG(price) FROM products WHERE vegan = true AND country = 'USA';", 'SELECT COUNT(DISTINCT Country) FROM Country_Spacecraft;', 'SELECT DISTINCT principle FROM ethical_ai_principles;', "SELECT AVG(budget) FROM ai_projects WHERE region = 'Latin America';", "SELECT name, capacity FROM Warehouses WHERE country = 'Canada';", "SELECT AVG(temperature), AVG(precipitation) FROM weather JOIN farms ON weather.farm_id = farms.id WHERE farms.location = 'Urban' AND weather.month = 4;"]
[True, True, False, True, True, True, True, False]
168 176


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT teams.team_name, SUM(points.points) FROM teams INNER JOIN points ON teams.team_id = points.team_id WHERE points.season = '2021' GROUP BY teams.team_name;", 'SELECT contributor, contributions FROM accessibility_contributors ORDER BY contributions DESC LIMIT 3;', "SELECT COUNT(*) FROM Packages WHERE arrived >= '2021-01-01' AND destination = 'Texas' AND (arrived < '2021-01-01' OR arrived >= '2021-01-01' AND arrived < '2021-01-01' AND arrived >= '2021-01-01' AND arrived < '2021-01-01' AND arrived >= '2021-01-01' AND arrived < '2021-01-01' AND arrived >= '2021-01-01' AND arrived < '2021-01-01' AND arrived >= '2021-01-01' AND arrived < '2021-01-01' AND arrived >= '2021-01-01' AND arrived < '2021-01-01' AND arrived >= '2021-01-01' AND arrived < '2021-01-01' AND arrived >= '2021-01-01' AND arrived < '2021-01-01' AND arrived >= '2021-01-01' AND arrived < '2021-01-01' AND arrived >= '2021-01-01' AND arrived < '2021-01-01' AND arrived >= '2021-01-01' AND arrived < '2021-01-01' AND arrive

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT ReportYear, COUNT(*) FROM MentalHealthParity WHERE ReportYear IN (2020, 2021) GROUP BY ReportYear;', "SELECT MIN(views) FROM videos_3 WHERE category ='music';", 'SELECT region, SUM(score) as total_score FROM media_representation GROUP BY region;', "SELECT SUM(mass_kg) FROM spacecraft WHERE name = 'Juno';", 'SELECT customer_id, total_sales_2022 FROM customers ORDER BY total_sales_2022 DESC LIMIT 1;', "SELECT MAX(capacity) FROM max_energy_storage WHERE country = 'Australia';", 'SELECT foul_type, COUNT(*) as foul_count FROM basketball_fouls GROUP BY foul_type ORDER BY foul_count DESC LIMIT 1;', "SELECT account_type, SUM(loan_amount) FROM loans WHERE client_region = 'Asia-Pacific' GROUP BY account_type;"]
[False, True, True, True, True, True, False, True]
184 192


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT MIN(lead_time) FROM WarehouseTransfers WHERE source_warehouse_id = 6;', 'SELECT teacher_id, COUNT(course_id) as courses_completed FROM teacher_pd GROUP BY teacher_id ORDER BY courses_completed DESC;', "SELECT COUNT(*) FROM solar_projects WHERE country IN ('Germany', 'Spain') AND completed = true;", 'SELECT COUNT(*) FROM acc_proj;', "SELECT MAX(score) FROM financial_capability WHERE country IN ('South Africa', 'Egypt', 'Nigeria');", "SELECT MAX(delivery_time) FROM deliveries WHERE warehouse = 'Mumbai' AND quarter = 4;", 'SELECT d.name FROM directors d JOIN movies_per_director m ON d.id = m.id ORDER BY m.movies_count DESC LIMIT 1;', 'SELECT outcome, COUNT(*) as count FROM space_missions GROUP BY outcome;']
[True, True, True, False, False, True, False, True]
192 200


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT City, 100.0 * SUM(FullPrice) / COUNT(*) as FullPricePercentage FROM Transactions GROUP BY City;', 'SELECT crop_type, AVG(yield/acres) as avg_yield_per_acre FROM crop_types GROUP BY crop_type;', 'SELECT department, AVG(course_completed) FROM teacher_professional_development GROUP BY department;', "SELECT COUNT(*) FROM employee_roster WHERE team = 'Ethical AI' AND join_date > '2021-06-01';", 'SELECT CountryName, SUM(CertificationCount) as TotalCertifications FROM EthicalAICertifications GROUP BY CountryName;', "SELECT program_name FROM fwp_programs WHERE country IN ('USA', 'UK');", "SELECT Warehouse.name, SUM(Handling.pallets) FROM Handling INNER JOIN Warehouse ON Handling.warehouse_id = Warehouse.id WHERE Warehouse.city = 'Paris' GROUP BY Warehouse.name;", "SELECT SUM(volume) FROM Canada_Freight WHERE origin_country = 'Mexico' AND destination_country = 'Canada';"]
[False, True, True, True, True, True, True, True]
200 208


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT SUM(weight) FROM shipments WHERE country = 'USA';", "SELECT SUM(revenue) FROM Sales WHERE market = 'US';", "SELECT product_name, SUM(units_sold) FROM product_sales WHERE country = 'Canada' AND quarter IN (2, 3) GROUP BY product_name;", "SELECT artists.artist_name, COUNT(*) as song_count FROM songs JOIN artists ON songs.artist_name = artists.artist_name WHERE songs.genre = 'R&B' GROUP BY artists.artist_name ORDER BY song_count DESC LIMIT 2;", 'SELECT s.school_name, 100.0 * COUNT(*) FILTER (WHERE s.school_id = s.participant_in_program) / COUNT(*) FROM school_lifelong_learning_participation s JOIN schools s ON s.school_id = s.school_id GROUP BY s.school_name;', 'SELECT AVG(capacity_MW) FROM geothermal_plants;', "SELECT COUNT(*) FROM patents WHERE ethical = true AND filed_country IN ('Mexico', 'Argentina', 'Colombia');", "SELECT AVG(cost) FROM warehouse_costs_apac WHERE warehouse_location IN ('Sydney Warehouse', 'Melbourne Warehouse') AND quarter = 2 AND year = 2023;"]
[True, True

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT name FROM Astronauts WHERE agency = 'Roscosmos';", 'SELECT m.name, COUNT(s.id) as num_satellites FROM satellites s JOIN manufacturers m ON s.manufacturer_id = m.id GROUP BY m.name ORDER BY num_satellites DESC LIMIT 1;', 'SELECT SUM(amount) FROM funding JOIN organizations ON funding.org_id = organizations.org_id WHERE implemented_digital_divide_initiatives = TRUE;', 'SELECT country, COUNT(*) FROM financial_wellbeing_programs GROUP BY country;', "SELECT AVG(amount) FROM shariah_financing WHERE client_country IN ('Indonesia', 'Bahrain', 'UAE') GROUP BY client_country ORDER BY AVG(amount) DESC LIMIT 3;", "SELECT COUNT(*) FROM SpaceMissions WHERE agency = 'NASA' AND year < 2000 AND manned = true;", 'SELECT region, AVG(score) FROM energy_efficiency WHERE year = 2021 GROUP BY region;', 'SELECT AVG(rebounds) FROM wilt_stats WHERE game = 1;']
[True, False, True, True, False, True, True, False]
216 224


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT continent, SUM(quantity) FROM shipments JOIN shipment_items ON shipments.shipment_id = shipment_items.shipment_id GROUP BY continent;', 'SELECT title FROM shows WHERE runtime > 60;', "SELECT AVG(labor_cost) FROM factories WHERE country LIKE 'Africa%';", "SELECT SUM(streams) FROM song_streams WHERE song_title = 'Bohemian Rhapsody' AND platform IN ('Spotify', 'Apple Music');", "SELECT SUM(courses_completed) FROM teachers WHERE school = 'Westside' AND year = 2019;", "SELECT SUM(EnergyConsumption.Consumption) FROM EnergyConsumption INNER JOIN Emissions ON EnergyConsumption.Sector = Emissions.Sector WHERE EnergyConsumption.Year = 2020 AND EnergyConsumption.Sector IN ('Residential', 'Commercial');", "SELECT SUM(capacity) FROM solar_farm WHERE country IN ('China', 'Spain');", 'SELECT name FROM organizations WHERE gov_funding = TRUE AND private_funding = FALSE;']
[True, True, False, True, True, False, True, True]
224 232


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT AVG(score) FROM ai_tools WHERE type = 'SocialGood';", "SELECT MIN(budget) FROM projects WHERE region = 'Africa';", "SELECT COUNT(*) FROM shariah_compliant_finance WHERE country = 'United Arab Emirates';", "SELECT SUM(shipment_weight) FROM ShipmentWeights WHERE destination_continent = 'South America';", "SELECT COUNT(*) FROM farms WHERE country = 'USA' AND organic = TRUE;", "SELECT State, AVG(Age) FROM CommunityHealthWorkers WHERE Gender = 'Non-binary' GROUP BY State;", 'SELECT home_team, SUM(yellow_cards_home) as total_yellow_cards FROM soccer_matches GROUP BY home_team;', "SELECT name FROM organizations WHERE region = 'Asia' AND involvement ='social good';"]
[True, False, True, True, True, True, False, True]
232 240


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT * FROM ReverseLogistics WHERE date BETWEEN '2023-01-01' AND '2023-01-31';", 'SELECT State, SUM(Coverage) as TotalCoverage FROM MentalHealthParity GROUP BY State ORDER BY TotalCoverage DESC;', "SELECT COUNT(*) FROM Manufacturing WHERE garment_type = 'T-Shirt' AND country = 'Turkey' AND year = 2022;", 'SELECT system_name, production FROM african_indigenous_systems ORDER BY production DESC;', "SELECT emissions FROM co2_emissions WHERE country = 'Australia' AND sector = 'Energy';", "SELECT AVG(generation) FROM hydro_power WHERE country IN ('Norway', 'Sweden');", "SELECT name FROM users WHERE region = 'North America' AND age > 30;", "SELECT artist_name, MAX(total_streams) FROM artist_streams WHERE platform IN ('Spotify', 'Apple Music') GROUP BY artist_name;"]
[False, True, True, False, True, False, True, True]
240 248


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT sport, AVG(salary) FROM athlete_salaries GROUP BY sport;', "SELECT MAX(SocialGoodBudget.Budget) FROM SocialGoodBudget INNER JOIN Countries ON SocialGoodBudget.Country = Countries.Country WHERE Countries.Continent = 'Africa';", 'SELECT destination_continent, SUM(quantity) FROM Shipment GROUP BY destination_continent;', 'SELECT article_language, COUNT(*) as article_count FROM articles GROUP BY article_language;', "SELECT AVG(yield) FROM crops WHERE region = 'Pacific' AND year = 2021;", "SELECT COUNT(*) FROM open_pedagogy_resources WHERE access_date >= '2022-03-01' AND access_date < '2022-04-01';", 'SELECT socially_responsible_loans.client_id, credit_cards.card_type FROM socially_responsible_loans INNER JOIN credit_cards ON socially_responsible_loans.client_id = credit_cards.client_id WHERE socially_responsible_loans.client_id IS NOT NULL AND credit_cards.client_id IS NOT NULL;', "SELECT Name, CulturalCompetencyScore FROM Hospitals WHERE Region = 'Northeast';"]
[True, True, True,

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT SUM(amount) FROM sales JOIN vendors ON sales.vendor_id = vendors.vendor_id WHERE vendors.region = 'Midwest';", "SELECT project_name, SUM(budget) FROM energy_efficiency_projects WHERE state = 'California' GROUP BY project_name;", "SELECT COUNT(*) FROM ethical_ai_initiatives WHERE region IN ('Asia', 'Europe');", 'SELECT Area, AVG(MentalHealthScore) as Avg_Score FROM MentalHealthScores GROUP BY Area;', "SELECT SUM(number_of_cases) FROM court_cases WHERE county = 'Los Angeles' AND year = 2020;", 'SELECT country, frequency FROM media_content;', 'SELECT supplier_id, COUNT(*) FROM product GROUP BY supplier_id;', 'SELECT mass_range, COUNT(*) as count FROM space_debris_by_mass GROUP BY mass_range;']
[True, True, False, False, True, False, True, False]
256 264


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT AVG(area_ha) FROM agroecological_projects WHERE location = 'Asia';", 'SELECT region FROM initiatives WHERE success_score > 75 ORDER BY success_score DESC;', "SELECT type, COUNT(*) as count FROM resources WHERE district = 'Brookside' GROUP BY type ORDER BY count ASC LIMIT 1;", 'SELECT SUM(program_completed) FROM teacher_development WHERE program_completed > 0;', 'SELECT Category, Resolution, COUNT(Cases) FROM CasesByJusticeCategory WHERE Year = 2021 GROUP BY Category, Resolution;', "SELECT MAX(case_type) FROM cases WHERE country = 'Australia' AND case_type = 'Restorative Justice';", 'SELECT vendors.vendor_name, SUM(sales.amount) FROM vendors INNER JOIN sales ON vendors.vendor_id = sales.vendor_id WHERE sales.amount > 10000 GROUP BY vendors.vendor_name;', "SELECT name FROM teachers WHERE subject = 'Computer Science' ORDER BY hire_date;"]
[True, False, True, True, False, False, True, False]
264 272


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT team_name, AVG(strikeouts) FROM baseball_teams GROUP BY team_name;', "SELECT MAX(goals_scored) FROM goals WHERE team = 'Montreal Canadiens';", 'SELECT Ethnicity, COUNT(*) as TotalCases FROM DiversityInJustice GROUP BY Ethnicity;', 'SELECT AVG(volunteer_age) FROM restorative_justice_programs;', 'SELECT hour_type, location, ethnicity, SUM(hours) FROM legal_aid_hours_ethnicity GROUP BY hour_type, location, ethnicity;', 'SELECT country, COUNT(*) as num_successes FROM space_missions GROUP BY country;', "SELECT f.name, COUNT(DISTINCT c.variety) as unique_varieties FROM farms f JOIN crops c ON f.id = c.farm_id WHERE f.location = 'Asia' AND c.last_harvest_date >= '2022-01-01' GROUP BY f.name;", 'SELECT name, location FROM Farmers WHERE years_of_experience > 10;']
[True, True, False, True, True, False, False, False]
272 280


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT MAX(home_team_player_hat_tricks + away_team_player_hat_tricks) FROM german_matches;', 'SELECT AVG(weight) FROM shipment WHERE warehouse_id = 3;', 'SELECT product_category, SUM(sales) FROM sales GROUP BY product_category;', "SELECT SUM(capacity) FROM windfarm WHERE country IN ('Germany', 'France');", "SELECT SUM(budget) FROM accessible_tech WHERE sector = 'education';", "SELECT SUM(amount) FROM donations WHERE donor = 'Aisha' AND donation_date BETWEEN '2021-01-01' AND '2021-12-31';", 'SELECT transportation_mode, COUNT(DISTINCT item_type) FROM shipments GROUP BY transportation_mode;', "SELECT MAX(number) FROM mental_health_parity_violations WHERE state IN ('Alabama', 'Georgia', 'Florida', 'North Carolina', 'South Carolina', 'Mississippi', 'Louisiana', 'Arkansas', 'Tennessee', 'Kentucky', 'Virginia');"]
[True, True, True, True, True, True, True, True]
280 288


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT COUNT(DISTINCT service_type) FROM defendant_services;', "SELECT AVG(songs.song_length) FROM songs JOIN artists ON songs.artist_id = artists.artist_id WHERE artists.country = 'United States';", 'SELECT district_name, AVG(mental_health_score) as avg_score FROM student_mental_health GROUP BY district_name ORDER BY avg_score DESC;', "SELECT SUM(generation) FROM energy_generation WHERE country IN ('Kenya', 'Nigeria', 'South Africa') AND generation_date BETWEEN '2021-01-01' AND '2021-03-31';", 'SELECT name, MAX(home_runs) FROM baseball_stats;', "SELECT SUM(income) as total_income, SUM(expenses) as total_expenses FROM FinancialWellbeingPrograms WHERE country = 'Australia';", "SELECT SUM(weight) FROM parcels JOIN shipments ON parcels.id = shipments.shipment_id WHERE shipments.source_airport = 'FRA' AND shipments.destination_airport = 'ICN' AND shipped_date BETWEEN '2022-03-01' AND '2022-03-31';", "SELECT country, SUM(product_quantity) as total_quantity FROM ethical_brands JOIN sales O

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT stadium_name, capacity FROM stadiums WHERE capacity > 70000;', 'SELECT DISTINCT city FROM conferences JOIN ethical_ai_topics ON conferences.id = ethical_ai_topics.conference_id WHERE conference_id IN (1, 4);', "SELECT AVG(amount) FROM shariah_compliant_loans WHERE region IN ('Middle East', 'Africa');", "SELECT COUNT(*) FROM financial_institutions WHERE country IN ('UAE', 'Egypt') AND is_shariah_compliant = true;", 'SELECT District, COUNT(*) as NumOfPrograms FROM ADRPrograms WHERE YearEstablished BETWEEN 2010 AND 2020 GROUP BY District;', 'SELECT r.name, AVG(m.carbon_footprint) FROM regions r JOIN manufacturers m ON r.id = m.region_id GROUP BY r.name;', "SELECT MIN(Price) FROM Products INNER JOIN Stores ON Products.StoreID = Stores.StoreID WHERE Products.Category = 'Grocery' AND Stores.Country = 'USA' AND Stores.State = 'New York';", "SELECT SUM(quantity) FROM sales WHERE business_size ='small';"]
[True, False, True, False, True, True, True, True]
296 304


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT COUNT(DISTINCT country) FROM space_agencies;', "SELECT mission_name, launch_date FROM missions WHERE country = 'Russia';", 'SELECT country, COUNT(*) as mission_count FROM SpaceMissions GROUP BY country;', 'SELECT return_point, COUNT(*) as count FROM returns WHERE return_half = 1 AND return_year = 2022 GROUP BY return_point ORDER BY count DESC LIMIT 3;', 'SELECT region, worker_count FROM region_health_workers;', 'SELECT county, COUNT(*) as org_count FROM legal_aid_organizations GROUP BY county;', 'SELECT country, COUNT(*) as num_missions FROM space_missions GROUP BY country;', "SELECT SUM(quantity) FROM emissions WHERE emission_type = 'CO2' AND country = 'France';"]
[True, True, True, False, False, True, True, True]
304 312


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT title, streams FROM songs WHERE genre = 'Hip-Hop' AND year = 2021 ORDER BY streams DESC LIMIT 3;", 'SELECT country, SUM(energy_consumption) as total_consumption FROM energy_consumption WHERE year = 2020 GROUP BY country;', "SELECT name, wins FROM teams WHERE league = 'UEFA Champions League';", 'SELECT MAX(homeruns) FROM single_game_homeruns;', "SELECT MAX(loan_amount) FROM socially_responsible_loans WHERE region = 'Asia-Pacific';", "SELECT genre, COUNT(*) as num_shows, MAX(rating) as max_rating FROM tv_shows WHERE production_country = 'Japan' AND release_year BETWEEN 2015 AND 2020 GROUP BY genre ORDER BY max_rating DESC;", 'SELECT region, COUNT(*) FROM socially_responsible_lending GROUP BY region;', "SELECT warehouse_location, SUM(quantity) FROM warehouse_data WHERE item_name = 'Widget' GROUP BY warehouse_location HAVING SUM(quantity) < 50;"]
[True, True, False, True, True, False, True, False]
312 320


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT Region, HealthWorkerCount FROM RegionHealthWorkers;', 'SELECT case_type, AVG(processing_time) FROM case_processing GROUP BY case_type;', "SELECT SUM(production) FROM production_data WHERE country = 'Kenya';", "SELECT name, energy_star_rating FROM appliances WHERE country = 'USA' ORDER BY energy_star_rating DESC LIMIT 3;", "SELECT state, year, SUM(consumption) as total_consumption, energy_type, SUM(consumption) as renewable_consumption, SUM(consumption) as non_renewable_consumption FROM energy_consumption WHERE state = 'New York' AND year = 2020 GROUP BY state, year, energy_type;", 'SELECT AVG(goals_home) FROM games WHERE attendance > 50000;', "SELECT MIN(health_equity_metric_score) FROM healthcare_providers WHERE location = 'Rural';", "SELECT category, COUNT(*) FROM media_content WHERE studio_location IN ('Brazil', 'Japan') GROUP BY category;"]
[True, True, False, True, False, True, True, True]
320 328


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT country, SUM(cost) FROM missions GROUP BY country;', "SELECT organization, COUNT(*) as num_projects FROM ai_ethics WHERE country = 'Canada' GROUP BY organization;", "SELECT gender, AVG(followers) FROM news_anchors WHERE news_channel = 'Channel1' GROUP BY gender;", 'SELECT type, AVG(temperature) FROM crop GROUP BY type;', "SELECT name FROM urban_agriculture_initiatives WHERE location = 'Montreal' AND area_ha > 0.5;", 'SELECT COUNT(DISTINCT organization_name) FROM social_good_middle_east;', 'SELECT occupation, gender, AVG(score) as avg_score FROM financial_capability_3 GROUP BY occupation, gender;', "SELECT COUNT(*) FROM podcasts WHERE publication_year = 2019 AND creator_community = 'Underrepresented Community';"]
[False, True, True, True, True, True, False, True]
328 336


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT MAX(units_sold) FROM product_sales WHERE country = 'Germany' AND year = 2020;", 'SELECT country, SUM(satellites) as total_satellites FROM SpaceRadar GROUP BY country ORDER BY total_satellites DESC;', "SELECT SUM(quantity) FROM shipments JOIN warehouses ON shipments.warehouse_id = warehouses.warehouse_id WHERE warehouses.city = 'NYC';", "SELECT COUNT(*) FROM tv_shows WHERE country = 'Spain' AND year = 2017;", "SELECT MIN(pub_date) FROM articles_tech WHERE category = 'Tech';", "SELECT SUM(Revenue) FROM Revenue WHERE Practice = 'Ethical Labor' AND Country = 'South America';", "SELECT name FROM space_craft WHERE orbit = 'GTO' ORDER BY mass DESC LIMIT 1;", 'SELECT SUM(Mass) FROM Space_Debris;']
[True, True, True, True, True, True, True, True]
336 344


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT salesperson_id, COUNT(*) AS sales_count, SUM(revenue) AS total_revenue FROM sales GROUP BY salesperson_id ORDER BY sales_count DESC;', 'SELECT country, MAX(initiatives) as max_initiatives FROM urban_agriculture GROUP BY country ORDER BY max_initiatives DESC;', "SELECT AVG(efficiency_rating) FROM energy_efficiency WHERE building_type = 'Residential' AND country = 'India';", "SELECT SUM(sessions) FROM financial_capability_training WHERE country = 'Germany' AND quarter = 1 AND year = 2022;", "SELECT worker_id, region, SUM(CASE WHEN metric1 THEN 1 ELSE 0 END) AS metric1_met, SUM(CASE WHEN metric2 THEN 1 ELSE 0 END) AS metric2_met, SUM(CASE WHEN metric3 THEN 1 ELSE 0 END) AS metric3_met FROM health_equity_metrics_worker WHERE region = 'West' AND year = 2020 GROUP BY worker_id, region, metric1_met + metric2_met + metric3_met;", "SELECT product_category, SUM(sale_amount) as total_sales FROM sales JOIN products ON sales.product_id = products.product_id WHERE sale_region = 'Europe' GRO

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT AVG(height) FROM players WHERE team = 'Atlanta Hawks';", "SELECT AVG(funding) FROM projects WHERE category = 'SocialGood' AND name LIKE '%Technology%';", 'SELECT route, AVG(delivery_time) FROM delivery GROUP BY route;', "SELECT SUM(Shipments.Weight) FROM Shipments JOIN FreightForwarders ON Shipments.FreightForwarderID = FreightForwarders.ID WHERE Shipments.Origin = 'Brazil' AND Shipments.Destination = 'India' AND FreightForwarders.Name = 'DEF Logistics';", "SELECT AVG(rating) FROM movies WHERE director LIKE '%Woman%';", "SELECT SUM(duration) FROM Videos WHERE category = 'Entertainment' AND rating > 8;", 'SELECT SUM(mass) FROM exoplanets WHERE atmosphere = true;', "SELECT COUNT(*) FROM streams WHERE genre = 'Country' AND country = 'USA' AND stream_date BETWEEN '2021-02-01' AND '2021-02-28';"]
[True, False, True, True, False, True, True, True]
352 360


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT ProjectName FROM Projects WHERE LeaderCommunity LIKE 'Historically Underrepresented%' AND Domain = 'Social Good';", "SELECT principle_name FROM EthicalAI WHERE project_location = 'India';", "SELECT AVG(budget) FROM company_tech WHERE name LIKE '%ethical AI%';", "SELECT MIN(salary) FROM salaries WHERE team = 'Social Good';", "SELECT Warehouses.name, SUM(Inventory.pallets) FROM Inventory INNER JOIN Warehouses ON Inventory.warehouse_id = Warehouses.id WHERE Warehouses.country = 'France' GROUP BY Warehouses.name;", "SELECT COUNT(*) FROM shipment WHERE warehouse_id = 4 AND delivery_location = 'Berlin' AND shipped_date BETWEEN '2021-02-01' AND '2021-02-28' AND weight > 15;", "SELECT victim_name FROM restorative_justice_programs WHERE program_state = 'New York';", 'SELECT location, COUNT(DISTINCT genre) FROM media GROUP BY location;']
[True, True, False, True, True, True, True, True]
360 368


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['SELECT AVG(altitude) FROM geostationary_satellites;', "SELECT SUM(satellites) FROM satellite_launches WHERE country IN ('India', 'USA');", 'SELECT artist_name, COUNT(*) as song_count FROM Songs GROUP BY artist_name;', 'SELECT COUNT(DISTINCT customer_id) FROM credit_cards;', "SELECT AVG(score) FROM financial_wellbeing_eu WHERE country IN ('Germany', 'France', 'UK');", 'SELECT year, country, SUM(num_satellites) FROM satellite_launches_by_year_country GROUP BY year, country;', "SELECT SUM(Quantity) FROM MaizeProduction WHERE System = 'Indigenous Food Systems';", 'SELECT warehouse_id, AVG(weight) FROM packages WHERE weight <= 80 GROUP BY warehouse_id;']
[True, True, True, False, True, True, True, True]
368 376


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT Metric_Name, Metric_Value FROM HealthEquityMetrics WHERE Region = 'rural';", 'SELECT AVG(cost) FROM mars_missions;', "SELECT f.name, f.age, f.location, p.product_name, p.price FROM Agroecology_Farmers f JOIN Agroecology_Produce p ON f.id = p.farmer_id WHERE f.location IN ('Senegalese Savannah', 'Kenyan Highlands');", "SELECT SUM(yield) FROM organic_farms WHERE state IN ('CA', 'TX') AND year = 2020;", 'SELECT gender, age, ethnicity, AVG(mental_health_score) as avg_mental_health_score FROM students GROUP BY gender, age, ethnicity;', 'SELECT AVG(mental_health_score) FROM students WHERE participated_in_open_pedagogy = TRUE;', 'SELECT country, COUNT(*) as num_resources FROM student_access JOIN open_resources ON student_access.resource_id = open_resources.resource_id GROUP BY country;', 'SELECT player_name, SUM(points) FROM nba_scores GROUP BY player_name;']
[True, True, False, True, True, True, True, True]
376 384


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT COUNT(*) FROM digital_divide_initiatives WHERE region = 'Asia';", "SELECT name FROM farms WHERE size > 150 AND location = 'Texas';", "SELECT SUM(revenue) FROM sales_view WHERE farmer_id = 1 AND crop_name = 'Potatoes';", "SELECT ArtistName, SUM(SalesAmount) as TotalRevenue FROM MusicSales WHERE Genre = 'Digital' GROUP BY ArtistName ORDER BY TotalRevenue DESC LIMIT 3;", "SELECT AVG(Budget) FROM Accessible_Tech_Projects WHERE Location = 'Africa';", 'SELECT name, MAX(balance) FROM shariah_compliant_finance GROUP BY name;', 'SELECT WarehouseManagementTransactions.TransactionID, WarehouseManagementTransactions.TransactionStatus, WarehouseManagementTransactions.TransactionDate FROM WarehouseManagementTransactions WHERE WarehouseManagementTransactions.WarehouseID = 3;', "SELECT COUNT(*) FROM community_health_workers WHERE community IN ('First Nations', 'Inuit', 'Métis');"]
[True, True, True, False, True, False, True, True]
384 392


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT MAX(altitude) FROM RocketAltitudes WHERE rocket = 'Falcon 9';", 'SELECT country, SUM(cost) FROM satellites GROUP BY country;', 'SELECT Country, AVG(Consumption) as AvgConsumptionPerCapita FROM EnergyConsumptionPerCapita WHERE Year = 2021 GROUP BY Country;', "SELECT MAX(usage) FROM power_usage WHERE building_type = 'Industrial' AND location = 'California';", "SELECT matches FROM tennis_tournaments WHERE court = 'Grass';", "SELECT SUM(fans_attended) FROM matches WHERE team = 'Manchester United' AND year = 2020;", 'SELECT Genre, SUM(RunningTime) FROM GenreRunningTimes GROUP BY Genre;', "SELECT products.name, suppliers.name FROM products INNER JOIN suppliers ON products.supplier_id = suppliers.id WHERE suppliers.name!= 'Green Cotton Inc.';"]
[True, True, True, False, False, True, True, True]
392 400


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["SELECT SUM(mass) FROM gso_debris WHERE orbit = 'GSO' AND source = 'Defunct Satellite';", 'SELECT MAX(cost) FROM rover_missions;', 'SELECT SUM(cost) FROM space_missions;', 'SELECT s.name, AVG(l.progress) as avg_progress FROM lifelong_learning l JOIN schools s ON l.school_id = s.school_id GROUP BY s.name;', "SELECT COUNT(cards) FROM yellow_cards JOIN teams ON yellow_cards.team_id = teams.team_id WHERE teams.name = 'Bayern Munich' AND tournament = 'Champions League';", 'SELECT COUNT(*) FROM clients JOIN loans ON clients.client_id = loans.client_id WHERE clients.is_financially_capable = true;', "SELECT COUNT(*) FROM Warehouses WHERE city = 'City Y' AND capacity > 100000;", "SELECT outlet_name, COUNT(*) as num_articles FROM media_outlets JOIN fact_checks ON media_outlets.outlet_id = fact_checks.outlet_id WHERE fact_checks.is_true = TRUE AND fact_checks.fact_check_date >= '2021-01-01' GROUP BY outlet_name ORDER BY num_articles DESC LIMIT 5;"]
[True, False, True, False, False, True, True, T

In [None]:
len(all_evals), np.array(all_evals).mean()

(408, 0.6936274509803921)

In [None]:
# test_ds = test_ds.remove_columns(['score', 'pred_sql'])
test_ds = test_ds.add_column('score', np.array(all_evals))
test_ds = test_ds.add_column('pred_sql', pred_sqls)

In [None]:
check_df = pd.DataFrame(test_ds)

In [None]:
check_df.groupby(['sql_complexity'])['score'].mean()

sql_complexity
aggregation    0.631148
basic SQL      0.756637
single join    0.583333
Name: score, dtype: float64