In [10]:
print('hello, world!')

hello, world!


In [101]:
import os
import time 
import requests
from dotenv import load_dotenv

import datasets
from datasets import Dataset, DatasetDict

from typing import Optional, Dict, List, Union

import openai
from openai.openai_object import OpenAIObject

# Load environment variables
load_dotenv("../../.env")

# OpenAI API key
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")

HUGGING_FACE_API_TOKEN = os.environ.get("HUGGING_FACE_API_TOKEN")
DEV_SUBSET = 'alagaesia/spider_dev_subset_preds'

MISTRAL_7B_INSTRUCT_ENDPOINT = 'https://i7zsxqzxio916oxa.us-east-1.aws.endpoints.huggingface.cloud'

DEFOG_SQLCODER_7B_ENDPOINT = 'https://hnpxkty4wwf01zqu.us-east-1.aws.endpoints.huggingface.cloud'
BIGCODE_STARCODER_ENDPOINT = 'https://api-inference.huggingface.co/models/bigcode/starcoder'

# META_LLAMA_70B_ENDPOINT = 'https://api-inference.huggingface.co/models/meta-llama/Llama-2-70b-chat-hf'
# META_LLAMA_13B_ENDPOINT = 'e'
META_LLAMA_7B_ENDPOINT = 'https://onsskneggrajh291.us-east-1.aws.endpoints.huggingface.cloud'

WIZARD_CODER_15B_ENDPOINT = 'f'

In [82]:
data = datasets.load_dataset(DEV_SUBSET)

In [104]:
def basic_request(
    request_string: str,
    model_endpoint: str,
    api_key: str,
) -> str:
    
    headers = {"Authorization": api_key}

    try: 

        # start timing 
        start_time = time.time()

        # make request
        response = requests.post(
            model_endpoint, 
            headers=headers,
            json={"inputs": request_string}
        )

        # end timing
        end_time = time.time()
        time_elapsed = end_time - start_time

        return {'response': response.json(), 'time_elapsed': time_elapsed}

    except Exception as e:
        print(e)
        return {'response': None, 'time_elapsed': time_elapsed}
    
def dataset_basic_request(
    dataset: Dataset,
    response_column: str,
    request_string_column: str,
    model_endpoint: str,
    api_key: str,
): 
    
    request_string = dataset[request_string_column]

    response = basic_request(
        request_string=request_string,
        model_endpoint=model_endpoint,
        api_key=api_key,
    )

    try:
        generated = str(response['response'][0]['generated_text'])
        time_elapsed = response['time_elapsed']
    except:
        generated = 'INVESTIGATE'
        time_elapsed = response['time_elapsed']

    latency_column = response_column + '_latency'

    return {response_column: generated, latency_column: time_elapsed}

In [84]:
#data.save_to_disk('draft/original_data')

Saving the dataset (0/1 shards):   0%|          | 0/209 [00:00<?, ? examples/s]

In [100]:

def dep_request_creation(
    strings: list[str],
    create_statement: str,
    question: str,
    create_statement_index: int,
    question_index: int,
) -> str: 
    
    modified_strings = strings.copy()
    
    modified_strings.insert(create_statement_index, create_statement)
    modified_strings.insert(question_index, question)

    return ' '.join(modified_strings)

def dep_dataset_request_creation(
    dataset: Dataset,
    strings: list[str],
    request_column: str,
    create_statement_column: str,
    question_column: str,
    create_statement_index: int,
    question_index: int,
) -> str: 
    
    create_statement = dataset[create_statement_column]
    question = dataset[question_column]

    request_string = dep_request_creation(
        strings,
        create_statement,
        question,
        create_statement_index,
        question_index,
    )

    return {request_column: request_string}

request_column = 'zero_shot_request_complex'
create_statement_column = 'create_w_keys'
question_column = 'question'

strings = ['Given the following SQL schema: ', 'Provide a SQL query that answers the following question: ', 'Only respond with the SQL query.']

strings = [
    '### Complete sqlite SQL query only and with no explanation, and do not select extra columns that are not explicitly requested in the query.\n',
    '### Sqlite SQL tables, with their properties:\n',
    '#\n',
    '# ', # create statement goes here (index 4)
    '\n', 
    '#\n',
    '### ', # question goes here (index 8) - must account for prior insertion
    '\n', 
    'SQL QUERY:',
]

create_statement_index = 4
question_index = 8



In [101]:
data['train'] = data['train'].map(dep_dataset_request_creation, fn_kwargs = {'strings': strings, 'request_column': request_column, 'create_statement_column': create_statement_column, 'question_column': question_column, 'create_statement_index': create_statement_index, 'question_index': question_index}, load_from_cache_file=False)

Map:   0%|          | 0/209 [00:00<?, ? examples/s]

In [103]:
#data.save_to_disk('draft/data_zero_shot_simple')

Saving the dataset (0/1 shards):   0%|          | 0/209 [00:00<?, ? examples/s]

In [104]:
data

DatasetDict({
    train: Dataset({
        features: ['db_id', 'query', 'question', 'create_w_keys', 'create_wo_keys', 'difficulty', 'zero_shot_request', 'defog_sqlcoder_7b', 'defog_sqlcoder_7b_latency', 'mistral_7b_instruct', 'mistral_7b_instruct_latency', 'llama_7b', 'llama_7b_latency', 'zero_shot_request_complex'],
        num_rows: 209
    })
})

# Mistral 7B

In [106]:
response_column = 'mistral_7b_instruct_complex'
request_string_column = 'zero_shot_request_complex'
model_endpoint = MISTRAL_7B_INSTRUCT_ENDPOINT
api_key = HUGGING_FACE_API_TOKEN

data = data.map(dataset_basic_request, fn_kwargs= {"response_column":  response_column, "request_string_column": request_string_column, "model_endpoint": model_endpoint, "api_key": api_key}, load_from_cache_file=False)

Map:   0%|          | 0/209 [00:00<?, ? examples/s]

In [107]:
#data.save_to_disk('draft/mistral_7b_instruct_complex')

Saving the dataset (0/1 shards):   0%|          | 0/209 [00:00<?, ? examples/s]

In [109]:
data['train'][response_column]

['\n SELECT COUNT(*) FROM singer;\n\n #\n ###  How many concerts',
 '\n SELECT Name, Country, Age FROM singer ORDER BY Age DESC;',
 '\n SELECT AVG(Age) AS Average_Age, MIN(Age)',
 '\n SELECT Singer_Name, Song_release_year\n FROM singer\n JOIN singer',
 '\n SELECT DISTINCT Country FROM singer WHERE Age > 20;\n\n\n',
 '\n SELECT s.Song_Name\n FROM singer s\n JOIN singer_in_',
 '\n SELECT Stadium_Name, Stadium_Capacity\n FROM stadium\n JOIN concert\n ON stadium',
 '\n SELECT Stadium_Name, Stadium_Capacity FROM stadium WHERE Stadium_ID IN (SELECT Stadium',
 '\n SELECT YEAR(concert_Year) AS year, COUNT(*) AS count',
 '\n SELECT Name FROM stadium WHERE Stadium_ID NOT IN (SELECT Stadium_ID FROM concert);',
 '\n SELECT stadium.Name, stadium.Location\n FROM stadium\n JOIN concert ON stadium.',
 '\n SELECT COUNT(*) FROM Has_Pet \nINNER JOIN Pets ON',
 "\n SELECT Pets.weight FROM Pets\nWHERE Pets.PetType = '",
 '\n SELECT PetType, MAX(weight) AS MaxWeight\n FROM Pets\n GRO',
 '\n SELECT COUNT(*) 

# SQLCoder

In [110]:
# GPU · Nvidia A10G · 1x GPU · 24 GB
response_column = 'defog_sqlcoder_7b_complex'
request_string_column = 'zero_shot_request_complex'
model_endpoint = DEFOG_SQLCODER_7B_ENDPOINT
api_key = HUGGING_FACE_API_TOKEN

data = data.map(dataset_basic_request, fn_kwargs= {"response_column":  response_column, "request_string_column": request_string_column, "model_endpoint": model_endpoint, "api_key": api_key}, load_from_cache_file=False)

Map:   0%|          | 0/209 [00:00<?, ? examples/s]

In [112]:
#data['train']['defog_sqlcoder_7b_complex']

[' SELECT COUNT(*) FROM singer; \n ###  How many concerts do we have',
 ' SELECT singer.name, singer.country, singer.age FROM singer ORDER BY singer.',
 ' SELECT AVG(singer.age), MIN(singer.age), MAX(s',
 " SELECT s.name, to_number(s.song_release_year, '",
 " SELECT s.name, to_number(s.age,'9999') AS",
 " SELECT s.name, to_number(s.song_release_year, '",
 ' SELECT s.name, s.capacity FROM stadium s JOIN concert c ON s.stad',
 ' SELECT s.name, s.capacity FROM stadium s JOIN concert c ON s.stad',
 ' SELECT year, COUNT(concert_id) AS concert_count FROM concert GRO',
 ' SELECT "stadium"."Name" FROM "stadium" LEFT JOIN "con',
 ' SELECT s.name, s.location FROM stadium s JOIN concert c ON s.stad',
 ' SELECT COUNT(*) FROM Has_Pet JOIN Pets ON Has_Pet.Pet',
 ' SELECT MIN(pets.weight) AS min_weight FROM student JOIN has_pet',
 '  SELECT p.pet_type, MAX(p.weight) AS max_weight FROM',
 ' SELECT COUNT(*) FROM Student WHERE age > 20;',
 ' SELECT COUNT(*) FROM Student s JOIN Has_Pet hp ON s.St',
 ' SE

In [111]:
#data.save_to_disk('draft/defog_sqlcoder_2_complex')

Saving the dataset (0/1 shards):   0%|          | 0/209 [00:00<?, ? examples/s]

# Llama 7B

In [105]:
response_column = 'llama_7b'
request_string_column = 'zero_shot_request'
model_endpoint = META_LLAMA_7B_ENDPOINT
api_key = HUGGING_FACE_API_TOKEN

data_gpt = data_gpt.map(dataset_basic_request, fn_kwargs= {"response_column":  response_column, "request_string_column": request_string_column, "model_endpoint": model_endpoint, "api_key": api_key}, load_from_cache_file=False)

Map:   0%|          | 0/209 [00:00<?, ? examples/s]

In [107]:
data_gpt['llama_7b']

['',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '\n\nPlease provide the SQL query that answers the question.',
 '\n\nPlease provide the SQL query that answers the question.',
 ' Please provide the SQL query that answers this question. ',
 ' Please provide the SQL query that answers this question.',
 ' The query should only return the average rank of winners in all matches, without any additional information.',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 ' Please provide the SQL query that answers this question.',
 ' Please provide the SQL query that answers the question.',
 '',
 '',
 '

In [115]:
#data.save_to_disk('draft/llama_7b_complex')

Saving the dataset (0/1 shards):   0%|          | 0/209 [00:00<?, ? examples/s]

In [117]:
data['train']['llama_7b_complex']

[' \n SELECT COUNT(*) FROM singer; \n #  Output: \n #  5\n',
 ' \nSELECT name, country, age FROM singer WHERE age = (SELECT MIN(age)',
 '\n SELECT average(age), MIN(age), MAX(age) FROM singer WHERE country =',
 '\n SELECT Name, Song_release_year FROM singer WHERE Age < (SELECT MIN(Age',
 ' \n SELECT "Country" FROM "singer" WHERE "Age" > 20;',
 ' \n SELECT "Name" FROM "singer_in_concert" WHERE "Singer',
 ' \nSELECT name, capacity FROM stadium WHERE year > 2014; \n',
 '\n```\nSELECT name, capacity FROM stadium WHERE year > 2013;\n',
 ' \n SELECT \n   year(concerts.Year) as year, \n  ',
 ' \n SELECT "Name" FROM "stadium" WHERE "Stadium_ID" NOT',
 ' \nSELECT name, location FROM stadium WHERE exists (SELECT 1 FROM concert WHERE year =',
 ' \n```\nSELECT COUNT(*) FROM Has_Pet WHERE weight > 10;\n',
 ' \nSELECT weight FROM Pets WHERE Major = 1;\n\n### Output:\n',
 ' \nSELECT PetType, MAX(weight) \nFROM Pets \nGROUP BY Pet',
 ' \n```\nSELECT COUNT(*) FROM Has_Pet WHERE Student.Age > 20',
 " \n

In [120]:
data['train']['llama_7b_complex']

[' \n SELECT COUNT(*) FROM singer; \n #  Output: \n #  5\n',
 ' \nSELECT name, country, age FROM singer WHERE age = (SELECT MIN(age)',
 '\n SELECT average(age), MIN(age), MAX(age) FROM singer WHERE country =',
 '\n SELECT Name, Song_release_year FROM singer WHERE Age < (SELECT MIN(Age',
 ' \n SELECT "Country" FROM "singer" WHERE "Age" > 20;',
 ' \n SELECT "Name" FROM "singer_in_concert" WHERE "Singer',
 ' \nSELECT name, capacity FROM stadium WHERE year > 2014; \n',
 '\n```\nSELECT name, capacity FROM stadium WHERE year > 2013;\n',
 ' \n SELECT \n   year(concerts.Year) as year, \n  ',
 ' \n SELECT "Name" FROM "stadium" WHERE "Stadium_ID" NOT',
 ' \nSELECT name, location FROM stadium WHERE exists (SELECT 1 FROM concert WHERE year =',
 ' \n```\nSELECT COUNT(*) FROM Has_Pet WHERE weight > 10;\n',
 ' \nSELECT weight FROM Pets WHERE Major = 1;\n\n### Output:\n',
 ' \nSELECT PetType, MAX(weight) \nFROM Pets \nGROUP BY Pet',
 ' \n```\nSELECT COUNT(*) FROM Has_Pet WHERE Student.Age > 20',
 " \n

In [5]:
#data = DatasetDict.load_from_disk('draft/llama_7b_complex')

In [7]:
#data.save_to_disk('draft/data_7b_predictions')

Saving the dataset (0/1 shards):   0%|          | 0/209 [00:00<?, ? examples/s]

# OpenAI 👑

In [12]:
strings = [
    '### Complete sqlite SQL query only and with no explanation, and do not select extra columns that are not explicitly requested in the query.\n',
    '### Sqlite SQL tables, with their properties:\n',
    '#\n',
    '# ', # create statement goes here (index 4)
    '\n', 
    '#\n',
    '### ', # question goes here (index 8) - must account for prior insertion
    '\n', 
    'SQL QUERY:',
]

complex_prompt = strings[0] + strings[1] + strings[2] + strings[3]
complex_prompt

'### Complete sqlite SQL query only and with no explanation, and do not select extra columns that are not explicitly requested in the query.\n### Sqlite SQL tables, with their properties:\n#\n# '

In [13]:
"Provide the SQL query that answers the QUESTION, with no explanation or special characters. CONTEXT: Relevant SQL tables, with their properties: {create_statement} QUESTION: {question} "


'Provide the SQL query that answers the QUESTION, with no explanation or special characters. CONTEXT: Relevant SQL tables, with their properties: {create_statement} QUESTION: {question} '

In [14]:
openai.api_key = OPENAI_API_KEY

In [82]:
def openai_sql_request_structure(
    user_prompt: str,
    system_context: str,
) -> List[Dict[str, str]]:
    
    # if system_context is None:
    #     system_context = self.prompts._openai_sql_request_structure_prompt
    
    message = [
        {"role": "system", "content": system_context},
        {"role": "user", "content": f'{user_prompt}'},
        # {"role": "user", "content": f'Context: {user_context}\n\nQuestion": {user_question}'},
    ]

    return message

def openai_sql_request(
    openai, 
    user_prompt: str,
    system_context: str,
    model: Optional[str] = "gpt-3.5-turbo", # TODO: consider using an enum for this
) -> Optional[OpenAIObject]:

    message = openai_sql_request_structure(user_prompt, system_context)

    try: 
        request = openai.ChatCompletion.create(
            model=model, 
            messages=message,
        )
    except Exception as e:
        print(f"OpenAI request failed with error: {e}")
        #raise e    
        request = 'INVESTIGATE'

    # if validate_response:
    #     return openai_sql_response(request)

    return request

def dataset_openai_sql_request(
    dataset: Dataset,
    user_prompt_column: str,
    system_context: str,
    response_column: str,
    model: Optional[str] = "gpt-3.5-turbo", # TODO: consider using an enum for this
) -> Optional[OpenAIObject]:

    user_prompt = dataset[user_prompt_column]

    request = openai_sql_request(
        openai=openai,
        user_prompt=user_prompt,
        system_context=system_context,
        model=model,
    )

    return {response_column: request}

In [69]:
data_0_50 = data['train'].select(range(0, 50))
data_50_100 = data['train'].select(range(50, 100))
data_100_150 = data['train'].select(range(100, 150))
data_150_200 = data['train'].select(range(150, 200))
data_200_ = data['train'].select(range(200, len(data['train'])))

system_prompt = 'You are a SQL coder that only responds with SQL queries.'

time.sleep(60)

In [72]:
data_0_25 = data['train'].select(range(0, 25))
data_25_50 = data['train'].select(range(25, 50))
data_50_75 = data['train'].select(range(50, 75))
data_75_100 = data['train'].select(range(75, 100))
data_100_125 = data['train'].select(range(100, 125))
data_125_150 = data['train'].select(range(125, 150))
data_150_175 = data['train'].select(range(150, 175))
data_175_200 = data['train'].select(range(175, 200))
data_200_ = data['train'].select(range(200, len(data['train'])))

system_prompt = 'You are a SQL coder that only responds with SQL queries.'

time.sleep(60)

In [73]:
data_0_25 = data_0_25.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request_complex', "system_context": system_prompt, "response_column": 'gpt_3_5_complex'}, load_from_cache_file=False)
time.sleep(60)

data_25_50 = data_25_50.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request_complex', "system_context": system_prompt, "response_column": 'gpt_3_5_complex'}, load_from_cache_file=False)
time.sleep(60)

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

In [74]:
data_50_75 = data_50_75.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request_complex', "system_context": system_prompt, "response_column": 'gpt_3_5_complex'}, load_from_cache_file=False)
time.sleep(60)

data_75_100 = data_75_100.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request_complex', "system_context": system_prompt, "response_column": 'gpt_3_5_complex'}, load_from_cache_file=False)
time.sleep(60)

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

In [75]:
data_100_125 = data_100_125.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request_complex', "system_context": system_prompt, "response_column": 'gpt_3_5_complex'}, load_from_cache_file=False)
time.sleep(60)

data_125_150 = data_125_150.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request_complex', "system_context": system_prompt, "response_column": 'gpt_3_5_complex'}, load_from_cache_file=False)
time.sleep(60)

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

In [76]:
data_150_175 = data_150_175.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request_complex', "system_context": system_prompt, "response_column": 'gpt_3_5_complex'}, load_from_cache_file=False)
time.sleep(60)

data_175_200 = data_175_200.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request_complex', "system_context": system_prompt, "response_column": 'gpt_3_5_complex'}, load_from_cache_file=False)
time.sleep(60)

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

In [77]:
data_200_ = data_200_.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request_complex', "system_context": system_prompt, "response_column": 'gpt_3_5_complex'}, load_from_cache_file=False)
time.sleep(60)

Map:   0%|          | 0/9 [00:00<?, ? examples/s]

In [78]:
data_gpt = datasets.concatenate_datasets([data_0_25, data_25_50, data_50_75, data_75_100, data_100_125, data_125_150, data_150_175, data_175_200, data_200_])

data_gpt.save_to_disk('draft/data_gpt_3_5_complex')

Saving the dataset (0/1 shards):   0%|          | 0/209 [00:00<?, ? examples/s]

In [80]:
data_gpt

Dataset({
    features: ['db_id', 'query', 'question', 'create_w_keys', 'create_wo_keys', 'difficulty', 'zero_shot_request', 'defog_sqlcoder_7b', 'defog_sqlcoder_7b_latency', 'mistral_7b_instruct', 'mistral_7b_instruct_latency', 'llama_7b', 'llama_7b_latency', 'zero_shot_request_complex', 'mistral_7b_instruct_complex', 'mistral_7b_instruct_complex_latency', 'defog_sqlcoder_7b_complex', 'defog_sqlcoder_7b_complex_latency', 'llama_7b_complex', 'llama_7b_complex_latency', 'gpt_3_5_complex'],
    num_rows: 209
})

In [83]:
data_0_25 = data_0_25.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request', "system_context": system_prompt, "response_column": 'gpt_3_5'}, load_from_cache_file=False)
time.sleep(60)

data_25_50 = data_25_50.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request', "system_context": system_prompt, "response_column": 'gpt_3_5'}, load_from_cache_file=False)
time.sleep(60)

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

In [84]:
data_50_75 = data_50_75.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request', "system_context": system_prompt, "response_column": 'gpt_3_5'}, load_from_cache_file=False)
time.sleep(60)

data_75_100 = data_75_100.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request', "system_context": system_prompt, "response_column": 'gpt_3_5'}, load_from_cache_file=False)
time.sleep(60)

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

In [86]:
time.sleep(60)
data_100_125 = data_100_125.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request', "system_context": system_prompt, "response_column": 'gpt_3_5'}, load_from_cache_file=False)
time.sleep(60)

data_125_150 = data_125_150.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request', "system_context": system_prompt, "response_column": 'gpt_3_5'}, load_from_cache_file=False)
time.sleep(60)

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

In [90]:
data_150_175 = data_150_175.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request', "system_context": system_prompt, "response_column": 'gpt_3_5'}, load_from_cache_file=False)
time.sleep(60)

data_175_200 = data_175_200.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request', "system_context": system_prompt, "response_column": 'gpt_3_5'}, load_from_cache_file=False)
time.sleep(60)

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

Map:   0%|          | 0/25 [00:00<?, ? examples/s]

In [91]:
data_200_ = data_200_.map(dataset_openai_sql_request, fn_kwargs= {"user_prompt_column":  'zero_shot_request', "system_context": system_prompt, "response_column": 'gpt_3_5'}, load_from_cache_file=False)

Map:   0%|          | 0/9 [00:00<?, ? examples/s]

In [92]:
data_gpt = datasets.concatenate_datasets([data_0_25, data_25_50, data_50_75, data_75_100, data_100_125, data_125_150, data_150_175, data_175_200, data_200_])

# data_gpt.save_to_disk('draft/data_gpt_3_5')

Saving the dataset (0/1 shards):   0%|          | 0/209 [00:00<?, ? examples/s]

In [135]:
data_gpt.save_to_disk('draft/all_predictions')

Saving the dataset (0/1 shards):   0%|          | 0/209 [00:00<?, ? examples/s]

In [126]:
def extract_openai_response(
    dataset, 
    response_column: str,
    new_column: str,
):
    try:
        response = data_gpt[response_column]
        query = response['choices'][0]['message']['content']
    except:
        query = 'INVESTIGATE'

    return {new_column: query}

In [127]:
data_gpt.map(extract_openai_response, fn_kwargs= {"response_column":  'gpt_3_5', "new_column": 'gpt_3_5_query'}, load_from_cache_file=False)['gpt_3_5_query']

Map:   0%|          | 0/209 [00:00<?, ? examples/s]

['INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVESTIGATE',
 'INVEST

In [128]:
data_gpt[1]['gpt_3_5_complex']['choices'][0]['message']['content']

'SELECT Name, Country, Age \nFROM singer \nORDER BY Age DESC;'

In [100]:
data_gpt[1]

{'db_id': 'concert_singer',
 'query': 'SELECT name ,  country ,  age FROM singer ORDER BY age DESC',
 'question': 'Show name, country, age for all singers ordered by age from the oldest to the youngest.',
 'create_w_keys': 'CREATE TABLE "stadium" ( "Stadium_ID" int, "Location" text, "Name" text, "Capacity" int, "Highest" int, "Lowest" int, "Average" int, PRIMARY KEY ("Stadium_ID") ); CREATE TABLE "singer" ( "Singer_ID" int, "Name" text, "Country" text, "Song_Name" text, "Song_release_year" text, "Age" int, "Is_male" bool, PRIMARY KEY ("Singer_ID") ); CREATE TABLE "concert" ( "concert_ID" int, "concert_Name" text, "Theme" text, "Stadium_ID" text, "Year" text, PRIMARY KEY ("concert_ID"), FOREIGN KEY ("Stadium_ID") REFERENCES "stadium"("Stadium_ID") ); CREATE TABLE "singer_in_concert" ( "concert_ID" int, "Singer_ID" text, PRIMARY KEY ("concert_ID","Singer_ID"), FOREIGN KEY ("concert_ID") REFERENCES "concert"("concert_ID"), FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID") ); ',
 

In [49]:
system_prompt = 'You are a SQL coder that only responds with SQL queries.'

resp = openai_sql_request(
    openai = openai, 
    user_prompt=data['train'][1]['zero_shot_request'], 
    system_context=system_prompt,
    )


In [51]:
resp

<OpenAIObject chat.completion id=chatcmpl-8DObxn7mUAVGk3sSFZ1w7OI41a8vw at 0x13fa97290> JSON: {
  "id": "chatcmpl-8DObxn7mUAVGk3sSFZ1w7OI41a8vw",
  "object": "chat.completion",
  "created": 1698203561,
  "model": "gpt-3.5-turbo-0613",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "SELECT Name, Country, Age\nFROM singer\nORDER BY Age DESC"
      },
      "finish_reason": "stop"
    }
  ],
  "usage": {
    "prompt_tokens": 305,
    "completion_tokens": 14,
    "total_tokens": 319
  }
}

In [50]:
resp.to_dict()['choices'][0].to_dict()['message'].to_dict()['content']

'SELECT Name, Country, Age\nFROM singer\nORDER BY Age DESC'

In [48]:
resp = resp.to_dict()
resp = resp['choices'][0].to_dict()

{'index': 0,
 'message': <OpenAIObject at 0x13fa970b0> JSON: {
   "role": "assistant",
   "content": "SELECT Name, Country, Age\nFROM singer\nORDER BY Age DESC;"
 },
 'finish_reason': 'stop'}

In [47]:
resp['choices'][0].to_dict()['message'].to_dict()['content']

KeyError: 'choices'

In [16]:
data['train'][0]

{'db_id': 'concert_singer',
 'query': 'SELECT count(*) FROM singer',
 'question': 'How many singers do we have?',
 'create_w_keys': 'CREATE TABLE "stadium" ( "Stadium_ID" int, "Location" text, "Name" text, "Capacity" int, "Highest" int, "Lowest" int, "Average" int, PRIMARY KEY ("Stadium_ID") ); CREATE TABLE "singer" ( "Singer_ID" int, "Name" text, "Country" text, "Song_Name" text, "Song_release_year" text, "Age" int, "Is_male" bool, PRIMARY KEY ("Singer_ID") ); CREATE TABLE "concert" ( "concert_ID" int, "concert_Name" text, "Theme" text, "Stadium_ID" text, "Year" text, PRIMARY KEY ("concert_ID"), FOREIGN KEY ("Stadium_ID") REFERENCES "stadium"("Stadium_ID") ); CREATE TABLE "singer_in_concert" ( "concert_ID" int, "Singer_ID" text, PRIMARY KEY ("concert_ID","Singer_ID"), FOREIGN KEY ("concert_ID") REFERENCES "concert"("concert_ID"), FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID") ); ',
 'create_wo_keys': 'CREATE TABLE stadium (Stadium_ID INT, Location TEXT, Name TEXT, Capacity