In [1]:
import os
import sqlite3
from dotenv import dotenv_values
from langchain_openai import ChatOpenAI

config = {**dotenv_values("../configs/local.env")}

In [2]:
os.environ["OPENAI_API_KEY"] = config["OPENAI_API_KEY"]

In [3]:
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")

In [14]:
from operator import itemgetter
from langchain.chains import create_sql_query_chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.utilities import SQLDatabase

from sql_table_qa.answerers.langchain_answerer.langchain_sql_connector import execute_sql
from CONSTANTS impor 

execute_query = execute_sql
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
write_query = create_sql_query_chain(llm, db)

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.
    
    Question: {question}
    SQL Query: {query}
    SQL Result: {result}
    Answer: """
)

answer = answer_prompt | llm | StrOutputParser()

chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

chain.invoke({"question": "How many employees are there"})

  sample_rows_result = connection.execute(command)  # type: ignore


'There are 8 employees.'

In [16]:
chain.invoke({"question": "Which customer has spent the most money in total?"})

'The customer who has spent the most money in total is Helena Holý, with a total spent amount of $49.62.'

In [17]:
RunnablePassthrough.assign(query=write_query).assign(result=itemgetter("query") | execute_query).invoke({"question": "Which customer has spent the most money in total?"})

{'question': 'Which customer has spent the most money in total?',
 'query': 'SELECT c."CustomerId", c."FirstName", c."LastName", SUM(i."Total") AS TotalSpent\nFROM "Customer" c\nJOIN "Invoice" i ON c."CustomerId" = i."CustomerId"\nGROUP BY c."CustomerId"\nORDER BY TotalSpent DESC\nLIMIT 1;',
 'result': "[(6, 'Helena', 'Holý', 49.620000000000005)]"}

In [18]:
chain.invoke({"question": "How many horses are there?"})

'There are 347 horses in the database.'

In [19]:
chain.invoke({"question": "Can you tell me about the customers?"})

'Based on the SQL query and result provided, the customers in the database have the following information available: CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, and Email. The query returned the details of the first 5 customers in the database.'

In [9]:
write_query

RunnableAssign(mapper={
  input: RunnableLambda(...),
  table_info: RunnableLambda(...)
})
| RunnableLambda(lambda x: {k: v for k, v in x.items() if k not in ('question', 'table_names_to_use')})
| PromptTemplate(input_variables=['input', 'table_info'], partial_variables={'top_k': '5'}, template='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. B

In [10]:
write_query.invoke({"question": "Which customer has spent the most money in total?"})

'SELECT "CustomerID", SUM("Total") AS total_spent\nFROM orders\nGROUP BY "CustomerID"\nORDER BY total_spent DESC\nLIMIT 1;'

In [11]:
import pandas as pd

eval = pd.read_csv("../data/evaluation_dataset.csv")

In [12]:
eval

Unnamed: 0,question,sql_query,tables_needed,sql_query_soft_eval,answer_detailed,answer_brief,sql_detailed_answer
0,How many tracks are there in the database?,SELECT COUNT(*) FROM Track;,Track,SELECT COUNT(*) FROM Track;,"There are 3,503 tracks in the database.",3503,"(3503,)"
1,What is the total revenue from all invoices?,SELECT SUM(Total) FROM Invoice;,Invoice,SELECT SUM(Total) FROM Invoice;,Total revenue from all invoices is approximate...,2328.600000000004,"(2328.600000000004,)"
2,What is the name of the most popular genre by ...,SELECT Name FROM Genre WHERE GenreId = (SELECT...,"Genre, Track","SELECT Genre.Name, COUNT(*) AS TrackCount FROM...","The most popular genre is Rock with 1,297 tracks.",Rock,"('Rock', 1297)"
3,What is the highest amount ever billed to a si...,SELECT MAX(Total) FROM Invoice;,Invoice,"SELECT MAX(Total) AS HighestTotal, InvoiceId F...",The highest invoice total is $25.86 for invoic...,25.86,"(25.86, 404)"
4,What is the name of the track that has generat...,SELECT T.Name FROM Track T JOIN InvoiceLine IL...,"Track, InvoiceLine","SELECT Track.Name, SUM(InvoiceLine.UnitPrice *...",The track 'The Woman King' generated the most ...,The Woman King,"('The Woman King', 3.98)"
5,Which customer has spent the most money in total?,SELECT FirstName|| ' ' ||LastName FROM Custome...,"Customer, Invoice",SELECT Customer.FirstName || ' ' || Customer.L...,"Helena Holý spent the most money, totaling app...",Helena Holý,"('Helena Holý', 49.620000000000005)"
6,Which artist's tracks are the most purchased?,SELECT A.Name FROM Artist A JOIN Album Al ON A...,"Artist, Album, Track, InvoiceLine","SELECT Artist.Name, COUNT(*) AS TotalPurchases...",Iron Maiden's tracks were purchased 140 times.,Iron Maiden,"('Iron Maiden', 140)"
7,Which employee has generated the most revenue ...,SELECT E.FirstName|| ' ' ||E.LastName FROM Emp...,"Employee, Customer, Invoice",SELECT Customer.FirstName || ' ' || Customer.L...,Fynn Zimmermann generated the most revenue at ...,Fynn Zimmermann,"('Fynn Zimmermann', 833.0400000000016)"
8,What is the name of the most popular playlist ...,SELECT P.Name FROM Playlist P JOIN PlaylistTra...,"Playlist, PlaylistTrack","SELECT Playlist.Name, COUNT(*) AS TrackCount F...","The 'Music' playlist contains 3,290 tracks.",Music,"('Music', 3290)"
9,Which genre has generated the highest total re...,SELECT G.Name FROM Genre G JOIN Track T ON G.G...,"Genre, Track, InvoiceLine","SELECT Genre.Name, SUM(InvoiceLine.UnitPrice *...",Rock generated the highest total revenue at ap...,Rock,"('Rock', 826.6500000000061)"


In [15]:
import mlflow
from mlflow.metrics.genai import answer_correctness, answer_relevance
from datetime import datetime

from sql_table_qa.evaluators.llm_evaluators import openai_correctness_evaluator, openai_relevance_evaluator

# Initialize MLflow client and set the experiment
mlflow.set_tracking_uri(config["MLFLOW_TRACKING_URI"])  # Set this to your MLflow tracking server URI
experiment_name = "Naive Langchain Prototype"
mlflow.set_experiment(experiment_name)
run_prefix = "notebook-initial-test"

with mlflow.start_run(run_name=f"{run_prefix}-{datetime.now().strftime('%Y%m%d_%H%M%S')}"):
    # Log parameters and results to MLflow
    model_load_success = False
    try:
        model_info = mlflow.langchain.log_model(chain, "naive_langchain_model")
        model = mlflow.pyfunc.load_model(model_info.model_uri)
        model_load_success = True
    except TypeError:
        # Some error with SQLAlchemy object being unpickle-able
        model = chain

    answers = []
    for _, example in eval.iterrows():
        if model_load_success:
            ans = model.predict({"question": example["question"]})
        else:
            ans = model.invoke({"question": example["question"]})
        answers.append(ans)
    eval_w_ans = eval.drop(columns=[c for c in eval.columns if c not in ("question", "answer_detailed")]).assign(model_answer=answers)
    
    mlflow.log_table(data=eval_w_ans, artifact_file="answers.json")
    results = mlflow.evaluate(
        data = eval_w_ans,
        targets = "answer_detailed",
        predictions = "model_answer",
        evaluators=None,
        extra_metrics=[openai_correctness_evaluator, openai_relevance_evaluator, mlflow.metrics.latency()],
        evaluator_config={'col_mapping': {"inputs": "question"}}
    )
    print(f"See aggregated evaluation results below:")
    display(results.metrics)

    # Evaluation result for each data record is available in `results.tables`.
    eval_table = results.tables["eval_results_table"]
    print(f"See evaluation table below:")
    display(eval_table)

  string_columns = trimmed_df.columns[(df.applymap(type) == str).all(0)]
  data = data.applymap(_hash_array_like_element_as_bytes)
  data = data.applymap(_hash_array_like_element_as_bytes)
2024/04/26 17:44:26 INFO mlflow.models.evaluation.base: Evaluating the model with the default evaluator.
2024/04/26 17:44:26 INFO mlflow.models.evaluation.default_evaluator: Testing metrics on first row...


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

2024/04/26 17:44:30 INFO mlflow.models.evaluation.default_evaluator: Evaluating metrics: answer_correctness


  0%|          | 0/16 [00:00<?, ?it/s]

2024/04/26 17:44:33 INFO mlflow.models.evaluation.default_evaluator: Evaluating metrics: answer_relevance


  0%|          | 0/16 [00:00<?, ?it/s]

See aggregated evaluation results below:


{'latency/mean': 0.0,
 'latency/variance': 0.0,
 'latency/p90': 0.0,
 'answer_correctness/v1/mean': 4.0625,
 'answer_correctness/v1/variance': 1.68359375,
 'answer_correctness/v1/p90': 5.0,
 'answer_relevance/v1/mean': 4.625,
 'answer_relevance/v1/variance': 1.109375,
 'answer_relevance/v1/p90': 5.0}

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

See evaluation table below:


Unnamed: 0,question,answer_detailed,model_answer,latency,answer_correctness/v1/score,answer_correctness/v1/justification,answer_relevance/v1/score,answer_relevance/v1/justification
0,How many tracks are there in the database?,"There are 3,503 tracks in the database.","There are 3,503 tracks in the database.",0,5,The output provided by the model is correct an...,5,The output directly answers the question by pr...
1,What is the total revenue from all invoices?,Total revenue from all invoices is approximate...,The total revenue from all invoices is $2328.60.,0,5,The output provided by the model is correct an...,5,The output directly provides the total revenue...
2,What is the name of the most popular genre by ...,"The most popular genre is Rock with 1,297 tracks.",The name of the most popular genre by number o...,0,5,The output is correct and demonstrates a high ...,5,The output provides the name of the most popul...
3,What is the highest amount ever billed to a si...,The highest invoice total is $25.86 for invoic...,The highest amount ever billed to a single cus...,0,5,The output provided by the model is correct an...,5,The output directly addresses the question by ...
4,What is the name of the track that has generat...,The track 'The Woman King' generated the most ...,The track that has generated the most revenue ...,0,1,The output is completely incorrect as it state...,5,The output directly answers the question by pr...
5,Which customer has spent the most money in total?,"Helena Holý spent the most money, totaling app...",The customer who has spent the most money in t...,0,5,The output provided by the model is correct an...,5,The output directly addresses the input questi...
6,Which artist's tracks are the most purchased?,Iron Maiden's tracks were purchased 140 times.,The artist whose tracks are the most purchased...,0,5,The output correctly identifies Iron Maiden as...,5,The output directly addresses the input questi...
7,Which employee has generated the most revenue ...,Fynn Zimmermann generated the most revenue at ...,Jane Peacock has generated the most revenue fr...,0,3,The output correctly identifies Jane Peacock a...,5,The output directly answers the input question...
8,What is the name of the most popular playlist ...,"The 'Music' playlist contains 3,290 tracks.",The name of the most popular playlist by numbe...,0,5,The output is correct and demonstrates a high ...,5,The output provides the exact name of the most...
9,Which genre has generated the highest total re...,Rock generated the highest total revenue at ap...,The genre that has generated the highest total...,0,5,The output provided by the model is correct an...,5,The output directly addresses the input questi...
