In [7]:
import openai

import itertools
import pandas as pd
import random
import os
import csv

from openai import OpenAI

In [2]:
domains = {
    'healthcare': 'Comprehensive data on patient demographics, medical histories, treatment protocols, and outcomes; drug efficacy studies and side effects.',
    'technology': 'Data on emerging technologies, market trends, innovation, and technological disruptions across various industries.',
    'hospitality': 'Information on travel and tourism trends, customer preferences, hotel management, and global hospitality analysis.',
    'environment': 'Data related to climate change, environmental conservation efforts, biodiversity, and sustainable development initiatives.',
    'entertainment': 'Records on media consumption patterns, content creation, audience engagement, and entertainment industry revenue streams.',
    'automotive': 'Detailed information on vehicle sales, automotive technologies, sustainable transportation, and global automotive market analysis.',
    'agriculture': 'Data on crop yields, farming practices, agricultural supply chain, and global food security initiatives.',
    'government': 'Comprehensive records on public policy, legislative processes, government spending, citizen demographics, and governance effectiveness.',
    'fashion': 'Information on fashion trends, consumer behavior, sustainable fashion practices, and global fashion industry analysis.',
    'sports': 'Detailed data on sports events, athlete performance, fan engagement, sports industry economics, and global sports market trends.'
}

user_proficiencies = {
    "beginner",
    "intermediate",
    "advanced",
    "expert"
}

sql_complexity = {
    "basic SQL with a simple select statement",
    "only one join (specify inner, outer, cross)",
    "two or more joins (specify inner, outer, cross)",
    "subqueries, including correlated and nested subqueries",
    "common table expressions",
    "aggregation functions (COUNT, SUM, AVG, MIN, MAX, etc.), and HAVING clause",
    "set operations such as UNION, INTERSECT, and EXCEPT",
    "window functions (e.g., ROW_NUMBER, LEAD, LAG, RANK, NTILE, PERCENT_RANK, etc.)",
    "pivoting and unpivoting"
}

sql_task = {
    "data definition: creating, altering, or dropping tables and other database objects",
    "data retrieval: basic data fetching queries",
    "data manipulation: inserting, updating, or deleting records",
    "analytics and reporting: generating reports, dashboards, and analytical insights",
    "transactional processing: SQL transaction control statements (e.g., BEGIN, COMMIT, ROLLBACK)"
}


In [3]:
# Generating the cartesian product
combinations = list(itertools.product(domains.keys(), user_proficiencies, sql_complexity, sql_task))
random.shuffle(combinations)

# Sampling 100 combinations
sampled_combinations = combinations[:1000]

# Creating a dataframe
df = pd.DataFrame(sampled_combinations, columns=['domain', 'user_proficiency', 'sql_complexity', 'sql_task'])

# Adding the descriptions
df['domain_description'] = df['domain'].map(domains)


In [4]:
df.head()

Unnamed: 0,domain,user_proficiency,sql_complexity,sql_task,domain_description
0,government,advanced,common table expressions,"data manipulation: inserting, updating, or del...","Comprehensive records on public policy, legisl..."
1,automotive,intermediate,"aggregation functions (COUNT, SUM, AVG, MIN, M...",transactional processing: SQL transaction cont...,"Detailed information on vehicle sales, automot..."
2,entertainment,beginner,basic SQL with a simple select statement,"data manipulation: inserting, updating, or del...","Records on media consumption patterns, content..."
3,government,expert,"subqueries, including correlated and nested su...","analytics and reporting: generating reports, d...","Comprehensive records on public policy, legisl..."
4,fashion,expert,"set operations such as UNION, INTERSECT, and E...",data retrieval: basic data fetching queries,"Information on fashion trends, consumer behavi..."


In [5]:
def get_prompt(domain: str, data_description: str, user_level: str, sql_complexity: str, sql_task: str, verbose=False) -> str:
    prompt = f"""
    ###### GUIDELINES
    Create 1 comma delimited row with following specifications. Exclude headers and just give me csv. Do not include ```csv ``` in response, just want the raw text in csv format.
    Please wrap each value in double quotes.

    example row:
    "Create a table that stores the average fuel efficiency for different vehicle models across various regions, and then display the top 5 most fuel-efficient models.","CREATE TABLE FuelEfficiency (ModelID INT PRIMARY KEY, ModelName VARCHAR(50), Region VARCHAR(50), AvgFuelEfficiency DECIMAL); INSERT INTO FuelEfficiency (ModelID, ModelName, Region, AvgFuelEfficiency) VALUES (1, 'Model S', 'North America', 120.5), (2, 'Model 3', 'Europe', 118.0), (3, 'Leaf', 'Asia', 110.0), (4, 'Bolt', 'North America', 115.7), (5, 'Prius', 'Europe', 95.0);", "WITH RankedModels AS (SELECT ModelName, AvgFuelEfficiency, DENSE_RANK() OVER (ORDER BY AvgFuelEfficiency DESC) AS Rank FROM FuelEfficiency) SELECT ModelName, AvgFuelEfficiency FROM RankedModels WHERE Rank <= 5;","The query creates a common table expression (CTE) to rank vehicle models based on their average fuel efficiency in descending order. It then selects the top 5 most fuel-efficient models using the DENSE_RANK function."    
    
    Column 1: natural language prompt
    * column name: prompt
    * a well-formulated question or command in everyday English, representing a user query to a database
    * prompt should require using {user_level} level of SQL and ideally {sql_complexity}
    * prompt should be in the {domain} domain and pertain to {data_description}
    * prompt should pertain to {sql_task} first and foremost

    Column 2: tables and views that already exist in the database
    * column name: context
    * this should be database context, NOT the SQL query representing a user’s prompt
    * include complete executable SQL table CREATE statements and/or view CREATE statements with capitalized keywords
    * only table CREATE and CREATE + INSERT statements are allowed as database context. No ALTER/DROP/UPDATE
    * provide up to five tables/views that are relevant to the user’s natural language prompt
    * provide only the SQL code for create statements, separated by a semicolon
    * make sure there is no text preceding or following SQL code
    * do not use newline characters or breakline tags in SQL code
    * do not use ellipsis anywhere: this should be a valid, executable SQL statement
    * do not use phrases like “same as previous example”
    * table names and schemas should correspond to the {domain} domain and focus on {data_description}
    * make sure there is no text of any kind preceding or following SQL code

    Column 3: SQL query
    * column name: sql
    * A complete and executable SQL query used to answer/execute the natural language prompt
    * do not use ellipsis anywhere: this should be a valid, executable SQL statement
    * SQL should be based on the database context generated above
    * To the extent possible, SQL should leverage {sql_complexity}
    * SQL should be written at an {user_level} SQL proficiency level
    * do not use newline characters or breakline tags in SQL code
    * make sure there is no text of any kind preceding or following SQL code

    Column 4: explanation
    * column name: explanation
    * a step-by-step explanation of what the SQL query is doing
    """
    return prompt


In [8]:
client = OpenAI(
    # This is the default and can be omitted
    api_key='TOKEN HERE',
)


In [None]:
# def validate_csv_response(response_text: str, num_columns: int=4) -> list:
#     rows = response_text.strip().split('\n')
#     valid_rows = []
#     for row in rows:
#         try:
#             parsed_row = [r for r in csv.reader([row])]
#             if len(parsed_row) == 1 and len(parsed_row[0]) == num_columns:
#                 valid_rows.append(parsed_row[0])
#         except Exception as e:
#             continue
#     return valid_rows

# # Helper function to write rows to CSV
# def write_rows_to_csv(rows: list, file_path: str):
#     with open(file_path, mode='a', newline='') as file:
#         writer = csv.writer(file)
#         for row in rows:
#             writer.writerow(row)

# output_csv = 'gpt-data_1k.csv'
# headers = ["domain", "domain_description", "user_proficiency", "sql_complexity", "sql_task", "sql_prompt", "sql_context", "sql", "sql_explanation"]
# if not os.path.exists(output_csv):
#     with open(output_csv, mode='w', newline='') as file:
#         writer = csv.writer(file)
#         writer.writerow(headers)

# index = 0
# rows_written = 0
# max_rows = 1000
# num_columns = len(headers)

# df_shuffled = df.sample(frac=1).reset_index(drop=True)
# while rows_written < max_rows and index < len(df_shuffled):
#     row = df_shuffled.iloc[index]
#     domain = row['domain']
#     data_description = row['domain_description']
#     user_level = row['user_proficiency']
#     sql_complexity = row['sql_complexity']
#     sql_task = row['sql_task']
    
#     prompt = get_prompt(domain, data_description, user_level, sql_complexity, sql_task)
    
#     response = client.chat.completions.create(
#         model="gpt-4o",
#         messages=[
#             {"role": "system", "content": "You are a data and SQL expert specializing in generating synthetic data."},
#             {"role": "user", "content": prompt}
#         ],
#         temperature=0.8,
#     )

    
#     response_text = response.choices[0].message.content.strip()
    
#     valid_rows = validate_csv_response(response_text)
    
#     if valid_rows:
#         rows_to_write = min(len(valid_rows), max_rows - rows_written)
#         for valid_row in valid_rows[:rows_to_write]:
#             complete_row = [domain, data_description, user_level, sql_complexity, sql_task] + valid_row 
#             write_rows_to_csv([complete_row], output_csv)
#             rows_written += 1

#     print(f"Rows Written: {rows_written}")
#     print(f"Index: {index}")
#     print("="*20)
#     index += 1

In [None]:
import os
import csv
import random
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

def validate_csv_response(response_text: str, num_columns: int=4) -> list:
    rows = response_text.strip().split('\n')
    valid_rows = []
    for row in rows:
        try:
            parsed_row = [r for r in csv.reader([row])]
            if len(parsed_row) == 1 and len(parsed_row[0]) == num_columns:
                valid_rows.append(parsed_row[0])
        except Exception as e:
            continue
    return valid_rows

# Helper function to write rows to CSV
def write_rows_to_csv(rows: list, file_path: str):
    with open(file_path, mode='a', newline='') as file:
        writer = csv.writer(file)
        for row in rows:
            writer.writerow(row)

# Function to handle each row
def process_row(row, client, output_csv, max_rows, max_retries=2):
    retries = 0
    while retries <= max_retries:
        try:
            # Random initial sleep to space out requests
            time.sleep(random.uniform(3, 10))

            domain = row['domain']
            data_description = row['domain_description']
            user_level = row['user_proficiency']
            sql_complexity = row['sql_complexity']
            sql_task = row['sql_task']
            
            prompt = get_prompt(domain, data_description, user_level, sql_complexity, sql_task)
            
            response = client.chat.completions.create(
                model="gpt-4o",
                messages=[
                    {"role": "system", "content": "You are a data and SQL expert specializing in generating synthetic data."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.8,
            )

            response_text = response.choices[0].message.content.strip()
            
            valid_rows = validate_csv_response(response_text)
            
            written_rows = 0
            if valid_rows:
                rows_to_write = min(len(valid_rows), max_rows)
                complete_rows = [[domain, data_description, user_level, sql_complexity, sql_task] + valid_row for valid_row in valid_rows[:rows_to_write]]
                write_rows_to_csv(complete_rows, output_csv)
                written_rows = len(complete_rows)

            return written_rows

        except Exception as e:
            if '429' in str(e) and retries < max_retries:
                retries += 1
                wait_time = random.uniform(3, 10)
                print(f"Retrying {retries}/{max_retries} after {wait_time:.2f} seconds due to 429 error...")
                time.sleep(wait_time)
            else:
                return f"error generating: {e}"

output_csv = 'gpt-data_1k.csv'
headers = ["domain", "domain_description", "user_proficiency", "sql_complexity", "sql_task", "sql_prompt", "sql_context", "sql", "sql_explanation"]
if not os.path.exists(output_csv):
    with open(output_csv, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(headers)

rows_written = 0
max_rows = 1000
num_columns = len(headers)

df_shuffled = df.sample(frac=1).reset_index(drop=True)

with ThreadPoolExecutor(max_workers=4) as executor:
    futures = []
    for index, row in df_shuffled.iterrows():
        if rows_written >= max_rows:
            break
        futures.append(executor.submit(process_row, row, client, output_csv, max_rows - rows_written))
    
    for future in as_completed(futures):
        try:
            written_rows = future.result()
            rows_written += written_rows
        except Exception as e:
            print(f"Error processing row: {e}")
        
        print(f"Rows Written: {rows_written}")
        print("="*20)

print(f"Total Rows Written: {rows_written}")


In [58]:
# Load Model

In [1]:
import csv
import yaml
from typing import Tuple
import re
from functools import partial

from datasets import load_dataset
from transformers import AutoTokenizer, BitsAndBytesConfig, GenerationConfig
from peft import AutoPeftModelForCausalLM
import torch

from llmtune.pydantic_models.config_model import Config

In [2]:
EXPERIMENT_HASH = "CmWSJ"
config_path = f"./experiment/{EXPERIMENT_HASH}/config/config.yml"

with open(config_path, "r") as f:
    config = yaml.safe_load(f)
    config = Config(**config)

In [3]:
weights_path = f"./experiment/{EXPERIMENT_HASH}/weights/"
model = AutoPeftModelForCausalLM.from_pretrained(
    weights_path,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)
model = model.merge_and_unload()
tok = AutoTokenizer.from_pretrained(weights_path, device_map="auto")



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


In [4]:
def infer(prompt: str, model, tok: AutoTokenizer) -> str:
    """
    Outputs predicted sequence and probability
    """
    input_ids = tok.encode(prompt, return_tensors="pt", truncation=True).cuda()
    gen_config = GenerationConfig(
        max_new_tokens=1024,
        do_sample=True,
        temperature=0.2,
        pad_token_id=tok.pad_token_id,
        return_dict_in_generate=True,
        output_scores=True,
        )

    gen_results = model.generate(input_ids, gen_config)
    gen_tok_ids = gen_results.sequences.squeeze(0)[input_ids.shape[1]:]

    seq = tok.decode(gen_tok_ids, skip_special_tokens=True)

    seq_score = torch.stack(gen_results.scores).squeeze(1) # get raw output scores 
    seq_score = torch.log(torch.softmax(seq_score, dim=1)) # softmax and log to get log_prob
    seq_score = seq_score.gather(dim=1, index=gen_tok_ids.view(-1, 1)).squeeze()
    seq_score = torch.exp(torch.sum(seq_score)).item() # sum and exp to get prob

    return seq, seq_score

In [5]:
with open("prompts.csv", mode='r', newline='') as file:
    reader = csv.reader(file)
    headers = next(reader)  # Read the header
    data = [tuple(row) for row in reader]  # Read the data and convert each row to a tuple

processed_preds = {}
for idx, row in enumerate(data):
    out, _ = infer(row[1], model, tok)
    out = out.strip()
    out += f"\t----- bird -----\t{data[idx][0]}"
    processed_preds[str(idx)] = out
    print(f"Done {idx+1}/{len(data)}")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Done 1/500
Done 2/500
Done 3/500
Done 4/500
Done 5/500
Done 6/500
Done 7/500
Done 8/500
Done 9/500
Done 10/500
Done 11/500
Done 12/500
Done 13/500
Done 14/500
Done 15/500
Done 16/500
Done 17/500
Done 18/500
Done 19/500
Done 20/500
Done 21/500
Done 22/500
Done 23/500
Done 24/500
Done 25/500
Done 26/500
Done 27/500
Done 28/500
Done 29/500
Done 30/500
Done 31/500
Done 32/500
Done 33/500
Done 34/500
Done 35/500
Done 36/500
Done 37/500
Done 38/500
Done 39/500
Done 40/500
Done 41/500
Done 42/500
Done 43/500
Done 44/500
Done 45/500
Done 46/500
Done 47/500
Done 48/500
Done 49/500
Done 50/500
Done 51/500
Done 52/500
Done 53/500
Done 54/500
Done 55/500
Done 56/500
Done 57/500
Done 58/500
Done 59/500
Done 60/500
Done 61/500
Done 62/500
Done 63/500
Done 64/500
Done 65/500
Done 66/500
Done 67/500
Done 68/500
Done 69/500
Done 70/500
Done 71/500
Done 72/500
Done 73/500
Done 74/500
Done 75/500
Done 76/500
Done 77/500
Done 78/500
Done 79/500
Done 80/500
Done 81/500
Done 82/500
Done 83/500
Done 84/500
D

In [6]:
import json
with open("predict_mini_dev_mistral-gpt_SQLite.json", 'w') as json_file:
    json.dump(processed_preds, json_file, indent=4)

In [7]:
processed_preds

{'0': "SELECT COUNT(*) FROM customers WHERE Currency = 'EUR' / COUNT(*) FROM customers WHERE Currency = 'CZK';\t----- bird -----\tdebit_card_specializing",
 '1': "WITH RankedCustomers AS (SELECT CustomerID, Consumption, DENSE_RANK() OVER (ORDER BY Consumption) AS Rank FROM yearmonth WHERE Date = '2012-01-01' AND Segment = 'LAM') SELECT CustomerID FROM RankedCustomers WHERE Rank = 1;\t----- bird -----\tdebit_card_specializing",
 '2': "SELECT AVG(Consumption) AS AvgConsumption FROM yearmonth WHERE CustomerID IN (SELECT CustomerID FROM customers WHERE Segment = 'SME') AND YEAR(Date) = 2013 AND MONTH(Date) BETWEEN 1 AND 12;\t----- bird -----\tdebit_card_specializing",
 '3': "SELECT SUM(CASE WHEN customers.Currency = 'CZK' THEN transactions_1k.Amount ELSE 0 END) - SUM(CASE WHEN customers.Currency = 'EUR' THEN transactions_1k.Amount ELSE 0 END) AS Difference FROM customers INNER JOIN transactions_1k ON customers.CustomerID = transactions_1k.CustomerID WHERE YEAR(transactions_1k.Date) = 2012;