In [65]:
# System
import os
import sys
import json 
from dotenv import load_dotenv
from typing import Optional, Dict, List, Union

# External
from datasets import DatasetDict, Dataset
import datasets

# Internal
from predict import SQLPredict

# Load environment variables
load_dotenv("../../.env")

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
REPLICATE_API_TOKEN = os.environ.get("REPLICATE_API_TOKEN")
REPLICATE_LLAMA_13B_TUNED = os.environ.get("REPLICATE_LLAMA_7B_TUNED")
REPLICATE_LLAMA_13B_BASE = os.environ.get("REPLICATE_LLAMA_13B_BASE")

model_name = "llama_2_13b_base"


In [55]:
sqp = SQLPredict.from_replicate_model(
    openai_api_key=OPENAI_API_KEY,
    replicate_api_key=REPLICATE_API_TOKEN,
    model_name=model_name,
    model_id=REPLICATE_LLAMA_13B_BASE
)

In [56]:
sqp.replicate_models

{'llama_2_13b_base': 'meta/llama-2-13b:078d7a002387bd96d93b0302a4c03b3f15824b63104034bfa943c63a8f208c38'}

In [57]:
rich_testing = Dataset.load_from_disk("../../local_data/rich_testing_subset_llama_13b_1_0_0_inferences_two")

In [68]:
rich_testing_subset_0_100 = rich_testing.select(range(0, 100))
rich_testing_subset_100_200 = rich_testing.select(range(100, 200))
rich_testing_subset_200_300 = rich_testing.select(range(200, 300))

In [69]:
rich_testing_subset_0_100 = rich_testing_subset_0_100.map(sqp.replicate_dataset_request, fn_kwargs={"model_name": model_name, "column_name": "llama_2_13b_base_inference"})

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [70]:
rich_testing_subset_100_200 = rich_testing_subset_100_200.map(sqp.replicate_dataset_request, fn_kwargs={"model_name": model_name, "column_name": "llama_2_13b_base_inference"})

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [71]:
rich_testing_subset_200_300 = rich_testing_subset_200_300.map(sqp.replicate_dataset_request, fn_kwargs={"model_name": model_name, "column_name": "llama_2_13b_base_inference"})

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [72]:
rich_testing_subset = datasets.concatenate_datasets([rich_testing_subset_0_100, rich_testing_subset_100_200, rich_testing_subset_200_300])

In [73]:
rich_testing_subset[0]

{'answer': 'SELECT COUNT(*) FROM COURSE',
 'context': 'CREATE TABLE COURSE (Id VARCHAR)',
 'question': 'How many courses are there in total?',
 'table_count': 1,
 'column_types': '{"COURSE": {"Id": "VARCHAR"}}',
 'duplicate_create_table': False,
 'filler_data': '{"COURSE": [{"Id": "Brittany Hernandez"}, {"Id": "Roy Ramirez"}, {"Id": "Russell Blair"}, {"Id": "Stephen Hoffman"}, {"Id": "Taylor Powers"}]}',
 'query_result': '[(5,)]',
 'valid_query': True,
 'openai_inference': {'choices': [{'finish_reason': 'stop',
    'index': 0,
    'message': {'content': 'SELECT COUNT(*) FROM COURSE;',
     'role': 'assistant'}}],
  'created': 1695830663,
  'id': 'chatcmpl-83RJPNGEgmb71scPDnqvwvRoxypos',
  'model': 'gpt-3.5-turbo-0613',
  'object': 'chat.completion',
  'usage': {'completion_tokens': 7, 'prompt_tokens': 48, 'total_tokens': 55}},
 'openai_result': '[(5,)]',
 'openai_valid': True,
 'tuning_format': '{"prompt": "[INST] <<SYS>>\\nContext contains the relevant SQL tables, respond with the SQL

In [75]:
rich_testing_subset.save_to_disk("../../local_data/rich_testing_subset_llama_13b_1_0_0_inferences_two_b")

Saving the dataset (0/1 shards):   0%|          | 0/300 [00:00<?, ? examples/s]