## Amazon Bedrock LLM Router Evaluation

### Intro and Goal
This Jupyter Notebook is designed to fine-tune an LLM (Large Language Model) routing system on a Text-to-SQL use case.

The goal is to take a prompt, determine the level of complexity and then route the prompt either to a small or large LLM to generate the corresponding SQL query.

*WIP*

In [None]:
# 1. Create a conda environment

# !conda create -y --name bedrock-router-eval python=3.11.8
# !conda init && activate bedrock-router-eval
# !conda install -n bedrock-router-eval ipykernel --update-deps --force-reinstall -y
# !conda install -c conda-forge ipython-sql

In [None]:
# 2. Install dependencies

# !pip install -r requirements.txt

In [75]:
# 3. Import necessary libraries and load environment variables

import numpy as np
from scipy.spatial.distance import cdist
import json
from dotenv import load_dotenv, find_dotenv
import os
import boto3
import sqlite3
from pandas.io import sql
from botocore.config import Config
import pandas as pd
import io
import json
from io import StringIO
import sqlparse
import sqlite3
import time
import matplotlib.pyplot as plt
import re
import typing as t
from queue import Queue
from threading import Thread

# loading environment variables that are stored in local file
local_env_filename = 'bedrock-router-eval.env'
load_dotenv(find_dotenv(local_env_filename),override=True)

os.environ['REGION'] = os.getenv('REGION')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
os.environ['SQL_DATABASE'] = os.getenv('SQL_DATABASE') # LOCAL or GLUE
os.environ['SQL_DIALECT'] = os.getenv('SQL_DIALECT') # SQlite or awsathena
os.environ['SQL_DATABASE_NAME'] = os.getenv('SQL_DATABASE_NAME')
# os.environ['AWS_ACCESS_KEY'] = os.getenv('AWS_ACCESS_KEY')
# os.environ['AWS_SECRET_ACCESS_KEY'] = os.getenv('AWS_SECRET_ACCESS_KEY')

REGION = os.environ['REGION']
HF_TOKEN = os.environ['HF_TOKEN']
SQL_DATABASE = os.environ['SQL_DATABASE']
SQL_DIALECT = os.environ['SQL_DIALECT']
SQL_DATABASE_NAME = os.environ['SQL_DATABASE_NAME']

MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0" # anthropic.claude-3-haiku-20240307-v1:0 "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

# get ground truth data
file_path = 'question_query_good_results.jsonl'
groundtruth_df = pd.read_json(file_path, lines=True)

In [76]:
# # 4. Define Helper functions

# def balance_dataset(
#     dataset_df: pd.DataFrame, key: str, random_state: int = 42
# ) -> pd.DataFrame:
#     """
#     Balance the dataset by oversampling the minority class.
#     """
#     # Determine the minority class
#     min_count = dataset_df[key].value_counts().min()

#     # Create a balanced DataFrame
#     sampled_dfs = []
#     for label in dataset_df[key].unique():
#         sampled = dataset_df[dataset_df[key] == label].sample(
#             n=min_count, random_state=random_state
#         )
#         sampled_dfs.append(sampled)

#     balanced_df = pd.concat(sampled_dfs).sample(frac=1, random_state=random_state)
#     return balanced_df
    
# def visualize_distribution(df, key):
#     # Check if 'score' column exists in the DataFrame
#     if key not in df.columns:
#         raise ValueError(f"The DataFrame does not contain a '{key}' column.")
    
#     # Count the frequency of each score
#     score_counts = df[key].value_counts().sort_index()
    
#     # Create a bar chart
#     plt.figure(figsize=(10, 6))
#     plt.bar(score_counts.index, score_counts.values)
    
#     # Customize the chart
#     plt.title(f'Distribution of {key}')
#     plt.xlabel(f'{key}')
#     plt.ylabel('Frequency')
#     plt.xticks(range(int(score_counts.index.min()), int(score_counts.index.max()) + 1))
    
#     # Add value labels on top of each bar
#     for i, v in enumerate(score_counts.values):
#         plt.text(score_counts.index[i], v, str(v), ha='center', va='bottom')
    
#     # Display the chart
#     plt.tight_layout()
#     plt.show()

# def execution_accuracy(generated_sql, labeled_sql):
#     """
#     Calculate Execution Accuracy (EX)
    
#     Args:
#     generated_sql (str): The SQL query generated by the model
#     labeled_sql (str): The labeled (ground truth) SQL query
    
#     Returns:
#     float: 1.0 if the queries match, 0.0 otherwise
#     """
#     # Normalize and compare the SQL queries
#     gen_normalized = sqlparse.format(generated_sql, strip_comments=True, reindent=True)
#     lab_normalized = sqlparse.format(labeled_sql, strip_comments=True, reindent=True)
    
#     return 1.0 if gen_normalized == lab_normalized else 0.0

# def exact_set_match_accuracy(generated_sql, labeled_sql, db_connection):
#     """
#     Calculate Exact Set Match Accuracy (EM)
    
#     Args:
#     generated_sql (str): The SQL query generated by the model
#     labeled_sql (str): The labeled (ground truth) SQL query
#     db_connection: A database connection object
    
#     Returns:
#     float: 1.0 if the result sets match, 0.0 otherwise
#     """
#     try:
#         # Execute both queries
#         gen_result = pd.read_sql_query(generated_sql, db_connection)
#         lab_result = pd.read_sql_query(labeled_sql, db_connection)
        
#         # Compare the result sets
#         return 1.0 if gen_result.equals(lab_result) else 0.0
#     except Exception as e:
#         print(f"Error executing SQL: {e}")
#         return 0.0

# def valid_efficiency_score(generated_sql, labeled_sql, db_connection):
#     """
#     Calculate Valid Efficiency Score (VES)
    
#     Args:
#     generated_sql (str): The SQL query generated by the model
#     labeled_sql (str): The labeled (ground truth) SQL query
#     db_connection: A database connection object
    
#     Returns:
#     float: The VES score
#     """
#     try:
#         # Execute both queries and measure execution time
#         gen_start = time.time()
#         gen_result = pd.read_sql_query(generated_sql, db_connection)
#         gen_time = time.time() - gen_start
#         # print(f'generated_sql_execution_time: {gen_time}')
#         lab_start = time.time()
#         lab_result = pd.read_sql_query(labeled_sql, db_connection)
#         lab_time = time.time() - lab_start
#         # print(f'labeled_sql_execution_time: {lab_time}')
        
#         # Check if results match
#         if not gen_result.equals(lab_result):
#             return 0.0
        
#         # Calculate VES
#         ves = min(lab_time / gen_time, 1.0)
#         return ves
#     except Exception as e:
#         print(f"Error executing SQL: {e}")
#         return 0.0


# def dataframe_to_s3_jsonl(df, bucket_name, prefix, filename):
#     """
#     Convert a pandas DataFrame to JSONL format and upload it to S3.

#     Parameters:
#     df (pandas.DataFrame): The DataFrame to be converted and uploaded.
#     bucket_name (str): The name of the S3 bucket.
#     prefix (str): The S3 prefix (folder path) where the file will be uploaded.
#     filename (str): The name of the file to be created in S3.

#     Returns:
#     str: The S3 URI of the uploaded file.
#     """
#     # Convert DataFrame to JSONL
#     jsonl_buffer = StringIO()
#     for _, row in df.iterrows():
#         json.dump(row.to_dict(), jsonl_buffer)
#         jsonl_buffer.write('\n')
#     jsonl_buffer.seek(0)
#     s3_client = boto3.client('s3')
#     # Upload the JSONL data to S3
#     s3_key = f"{prefix.rstrip('/')}/{filename}"
#     s3_client.put_object(
#         Bucket=bucket_name,
#         Key=s3_key,
#         Body=jsonl_buffer.getvalue(),
#         ContentType='application/json'
#     )

#     # Return the S3 URI of the uploaded file
#     return f"s3://{bucket_name}/{s3_key}"


# def download_and_parse_jsonl(bucket_name, object_key):
#     """
#     Downloads a JSONL file from an Amazon S3 bucket and parses it into a pandas DataFrame.

#     Args:
#         bucket_name (str): The name of the S3 bucket where the JSONL file is stored.
#         object_key (str): The key (path) of the JSONL file in the S3 bucket.

#     Returns:
#         pandas.DataFrame: A DataFrame containing the data from the JSONL file.
#     """
    
#     s3_client = boto3.client('s3')
#     # Download the JSONL file from S3
#     response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
#     jsonl_data = response['Body'].read().decode('utf-8')

#     # Parse the JSONL data into a list of dictionaries
#     data = [json.loads(line) for line in jsonl_data.strip().split('\n')]

#     # Create a DataFrame from the list of dictionaries
#     df = pd.DataFrame(data)

#     return df

# def check_job_status_and_wait(job_arn):
#     # # check status
#     # bedrock.get_model_invocation_job(jobIdentifier=jobArn)['status']

#     # # list batch jobs
#     # bedrock.list_model_invocation_jobs(
#     #     maxResults=10,
#     #     statusEquals="Failed",
#     #     sortOrder="Descending"
#     # )

#     while True:
#         job_status = bedrock_client.get_model_invocation_job(jobIdentifier=job_arn)['status']
#         print(f"Job status: {job_status}")

#         if job_status == 'COMPLETED':
#             output_s3_uri = bedrock_client.get_model_invocation_job(jobIdentifier=job_arn)['outputDataConfig']['s3OutputDataConfig']['s3Uri']
#             output_file_key = output_s3_uri.replace(f"s3://{output_bucket}/{output_prefix}", "")
#             output_file_name = output_file_key.split("/")[-1]
#             break
#         elif job_status == 'FAILED':
#             print("Job failed.")
#             break
#         else:
#             time.sleep(60)  # Wait for 1 minute before checking again
    
#     return output_s3_uri

# def get_schema(database_name, table_names=None):
#     try:
#         glue_client = boto3.client('glue', region_name=REGION)
#         table_schema_list = []
#         response = glue_client.get_tables(DatabaseName=database_name)

#         all_table_names = [table['Name'] for table in response['TableList']]

#         if table_names:
#             table_names = [name for name in table_names if name in all_table_names]
#         else:
#             table_names = all_table_names

#         for table_name in table_names:
#             response = glue_client.get_table(DatabaseName=database_name, Name=table_name)
#             columns = response['Table']['StorageDescriptor']['Columns']
#             schema = {column['Name']: column['Type'] for column in columns}
#             table_schema_list.append({"Table: {}".format(table_name): 'Schema: {}'.format(schema)})
#     except Exception as e:
#         print(f"Error: {str(e)}")
#     return table_schema_list

# def execute_athena_query(database, query):
#     athena_client = boto3.client('athena', region_name=REGION)
#     # Start query execution
#     response = athena_client.start_query_execution(
#         QueryString=query,
#         QueryExecutionContext={
#             'Database': database
#         },
#         ResultConfiguration={
#             'OutputLocation': outputLocation
#         }
#     )

#     # Get query execution ID
#     query_execution_id = response['QueryExecutionId']
#     print(f"Query Execution ID: {query_execution_id}")

#     # Wait for the query to complete
#     response_wait = athena_client.get_query_execution(QueryExecutionId=query_execution_id)

#     while response_wait['QueryExecution']['Status']['State'] in ['QUEUED', 'RUNNING']:
#         print("Query is still running...")
#         response_wait = athena_client.get_query_execution(QueryExecutionId=query_execution_id)

#     print(f'response_wait {response_wait}')

#     # Check if the query completed successfully
#     if response_wait['QueryExecution']['Status']['State'] == 'SUCCEEDED':
#         print("Query succeeded!")

#         # Get query results
#         query_results = athena_client.get_query_results(QueryExecutionId=query_execution_id)

#         # Extract and return the result data
#         code = 'SUCCEEDED'
#         return code, extract_result_data(query_results)

#     else:
#         print("Query failed!")
#         code = response_wait['QueryExecution']['Status']['State']
#         message = response_wait['QueryExecution']['Status']['StateChangeReason']
    
#         return code, message

# def extract_result_data(query_results):
#     #Return a cleaned response to the agent
#     result_data = []

#     # Extract column names
#     column_info = query_results['ResultSet']['ResultSetMetadata']['ColumnInfo']
#     column_names = [column['Name'] for column in column_info]

#     # Extract data rows
#     for row in query_results['ResultSet']['Rows']:
#         data = [item['VarCharValue'] for item in row['Data']]
#         result_data.append(dict(zip(column_names, data)))

#     return result_data


In [79]:
# SQL generation prompt

sql_template = """You are a SQL expert. You will be provided with the original user question and a SQL database schema. 
                Only return the SQL query and nothing else.
                Here is the original user question.
                <user_question>
                {user_question}
                </user_question>

                Here is the SQL database schema.
                <sql_database_schema>
                {sql_database_schema}
                </sql_database_schema>
                
                Instructions:
                Generate a SQL query that answers the original user question.
                Use the schema, first create a syntactically correct {sql_dialect} query to answer the question. 
                Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
                Pay attention to use only the column names that you can see in the schema description. 
                Be careful to not query for columns that do not exist. 
                Pay attention to which column is in which table. 
                Also, qualify column names with the table name when needed.
                If you cannot answer the user question with the help of the provided SQL database schema, 
                then output that this question question cannot be answered based of the information stored in the database.
                You are required to use the following format, each taking one line:
                Return the sql query inside the <SQL></SQL> tab.
                """

In [80]:
# router template
route_prompt_template="""                       
                        Instructions:
                        1. Give this question a difficulty rating from 1 to 3, where 3 is the most difficult and 1 is the easiest.
                        2. Return the difficulty inside <difficulty></difficulty> tags. 
                        3. If the score is 1, then generate the SQL query to answer the question.
                        4. Return the generated SQL qery inside <SQL></SQL> tags.
                        5. Review your formatted response. It needs to be valid XML.
                        
                        Question:
                        <user_question>
                        {user_question}
                        </user_question>

                        SQL database schema for the SQL dialect {sql_dialect}:
                        <sql_database_schema>
                        {sql_database_schema}
                        </sql_database_schema>
                        """

In [82]:
# Grading prompt

evaluation_template = """You are a SQL expert. 
                Your task is to evaluate a given SQL query based on a provided SQL schema and question using the criteria provided below.
 
                Evaluation Criteria (Additive Score, 0-5):
                1. Context: Award 1 point if the generated SQL query uses only information provided in the SQL schema, without introducing external or fabricated details.
                2. Completeness: Add 1 point if the generated SQL query addresses all key elements of the question based on the available SQL schema and Exact Set Match Accuracy (EM) score.
                3. ExecutionAccuracy: Add 1 point if the generated SQL query is very close to the groundtruth answer based on Execution Accuracy score.
                4. Faultless: Add 1 point if the generated SQL query ran without any errors.
                5. ValidEfficiencyScore:  Add 1 point if the runtime of the generated SQL query is similar or better than the the groundtruth qery as measured by the Valid Efficiency Score (VES).
                
                Evaluation Steps:
                1. Read provided context, question and answer carefully.
                2. Go through each evaluation criterion one by one and assess whether the answer meets the criteria.
                3. Compose your reasoning for each critera, explaining why you did or did not award a point. You can only award full points. 
                4. Calculate the total score by summing the points awarded.
                5. Think through the evaluation criteria inside <thinking></thinking> tags. 
                Then, output the total score inside <score></score> tags.
                Review your formatted response. It needs to be valid XML.
    
                Original question:
                <question>
                {question}
                </question>

                SQL schema:
                <sql_schema>
                {sql_schema}
                </sql_schema>

                Generated SQL query based on these instructions:
                <sql_query>
                {sql_query}
                </sql_query>

                SQL result based on the generated SQL query:
                <sql_query_run_result>
                {sql_query_run_result}
                </sql_query_run_result>

                Any SQL errors that might have occured based on the generated SQL query:
                <sql_query_run_error>
                {sql_query_run_error}
                </sql_query_run_error>

                Groundtruth SQL query for comparison with the generated SQL query:
                <groundtruth_sql_query>
                {groundtruth_sql_query}
                </groundtruth_sql_query>
                
                Execution Accuracy, which compares the generated SQL query to the labeled SQL query to determine if its a match or not: 
                <ex_score>
                {ex_score}
                </ex_score>
                
                Exact Set Match Accuracy (EM), which evaluates if the returned result set actually answer the question, regardless of how the query was written: 
                <em_score>
                {em_score}
                </em_score>

                Valid Efficiency Score (VES), which compares the runtime of the SQL provided as groundtruth to the generated SQL query:
                <ves_score>
                {ves_score}
                </ves_score>                
                """

### Create a Custom Classifier

In [3]:
# get graded data
file_path = 'question_query_small_llm_grades.jsonl'
df1_graded = pd.read_json(file_path, lines=True)


df1_graded["routing_label"] = df1_graded["score"].apply(
    lambda x: 1 if (x >=4) else 0
)

visualize_distribution(df1_graded, key="routing_label")

In [4]:
# balance the dataset for our classification task
balanced_train_df = balance_dataset(df1_graded, key="routing_label")
visualize_distribution(balanced_train_df, key="routing_label")
print(f"Train size: {len(balanced_train_df)}")

In [36]:
# sample training data and reformat to format for finetuning
from sklearn.model_selection import train_test_split

n_total_samples = 90
train_ratio = 0.75  # 75% for training, 25% for validation

# Calculate the number of samples for each set
n_train = int(n_total_samples * train_ratio)
n_val = n_total_samples - n_train

# Sample the data
sampled_df = balanced_train_df.sample(n=n_total_samples, random_state=42)

# Split the sampled data into training and validation sets
train_df, val_df = train_test_split(sampled_df, train_size=n_train, random_state=42)

# Define output file names
output_file = "sampled_train_data.jsonl"
val_output_file = "sampled_val_data.jsonl"


# reformat to format for finetuning
training_data = []
for index, row in train_df.iterrows():
    prompt = str(build_prediction_prompt(user_question=row['Question'], sql_database_schema=row['Context']))
    completion = str(row['routing_label'])
    training_data.append({'prompt': prompt, 'completion': completion})
training_df = pd.DataFrame(training_data)

# Explicitly set the data types to string
training_df['prompt'] = training_df['prompt'].astype(str)
training_df['completion'] = training_df['completion'].astype('int64')

print(training_df.columns)
training_df.head(1)
training_df.to_json(output_file, orient="records", lines=True)



val_data = []
for index, row in val_df.iterrows():
    prompt = str(build_prediction_prompt(user_question=row['Question'], sql_database_schema=row['Context']))
    completion = str(row['routing_label'])
    val_data.append({'prompt': prompt, 'completion': completion})
validation_df = pd.DataFrame(val_data)

# Explicitly set the data types to string
validation_df['prompt'] = validation_df['prompt'].astype(str)
validation_df['completion'] = validation_df['completion'].astype('int64')

print(validation_df.columns)
validation_df.head(1)
validation_df.to_json(val_output_file, orient="records", lines=True)



Index(['prompt', 'completion'], dtype='object')
Index(['prompt', 'completion'], dtype='object')


In [None]:
# # TBD: Finetune a small local LLM instruct model (DistilBERT) as binary classifier model
# # %pip install transformers torch scikit-learn transformers[torch] accelerate -U
# import torch
# from torch.utils.data import Dataset
# from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments
# from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# # Load tokenizer and model
# model_name = "distilbert-base-uncased"
# tokenizer = DistilBertTokenizer.from_pretrained(model_name)
# model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=2)


# class CustomDataset(Dataset):
#     def __init__(self, dataframe, tokenizer, max_length=128):
#         self.prompts = dataframe['prompt'].tolist()
#         self.completions = dataframe['completion'].tolist()
#         self.tokenizer = tokenizer
#         self.max_length = max_length

#     def __getitem__(self, idx):
#         prompt = self.prompts[idx]
#         completion = self.completions[idx]
        
#         # Tokenize the prompt and completion
#         encoding = self.tokenizer(
#             prompt,
#             completion,
#             truncation=True,
#             padding='max_length',
#             max_length=self.max_length,
#             return_tensors='pt'
#         )
        
#         # Remove the batch dimension
#         item = {key: val.squeeze(0) for key, val in encoding.items()}
        
#         # item['labels'] = torch.tensor(self.labels[idx])
        
#         return item

#     def __len__(self):
#         return len(self.prompts)

# # Create datasets with a DataFrame called training_df with 'prompt' and 'completion' columns
# train_dataset = CustomDataset(training_df, tokenizer)
# val_dataset = CustomDataset(validation_df, tokenizer)

# # Define training arguments
# training_args = TrainingArguments(
#     output_dir='./results',
#     num_train_epochs=3,
#     per_device_train_batch_size=16,
#     per_device_eval_batch_size=64,
#     warmup_steps=500,
#     weight_decay=0.01,
#     logging_dir='./logs',
#     logging_steps=10,
#     evaluation_strategy="epoch",
# )

# # Define metrics function
# def compute_metrics(pred):
#     labels = pred.label_ids
#     preds = pred.predictions.argmax(-1)
#     precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
#     acc = accuracy_score(labels, preds)
#     return {
#         'accuracy': acc,
#         'f1': f1,
#         'precision': precision,
#         'recall': recall
#     }

# # Create Trainer
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=train_dataset,
#     eval_dataset=val_dataset,
#     compute_metrics=compute_metrics,
# )

# # Train the model
# trainer.train()

# # Evaluate the model
# eval_results = trainer.evaluate()
# print(eval_results)

# # Save the model
# model.save_pretrained("./finetuned_model")
# tokenizer.save_pretrained("./finetuned_model")

# # Inference example
# def predict(text):
#     inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
#     outputs = model(**inputs)
#     probs = outputs.logits.softmax(dim=-1)
#     return probs.argmax().item()

# # Test the model
# test_text = "List all suppliers with their contact information."
# prediction = predict(test_text)
# print(f"Prediction for '{test_text}': {'Small_LLM' if prediction == 1 else 'Large_LLM'}")


### FINE-TUNING JOB

In [8]:
BEDROCK_FINE_TUNING = False
if BEDROCK_FINE_TUNING == True: # required Bedrock Provisioned Throughput to deploy
    # upload to S3
    bucket_name = 'felixh-demo'
    prefix = 'finetuning'
    filename = output_file
    s3_uri = dataframe_to_s3_jsonl(training_df, bucket_name, prefix, filename)
    print(f's3_uri: {s3_uri}')

    # Set parameters
    customizationType = "FINE_TUNING"
    baseModelIdentifier = "arn:aws:bedrock:us-east-1::foundation-model/amazon.titan-text-express-v1"
    roleArn = "arn:aws:iam::026459568683:role/admin"
    jobName = "Text-to-SQL-Routing-Classifier-Job-V2"
    customModelName = "LLM-Routing-Classifier"
    hyperParameters = {
            "epochCount": "1", # The maximum number of iterations through the entire training dataset
            "batchSize": "1", # The number of samples processed before updating model parameters
            "learningRate": ".0005", # Multiplier that influences the learning rate at which model parameters are updated after each batch
            "learningRateWarmupSteps": "0"
        }
    trainingDataConfig = {"s3Uri": s3_uri}
    outputDataConfig = {"s3Uri": f"s3://{bucket_name}/{prefix}/output"}

    # Create job
    response_ft = bedrock_client.create_model_customization_job(
        jobName=jobName, 
        customModelName=customModelName,
        roleArn=roleArn,
        baseModelIdentifier=baseModelIdentifier,
        hyperParameters=hyperParameters,
        trainingDataConfig=trainingDataConfig,
        outputDataConfig=outputDataConfig
    )

    jobArn = response_ft.get('jobArn')
    print(f'jobArn: {jobArn}')

    response = bedrock_client.get_model_customization_job(jobIdentifier=jobArn)
    status = response.get('status')
    if status == 'Completed':
        outputModelArn = response.get("outputModelArn")
        print(f'outputModelArn: {outputModelArn}')
        customModelName = "LLM-Routing-Classifier"
        response_pt = bedrock_client.create_provisioned_model_throughput(
            modelId= outputModelArn,
            provisionedModelName= customModelName,
            modelUnits=1
        )

        provisionedModelArn = response_pt.get('provisionedModelArn')
        print(f'provisionedModelArn: {provisionedModelArn}')
    else:
        print(f'finetuning job status: {status}')
        print(f'finetuning job response: {response}')

In [7]:
if BEDROCK_FINE_TUNING == False: # use SageMaker endpoint
    from sagemaker import hyperparameters
    from sagemaker.jumpstart.estimator import JumpStartEstimator
    # reformat to format for finetuning SageMaker
    from sagemaker.s3 import S3Uploader
    import sagemaker
    import random
    import json
    template = {
        "prompt": "{prompt}\n\n",
        "completion": " {completion}",
    }
    with open("template.json", "w") as f:
        json.dump(template, f)

    bucket_name = 'felixh-demo'
    prefix = 'finetuning'
    filename = output_file

    train_data_location = f"s3://{bucket_name}/{prefix}"
    S3Uploader.upload(output_file, train_data_location)
    S3Uploader.upload("template.json", train_data_location)
    print(f"Training data: {train_data_location}")


    model_id = 'huggingface-llm-mistral-7b' #check https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html
    model_version = '*'
    
    finetuning_hyperparameters = hyperparameters.retrieve_default(
        model_id=model_id, model_version=model_version
    )
    print(finetuning_hyperparameters)

    finetuning_hyperparameters["epoch"] = "1"
    finetuning_hyperparameters["per_device_train_batch_size"] = "2"
    finetuning_hyperparameters["gradient_accumulation_steps"] = "2"
    finetuning_hyperparameters["instruction_tuned"] = "True"
    finetuning_hyperparameters["max_input_length"] = "32000"

    #validate parameters
    hyperparameters.validate(
        model_id=model_id, model_version=model_version, hyperparameters=finetuning_hyperparameters
    )
    
    # start training
    instruction_tuned_estimator = JumpStartEstimator(
        model_id=model_id,
        hyperparameters=finetuning_hyperparameters,
        instance_type="ml.g5.12xlarge",
    )
    instruction_tuned_estimator.fit({"train": train_data_location}, logs=True)

    # get metrics
    from sagemaker import TrainingJobAnalytics

    training_job_name = instruction_tuned_estimator.latest_training_job.job_name

    df = TrainingJobAnalytics(training_job_name=training_job_name).dataframe()
    df.head(10)

    # deploy fine-tuned model
    instruction_tuned_predictor = instruction_tuned_estimator.deploy()


sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml


NameError: name 'output_file' is not defined

In [5]:
if BEDROCK_FINE_TUNING == False: # use SageMaker endpoint
    import sagemaker
    from sagemaker.predictor import Predictor
    import json

    # Initialize the SageMaker session
    sagemaker_session = sagemaker.Session()

    # Specify the endpoint name
    endpoint_name = "hf-llm-mistral-7b-2024-08-23-01-17-36-474"

    # Create the predictor
    predictor = Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker_session,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer()
    )

    prompt = build_sqlquerygen_prompt(user_question='What is the total number of customers?', sql_database_schema=schema)
    # Wrap the prompt in the instruction format if needed
    prompt = f"<s>[INST] {prompt} [/INST]"

    # prompt = f"<s>[INST] What is the total number of customers? [/INST]"
    # print(f'prompt: {prompt}')
    # Create the payload
    payload = {
    "inputs": prompt,
    "parameters": {
        "max_new_tokens": 500,
        "top_k": 5,
        "top_p": 0.75,
        "do_sample": True,
        "temperature": 0
    }
    }

    # Make the prediction
    result = predictor.predict(payload)
    print(f'result: {result}')

NameError: name 'BEDROCK_FINE_TUNING' is not defined

### Conclusion
In this tutorial, we have successfully built and evaluated a finetuned-LLM router. 
We generated synthetic labeled data using the LLM-as-a-judge method to train the model, finetuned an LLM classifier using Amazon Bedrock's API, 
and conducted offline evaluation.

### Sources

https://github.com/lm-sys/RouteLLM

https://medium.com/@learngrowthrive.fast/routellm-achieves-90-gpt-4-quality-at-80-lower-cost-6686e5f46e2a

https://medium.com/ai-insights-cobet/beyond-basic-chatbots-how-semantic-router-is-changing-the-game-783dd959a32d

https://medium.com/@bhawana.prs/semantic-routes-in-llms-to-make-chatbots-more-accurate-d99c17e30487


popular benchmarks: MT Bench, MMLU, and GSM8K.

* Semantic routing: Using a vector analysis to route the query to the closest “cluster”
https://github.com/aurelio-labs/semantic-router

* Prompt Chaining: Similar to what has been implemented inside Bedrock agents, and LangChain’s Custom function, these use an small LLM to analyze the question and route it to the next part of the chain. https://aws.amazon.com/blogs/machine-learning/enhance-conversational-ai-with-advanced-routing-techniques-with-amazon-bedrock/
You can optimize this by having the “router” model answer directly simple questions instead of routing them to another model.

* Intent Classification: Creating a custom model, similar to ROHF or Rerankers to classify the query and route it to the right LLM.  
https://medium.com/aimonks/intent-classification-generative-ai-based-application-architecture-3-79d2927537b4

https://www.anyscale.com/blog/building-an-llm-router-for-high-quality-and-cost-effective-responses

https://github.com/aws-samples/amazon-bedrock-samples/blob/main/function-calling/function_calling_text2SQL_converse_bedrock_streamlit.py

https://github.com/aws-samples/amazon-bedrock-samples/tree/main/rag-solutions/sql-query-generator

### Next steps

Explore other data sources:  https://bird-bench.github.io/ , latest spyder dataset, any SQL dataset from HF

<!-- # https://huggingface.co/datasets/b-mc2/sql-create-context

# from datasets import load_dataset
# # %sql sqlite:///routedb.db
# # to load SQL dataset from starcoder
## ds = load_dataset("bigcode/starcoderdata", data_dir="sql", split="train", token=True)
# ds = load_dataset("b-mc2/sql-create-context", split="train", token=True)
# from datasets import load_dataset_builder
# ds_builder = load_dataset_builder("b-mc2/sql-create-context")
# ds_builder.info.description
# ds_builder.info.features -->


### SCRATCHPAD

In [None]:
# mistral.mixtral-8x7b-instruct-v0:1

# Use the native inference API to send a text message to Anthropic Claude.

import boto3
import json

from botocore.exceptions import ClientError

# Create a Bedrock Runtime client in the AWS Region of your choice.
bedrock_runtime_client = boto3.client("bedrock-runtime", region_name="us-east-1")

# Set the model ID, e.g., Claude 3 Haiku.
model_id = "mistral.mixtral-8x7b-instruct-v0:1"

# Define the prompt for the model.
prompt = build_sqlquerygen_prompt(user_question= 'What is the total number of customers?', sql_database_schema= schema)

# Setup the system prompts and messages to send to the model.
system_prompts = [] # [{"text": "You are a helpful AI Assistant."}]
message = {
    "role": "user",
    "content": [{"text": prompt}]
}

messages = []
messages.append(message)

try:
    # Base inference parameters to use.
    inference_config = {"temperature": 0.0}
    # Additional inference parameters to use.
    additional_model_fields = {"top_k": 5}

    # Send the message.
    response = bedrock_runtime_client.converse(
        modelId=model_id,
        messages=messages,
        system=system_prompts,
        inferenceConfig=inference_config,
        additionalModelRequestFields=additional_model_fields
    )

    # Log token usage.
    text = response['output'].get('message').get('content')[0].get('text')
    print(f'text: {text}')
    token_usage = response['usage']
    print(f'token_usage: {token_usage}')
    latency = response['metrics'].get('latencyMs')
    print(f'latency: {latency}')
    
except Exception as e:
    print("Error with calling Bedrock: "+str(e))
    attempt+=1
    if attempt>3:
        print("Max attempts reached!")
        result_text = str(e)
        
    else:#retry in 10 seconds
        print("retry")
        time.sleep(10)

In [None]:
# initialize SageMaker predictor with endpoint name e.g. hf-llm-mistral-7b-2024-08-23-01-17-36-474
import sagemaker
from sagemaker.predictor import Predictor
import json

# Initialize the SageMaker session
sagemaker_session = sagemaker.Session()

# Specify the endpoint name
endpoint_name = "hf-llm-mistral-7b-2024-08-23-01-17-36-474"

# Create the predictor
predictor = Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker_session,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer()
)
prompt = build_sqlquerygen_prompt(user_question= 'What is the total number of customers?', sql_database_schema= schema)
# print(f'prompt: {prompt}')
# prompt = f"[INST] {prompt} [/INST]"
payload = {
    "inputs": prompt,
    "parameters": {
        "max_new_tokens": 500,
        "top_k": 5,
        "top_p": 0.75,
        "do_sample": True,
        "temperature": 0
    }
}

result = predictor.predict(payload)
print(f'result: {result}')