In [1]:
# System
import os
import sys
import json 
import pandas as pd
from dotenv import load_dotenv
from typing import Optional, Dict, List, Union
from IPython.display import display, clear_output
import time

# External
from datasets import DatasetDict, Dataset
import datasets

# Internal
from predict import SQLPredict

# Load environment variables
load_dotenv("../../.env")

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
REPLICATE_API_TOKEN = os.environ.get("REPLICATE_API_TOKEN")
REPLICATE_LLAMA_13B_TUNED = os.environ.get("REPLICATE_LLAMA_7B_TUNED")
REPLICATE_LLAMA_13B_BASE = os.environ.get("REPLICATE_LLAMA_13B_BASE")
HUGGING_FACE_API_TOKEN = os.environ.get("HUGGING_FACE_API_TOKEN")
MISTRAL_7B_INSTRUCT = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1"

model_name = "llama_2_13b_base"


In [2]:
sqp = SQLPredict.from_replicate_model(
    openai_api_key=OPENAI_API_KEY,
    replicate_api_key=REPLICATE_API_TOKEN,
    model_name=model_name,
    model_id=REPLICATE_LLAMA_13B_BASE
)

sqp.hf_key = HUGGING_FACE_API_TOKEN
sqp.add_model_endpoint('mistral', MISTRAL_7B_INSTRUCT)

In [5]:
rich_testing = Dataset.load_from_disk("../../local_data/rich_testing_subset_llama_13b_1_0_0_inferences_two")

In [6]:
rich_testing_subset_0_10 = rich_testing.select(range(0, 10))

In [8]:
rich_testing_subset_0_10 = rich_testing_subset_0_10.map(sqp.basic_text_generation_dataset_request, fn_kwargs={"model_name": "mistral", "response_column_name": "mistral_response"}, load_from_cache_file=False)

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

In [9]:
display(rich_testing_subset_0_10[5]['mistral_response'])

[{'generated_text': 'Context details the databse: CREATE TABLE table_name_41 (song VARCHAR, draw INTEGER) # Question to answer: What song has draw number less than 2? # Answer as a SQL query: \nSELECT song FROM table_name_41 WHERE draw < 2'}]

In [10]:
for i in range(10):
    
    clear_output(wait=True)

    display(rich_testing_subset_0_10[i]['mistral_response'])

    display(i)

    time.sleep(5)

[{'generated_text': "Context details the databse: CREATE TABLE airport (id VARCHAR, airport_id VARCHAR, pilot VARCHAR); CREATE TABLE flight (id VARCHAR, airport_id VARCHAR, pilot VARCHAR) # Question to answer: How many airports haven't the pilot 'Thompson' driven an aircraft? # Answer as a SQL query: \nSELECT COUNT(*) FROM airport\nWHERE airport_id NOT IN (\n    SE"}]

9

In [6]:
rich_testing_subset_0_100 = rich_testing.select(range(0, 100))
rich_testing_subset_100_200 = rich_testing.select(range(100, 200))
rich_testing_subset_200_300 = rich_testing.select(range(200, 300))

In [69]:
rich_testing_subset_0_100 = rich_testing_subset_0_100.map(sqp.replicate_dataset_request, fn_kwargs={"model_name": model_name, "column_name": "llama_2_13b_base_inference"})

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [70]:
rich_testing_subset_100_200 = rich_testing_subset_100_200.map(sqp.replicate_dataset_request, fn_kwargs={"model_name": model_name, "column_name": "llama_2_13b_base_inference"})

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [71]:
rich_testing_subset_200_300 = rich_testing_subset_200_300.map(sqp.replicate_dataset_request, fn_kwargs={"model_name": model_name, "column_name": "llama_2_13b_base_inference"})

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [72]:
rich_testing_subset = datasets.concatenate_datasets([rich_testing_subset_0_100, rich_testing_subset_100_200, rich_testing_subset_200_300])

In [73]:
rich_testing_subset[0]

{'answer': 'SELECT COUNT(*) FROM COURSE',
 'context': 'CREATE TABLE COURSE (Id VARCHAR)',
 'question': 'How many courses are there in total?',
 'table_count': 1,
 'column_types': '{"COURSE": {"Id": "VARCHAR"}}',
 'duplicate_create_table': False,
 'filler_data': '{"COURSE": [{"Id": "Brittany Hernandez"}, {"Id": "Roy Ramirez"}, {"Id": "Russell Blair"}, {"Id": "Stephen Hoffman"}, {"Id": "Taylor Powers"}]}',
 'query_result': '[(5,)]',
 'valid_query': True,
 'openai_inference': {'choices': [{'finish_reason': 'stop',
    'index': 0,
    'message': {'content': 'SELECT COUNT(*) FROM COURSE;',
     'role': 'assistant'}}],
  'created': 1695830663,
  'id': 'chatcmpl-83RJPNGEgmb71scPDnqvwvRoxypos',
  'model': 'gpt-3.5-turbo-0613',
  'object': 'chat.completion',
  'usage': {'completion_tokens': 7, 'prompt_tokens': 48, 'total_tokens': 55}},
 'openai_result': '[(5,)]',
 'openai_valid': True,
 'tuning_format': '{"prompt": "[INST] <<SYS>>\\nContext contains the relevant SQL tables, respond with the SQL

In [75]:
rich_testing_subset.save_to_disk("../../local_data/rich_testing_subset_llama_13b_1_0_0_inferences_two_b")

Saving the dataset (0/1 shards):   0%|          | 0/300 [00:00<?, ? examples/s]

In [7]:
rich_testing[0]

{'answer': 'SELECT COUNT(*) FROM COURSE',
 'context': 'CREATE TABLE COURSE (Id VARCHAR)',
 'question': 'How many courses are there in total?',
 'table_count': 1,
 'column_types': '{"COURSE": {"Id": "VARCHAR"}}',
 'duplicate_create_table': False,
 'filler_data': '{"COURSE": [{"Id": "Brittany Hernandez"}, {"Id": "Roy Ramirez"}, {"Id": "Russell Blair"}, {"Id": "Stephen Hoffman"}, {"Id": "Taylor Powers"}]}',
 'query_result': '[(5,)]',
 'valid_query': True,
 'openai_inference': {'choices': [{'finish_reason': 'stop',
    'index': 0,
    'message': {'content': 'SELECT COUNT(*) FROM COURSE;',
     'role': 'assistant'}}],
  'created': 1695830663,
  'id': 'chatcmpl-83RJPNGEgmb71scPDnqvwvRoxypos',
  'model': 'gpt-3.5-turbo-0613',
  'object': 'chat.completion',
  'usage': {'completion_tokens': 7, 'prompt_tokens': 48, 'total_tokens': 55}},
 'openai_result': '[(5,)]',
 'openai_valid': True,
 'tuning_format': '{"prompt": "[INST] <<SYS>>\\nContext contains the relevant SQL tables, respond with the SQL

In [62]:
prompt = json.loads(rich_testing[0]["tuning_format"])['prompt']

In [21]:
prompt

'[INST] <<SYS>>\nContext contains the relevant SQL tables, respond with the SQL query that answers the Question.\n<</SYS>>\n\nContext: CREATE TABLE COURSE (Id VARCHAR)\n\nQuestion: How many courses are there in total?[/INST]\n\n'

In [68]:
i = 5
rich_testing[i]

{'answer': 'SELECT song FROM table_name_41 WHERE draw < 2',
 'context': 'CREATE TABLE table_name_41 (song VARCHAR, draw INTEGER)',
 'question': 'What song has draw number less than 2?',
 'table_count': 1,
 'column_types': '{"table_name_41": {"song": "VARCHAR", "draw": "INT"}}',
 'duplicate_create_table': False,
 'filler_data': '{"table_name_41": [{"song": "Alan Williams", "draw": 32}, {"song": "Debbie Byrd", "draw": 6}, {"song": "Mark Munoz MD", "draw": 73}, {"song": "Darryl Rubio", "draw": 80}, {"song": "John Jackson", "draw": 1}]}',
 'query_result': "[('John Jackson',)]",
 'valid_query': True,
 'openai_inference': {'choices': [{'finish_reason': 'stop',
    'index': 0,
    'message': {'content': 'SELECT song \nFROM table_name_41\nWHERE draw < 2;',
     'role': 'assistant'}}],
  'created': 1695830668,
  'id': 'chatcmpl-83RJUWbvMAN8JemQspA0zCyVOaakO',
  'model': 'gpt-3.5-turbo-0613',
  'object': 'chat.completion',
  'usage': {'completion_tokens': 15, 'prompt_tokens': 55, 'total_tokens':

In [69]:
prompt_script = "Context contains the relevant SQL tables, provide the SQL query that answers the Question as the Answer."
prompt = "Context: " + rich_testing[i]['context'] + " # " "Question: " + rich_testing[i]['question'] + " # " + "Answer: "
prompt # "Prompt: " + prompt_script + " # " + 

'Context: CREATE TABLE table_name_41 (song VARCHAR, draw INTEGER) # Question: What song has draw number less than 2? # Answer: '

In [60]:
import requests

MISTRAL_7B_INSTRUCT = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1"
headers = {"Authorization": HUGGING_FACE_API_TOKEN}

def query(payload):
	response = requests.post(API_URL, headers=headers, json=payload)
	return response.json()
	
output = query({
	"inputs": prompt,
})

In [61]:
output

[{'generated_text': 'Prompt: Context contains the relevant SQL tables, provide the SQL query that answers the Question as the Answer. # Context: CREATE TABLE table_21991074_3 (pts_agst INTEGER) # Question: Name the most points agst # Answer: \nSELECT pts_agst, name\nFROM table_2199107'}]