#### SQL few shot example
https://python.langchain.com/docs/use_cases/qa_structured/sql

In [9]:
import os
import sys
import getpass
from dotenv import load_dotenv, dotenv_values
import pandas as pd
import openai

from IPython.display import display, Markdown, Latex, HTML, JSON
import sqlite3
from sqlite3 import Error
import pymysql
from sqlalchemy import create_engine, text as sql_text

import langchain
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

from cmd import PROMPT
from pyexpat.errors import messages

import tiktoken

from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document

sys.path.append(r"/Users/dovcohen/Documents/Projects/AI/NL2SQL")

from OpenAI.NL2SQL.lib_OpenAI import GenAI_NL2SQL


In [10]:
def Instantiate_OpenAI_Class(Model="gpt-3.5-turbo-instruct", Max_Tokens=250, Encoding_Base = "cl100k_base",Temperature=0):
    load_dotenv("/Users/dovcohen/.NL2SQL_env")
    # SQL DB
    DB = 'mysql'
    MYSQL_USER = os.getenv("MYSQL_USER", None)
    MYSQL_PWD = os.getenv("MYSQL_PWD", None)

    # OpenAI
    OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", None)
    # LLM parameters
   # Model = ["gpt-3.5-turbo-instruct",'text-embedding-ada-002']
    Token_Cost = {"gpt-3.5-turbo-instruct":{"Input":0.0015/1000,"Output":0.002/1000},
                 "text-embedding-ada-002":{"Input":0.0001/1000}}
    

    #Instantiate GenAI_NL2SQL Object
    return GenAI_NL2SQL(OPENAI_API_KEY, Model, Encoding_Base, Max_Tokens, Temperature, \
                        Token_Cost,DB, MYSQL_USER, MYSQL_PWD)



In [11]:
ADA=Instantiate_OpenAI_Class( Model = 'text-embedding-ada-002')

data = [['tom', 10], ['nick', 15], ['juli', 14]] 
  
# Create the pandas DataFrame 
df = pd.DataFrame(data, columns=['Name', 'Age']) 
  
# print dataframe. 
df 



Unnamed: 0,Name,Age
0,tom,10
1,nick,15
2,juli,14


### Embeddings Tutorial
https://cookbook.openai.com/examples/get_embeddings_from_dataset

The dataset used in this example is fine-food reviews from Amazon. The dataset contains a total of 568,454 food reviews Amazon users left up to October 2012. We will use a subset of this dataset, consisting of 1,000 most recent reviews for illustration purposes. The reviews are in English and tend to be positive or negative. Each review has a ProductId, UserId, Score, review title (Summary) and review body (Text).

We will combine the review summary and review text into a single combined text. The model will encode this combined text and it will output a single vector embedding.

To run this notebook, you will need to install: pandas, openai, transformers, plotly, matplotlib, scikit-learn, torch (transformer dep), torchvision, and scipy.



In [6]:
max_tokens = 8000
Datafile =  '/Users/dovcohen/Documents/Projects/Data/Amazon_Fine_Food_Reviews/Reviews.csv'
#
#Datafile =  '/Users/dovcohen/Documents/Projects/Data/Amazon_Fine_Food_Reviews/Reviews_Combined_Text.csv'
df = pd.read_csv(Datafile,index_col=0)
df = df[["Time", "ProductId", "UserId", "Score", "Summary", "Text"]]
df = df.dropna()
df["Combined"] = (
    "Title: " + df.Summary.str.strip() + "; Content: " + df.Text.str.strip()
)
#df.head(2)
top_n = 1000
df = df.sort_values("Time").tail(top_n * 2)  # first cut to first 2k entries, assuming less than half will be filtered out
df.drop("Time", axis=1, inplace=True)

encoding = tiktoken.get_encoding("cl100k_base")

# omit reviews that are too long to embed
df["n_tokens"] = df.Combined.apply(lambda x: len(encoding.encode(x)))
df = df[df.n_tokens <= max_tokens].tail(top_n)
len(df)
df.head(5)
Datafile =  '/Users/dovcohen/Documents/Projects/Data/Amazon_Fine_Food_Reviews/Reviews_1k_tokens.csv'
df.to_csv(Datafile, encoding='utf-8', index=False)


In [17]:
Datafile =  '/Users/dovcohen/Documents/Projects/Data/Amazon_Fine_Food_Reviews/Reviews_1k_tokens.csv'
df = pd.read_csv(Datafile,index_col=0)

df1 = df.reset_index()
df2 = df1.iloc[0:5]
#x = df['Combined'][0:5]
#x.iloc[2]
#Embed.OpenAI_Get_Embedding(x)
#openai.Embedding.create(input = x[0], model="text-embedding-ada-002")['data'][0]['embedding']
#openai.Embedding.create(input = x, model=model)['data'][0]['embedding']
#df2['ada_embedding'] = df2.Combined.apply(lambda x: ADA.OpenAI_Get_Embedding(x))
print(df2)
#df.to_csv('output/embedded_1k_reviews.csv', index=False)

    ProductId          UserId  Score  \
0  B003XPF9BO  A3R7JR3FMEBXQB      5   
1  B003JK537S  A3JBPC3WFUT5ZP      1   
2  B000JMBE7M   AQX1N6A51QOKG      4   
3  B004AHGBX4  A2UY46X0OSNVUQ      3   
4  B001BORBHO  A1AFOYZ9HSM2CZ      5   

                                             Summary  \
0  where does one  start...and stop... with a tre...   
1                                  Arrived in pieces   
2          It isn't blanc mange, but isn't bad . . .   
3        These also have SALT and it's not sea salt.   
4                             Happy with the product   

                                                Text  \
0  Wanted to save some to bring to my Chicago fam...   
1  Not pleased at all. When I opened the box, mos...   
2  I'm not sure that custard is really custard wi...   
3  I like the fact that you can see what you're g...   
4  My dog was suffering with itchy skin.  He had ...   

                                            Combined  n_tokens  
0  Title: where does

In [19]:
df2 = df2.reset_index()
df2['ada_embedding'] = df2.Combined.apply(lambda x: ADA.OpenAI_Get_Embedding(x))


In [20]:
df2

Unnamed: 0,index,ProductId,UserId,Score,Summary,Text,Combined,n_tokens,ada_embedding
0,0,B003XPF9BO,A3R7JR3FMEBXQB,5,where does one start...and stop... with a tre...,Wanted to save some to bring to my Chicago fam...,Title: where does one start...and stop... wit...,52,"[0.007060592994093895, -0.02732112631201744, 0..."
1,1,B003JK537S,A3JBPC3WFUT5ZP,1,Arrived in pieces,"Not pleased at all. When I opened the box, mos...",Title: Arrived in pieces; Content: Not pleased...,35,"[-0.023609420284628868, -0.011784634552896023,..."
2,2,B000JMBE7M,AQX1N6A51QOKG,4,"It isn't blanc mange, but isn't bad . . .",I'm not sure that custard is really custard wi...,"Title: It isn't blanc mange, but isn't bad . ....",267,
3,3,B004AHGBX4,A2UY46X0OSNVUQ,3,These also have SALT and it's not sea salt.,I like the fact that you can see what you're g...,Title: These also have SALT and it's not sea s...,239,"[0.010532955639064312, -0.01354704238474369, 0..."
4,4,B001BORBHO,A1AFOYZ9HSM2CZ,5,Happy with the product,My dog was suffering with itchy skin. He had ...,Title: Happy with the product; Content: My dog...,86,"[0.015255776233971119, -0.003898625960573554, ..."


In [None]:
def get_embedding(text, model="text-embedding-ada-002"):
   text = text.replace("\n", " ")
   return openai.Embedding.create(input = [text], model="text-embedding-ada-002")['data'][0]['embedding']

#df['ada_embedding'] = df.co1mbined.apply(lambda x: get_embedding(x, model='text-embedding-ada-002'))
#df.to_csv('output/embedded_1k_reviews.csv', index=False)

In [None]:
embeddings = OpenAIEmbeddings()

few_shot_docs = [Document(page_content=question, metadata={'sql_query': few_shots[question]}) for question in few_shots.keys()]
vector_db = FAISS.from_documents(few_shot_docs, embeddings)
retriever = vector_db.as_retriever()

### Few Shot Prompting with code llama langchain and mysql
https://medium.com/@yernenip/few-shot-prompting-with-codellama-langchain-and-mysql-94020ee16a08

### Setting up the example prompts

In [24]:
from langchain.prompts.prompt import PromptTemplate

examples = [
        {
            "input": "How many customers are from district California?",
            "sql_cmd": "SELECT COUNT(*) FROM customer cu JOIN address ad ON cu.address_id = ad.address_id \
            WHERE ad.district = 'California';",
            "result": "[(9,)]",
            "answer": "There are 9 customers from California",
        },
        {
            "input": "How many customers are from city San Bernardino?",
            "sql_cmd": "SELECT COUNT(*) FROM customer cu JOIN address ad ON cu.address_id = ad.address_id \
            JOIN city ci  ON ad.city_id = ci.city_id WHERE ci.city = 'San Bernardino';",
            "result": "[(1,)]",
            "answer": "There is 1 customer from San Bernardino",
        },
        {
            "input": "How many customers are from country United States?",
            "sql_cmd": "SELECT COUNT(*) FROM customer cu JOIN address ad ON cu.address_id = ad.address_id \
            JOIN city ci ON ad.city_id = ci.city_id JOIN country co ON ci.country_id = co.country_id \
            WHERE co.country = 'United States';",
            "result": "[(36,)]",
            "answer": "There are 36 customers from United States",
        },
]

example_prompt = PromptTemplate(
    input_variables=["input", "sql_cmd", "result", "answer",],
    template="\nQuestion: {input}\nSQLQuery: {sql_cmd}\nSQLResult: {result}\nAnswer: {answer}",
)

### Vectorizing the examples and using an example selector
#### Using Chroma DB Vector Store - running in-memory

In [27]:
# !pip install chromadb

In [28]:
from langchain.prompts import SemanticSimilarityExampleSelector
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma


embeddings = HuggingFaceEmbeddings()

to_vectorize = [" ".join(example.values()) for example in examples]

vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=examples)

example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,
    k=1,
)

### Setting up the few shot prompt

In [30]:
from langchain.prompts import FewShotPromptTemplate
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt

#print(PROMPT_SUFFIX)

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=_mysql_prompt,
    suffix=PROMPT_SUFFIX, 
    input_variables=["input", "table_info", "top_k"], #These variables are used in the prefix and suffix
)

In [31]:
PROMPT_SUFFIX

'Only use the following tables:\n{table_info}\n\nQuestion: {input}'

In [32]:
_mysql_prompt

'You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use CURDATE() function to get the current date, if the question involves "today".\n\nUse the following format:\n\nQuestion: Question here\nSQLQuery: SQL Query to run\nSQLResult

The _mysql_prompt and PROMPT_SUFFIX variables contain additional prompt text 
that provides instructions to give more context to the LLM. They also include the input variables.

### Calling the LLM with few shot prompting
<br> Finally, let’s prompt the LLM using the few-shot prompt and examine the result. <br>
My question differs slightly from the example; I replaced “United States” with “Canada.”


In [None]:
local_chain = SQLDatabaseChain.from_llm(llm, db, prompt=few_shot_prompt, use_query_checker=True, 
                                        verbose=True, return_sql=False,)
local_chain.run("How many customers are from country Canada?")

FewShotPromptTemplate(input_variables=['input', 'table_info', 'top_k'], example_selector=SemanticSimilarityExampleSelector(vectorstore=<langchain.vectorstores.chroma.Chroma object at 0x1549d20d0>, k=1, example_keys=None, input_keys=None), example_prompt=PromptTemplate(input_variables=['input', 'sql_cmd', 'result', 'answer'], template='\nQuestion: {input}\nSQLQuery: {sql_cmd}\nSQLResult: {result}\nAnswer: {answer}'), suffix='Only use the following tables:\n{table_info}\n\nQuestion: {input}', prefix='You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the column

In [None]:
few_shots = {'List all artists.': 'SELECT * FROM artists;',
#              "Find all albums for the artist 'AC/DC'.": "SELECT * FROM albums WHERE ArtistId = (SELECT ArtistId FROM artists WHERE Name = 'AC/DC');",
#              "List all tracks in the 'Rock' genre.": "SELECT * FROM tracks WHERE GenreId = (SELECT GenreId FROM genres WHERE Name = 'Rock');",
#              'Find the total duration of all tracks.': 'SELECT SUM(Milliseconds) FROM tracks;',
#              'List all customers from Canada.': "SELECT * FROM customers WHERE Country = 'Canada';",
#              'How many tracks are there in the album with ID 5?': 'SELECT COUNT(*) FROM tracks WHERE AlbumId = 5;',
#              'Find the total number of invoices.': 'SELECT COUNT(*) FROM invoices;',
#              'List all tracks that are longer than 5 minutes.': 'SELECT * FROM tracks WHERE Milliseconds > 300000;',
#              'Who are the top 5 customers by total purchase?': 'SELECT CustomerId, SUM(Total) AS TotalPurchase FROM invoices GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;',
#              'Which albums are from the year 2000?': "SELECT * FROM albums WHERE strftime('%Y', ReleaseDate) = '2000';",
#              'How many employees are there': 'SELECT COUNT(*) FROM "employee"'
#             }