<a href="https://colab.research.google.com/github/matejkvassay/colab-notebooks/blob/master/langchain1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
%pip install -q --upgrade pip tiktoken cohere openai langchain langchain-openai langchain-community langchain-experimental huggingface

[0m

In [7]:
from google.colab import userdata
from langchain.chat_models import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain.globals import set_debug, set_verbose
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage

In [8]:
API_KEY = userdata.get('OPENAI_API_KEY')
llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=API_KEY, verbose=True)

  warn_deprecated(


# Basic prompt

In [9]:
prompt = ChatPromptTemplate.from_template("If {country} is a country tell me it's population. Otherwise politely refuse to answer.")
output_parser = StrOutputParser()

chain = prompt | llm | output_parser

chain.invoke({"country": "bratislava"})

'I apologize for any confusion, but Bratislava is not a country. It is the capital city of Slovakia. The population of Bratislava is approximately 437,725 as of 2021.'

# Basic prompt - output to JSON

In [10]:
class CityInfo(BaseModel):
    population: int = Field(description="population")
    explanation: str = Field(description="detailed step-by-step explanation of how was the population count derived")

parser = JsonOutputParser(pydantic_object=CityInfo)

TEMP = "Answer the user query.\n{format_instructions}\n. Query: Tell me population of {country}. Only answer if {country} is a country."


prompt = PromptTemplate(
    template=TEMP,
    input_variables=["country"],
    partial_variables={"format_instructions": parser.get_format_instructions()},
)

chain = prompt | llm | parser

chain.invoke({"country": "Paris"})

{'population': None, 'explanation': None}

# Chat

In [11]:
SQL_EXAMPLE = """
Explain following SQL query to
SELECT  DISTINCT (ct.date),
        cty.card_type_name,
        SUM (ct.amount) OVER (PARTITION BY cty.card_type_name ORDER BY ct.date ASC) AS transaction_running_total
FROM card_transaction ct
JOIN card_number cn ON ct.card_number_id = cn.id
JOIN card_type cty ON cn.card_type_id = cty.id
WHERE date > '2020-11-30' AND date <= '2020-12-31'
ORDER BY cty.card_type_name ASC;
"""

In [12]:
from langchain_core.messages import HumanMessage, SystemMessage

In [13]:
messages = [
    SystemMessage(content="You're a helpful assistant"),
    HumanMessage(content="What is the purpose of model regularization?"),
]

In [14]:
ai_msg = llm.invoke(messages)



In [15]:
ai_msg.content

"The purpose of model regularization is to prevent overfitting in machine learning models. Overfitting occurs when a model becomes too complex and starts to memorize the training data instead of generalizing patterns and relationships. Regularization techniques help to control the model's complexity and reduce the risk of overfitting.\n\nRegularization introduces additional constraints or penalties to the model during the training process. These constraints encourage the model to prioritize simpler solutions and avoid fitting noise or irrelevant patterns in the data. By doing so, regularization improves the model's ability to generalize well to unseen data and improves its performance on test or validation sets."

In [17]:
llm.invoke('Hey there!')

AIMessage(content='Hello! How can I assist you today?')

# Text2SQL

In [49]:
set_debug(False)

In [62]:
gpt3_5 = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=API_KEY, verbose=True)
gpt4 = ChatOpenAI(model="gpt-4", openai_api_key=API_KEY, verbose=True)

In [121]:
PROMPT_IDX = 1
QUERY_IDX = 2

explanation_prompts = [
"""Explain what data is following SQL query fetching to a non-technical user.
Be very brief, mention short summary, tables and filtering criteria. Format answer as list items.
""".strip(),

"""
Provide brief summary of what data is following SQL query fetching.
Start with "Summary" section containing 1 sentence explanation.
Afterwards follow by "Details" section where you mention what tables is query using and what filtering criteria is applied in short bullet points.
""",

"""
"""
]

queries = [
"""
select
FiscalYear,
Q1
from Data
where
Company = 'Royal Bank'
and
StatementType = 'Income Statement'
and
LineItem = 'Sales'
order by FiscalYear
""".strip(),

"""
WITH top_cust AS (
      SELECT SUM(sales.amount) AS sales_17,
             SUM(discounts.amount) AS discounts_17,
             SUM(sales.amount) + SUM(discounts.amount) AS net_sales_17,
             clients.name AS client,
             clients.id AS client_id
      FROM clients
      JOIN sales
      ON sales.client_id = clients.id
      JOIN discounts
      ON discounts.sale_id = sales.id
      WHERE EXTRACT(year FROM sales.transaction_date) = 2017
      GROUP BY 4,5
      ORDER BY 3 DESC
      LIMIT 10
      ),
      ytd_2016 AS (
      SELECT SUM(sales.amount) AS sales_16,
             SUM(discounts.amount) AS discounts_16,
             SUM(sales.amount) + SUM(discounts.amount) AS net_sales_16,
             clients.name AS client,
             clients.id AS id
      FROM clients
      JOIN sales
      ON sales.client_id = clients.id
      JOIN discounts
      ON discounts.sale_id = sales.id
      WHERE EXTRACT(month FROM sales.transaction_date) BETWEEN 01 AND #{month}
            AND EXTRACT(year FROM sales.transaction_date) = 2016
      GROUP BY 4,5
      )
      SELECT sales_16, discounts_16, net_sales_16, top_cust.*
      FROM top_cust
      JOIN ytd_2016 ON top_cust.client_id = ytd_2016.id;"
""",

"""
;WITH cte AS
(
    SELECT 1 as QuartersAgo, GETDATE() as DT,
        CAST(YEAR(DATEADD(MONTH, 3, GETDATE())) AS VARCHAR(4)) + CAST(DATEPART(QUARTER, DATEADD(MONTH, 3, GETDATE())) AS VARCHAR(1)) as FinancialQuarter
    UNION ALL
    SELECT QuartersAgo + 1, DATEADD(MONTH, -3, cte.DT),
        CAST(YEAR(DATEADD(MONTH, 3, DATEADD(MONTH, -3, cte.DT))) AS VARCHAR(4)) + CAST(DATEPART(QUARTER, DATEADD(MONTH, 3, DATEADD(MONTH, -3, cte.DT))) AS VARCHAR(1))
    FROM cte
    WHERE QuartersAgo < 4
)
SELECT FinancialQuarter, QuartersAgo FROM cte
""",

"""
SELECT
  abs(sum(amount)) as total_per_cat,
  abs(sum(amount))/(select sum(amount) from finance_flow where amount>0) *100 as percentage,
  category.name
FROM finance_flow
LEFT JOIN category ON finance_flow.cat_id = category.id
WHERE amount<0
GROUP BY category.name
""",

"""
SELECT i.InvoiceID, i.CustomerID, i.EmployeeID, i.storeID, i.InvoiceDateTimeUTC, ii.Price, ii.Quantity, ii.Price * ii.Quantity AS ItemTotal,
       SUM(ii.Price * ii.Quantity) OVER (PARTITION BY i.InvoiceID) AS InvoiceTotal,
       SUM(ii.Price * ii.Quantity) OVER (PARTITION BY CAST(I.InvoiceDateTimeUTC AS DATE)) AS DailyInvoiceTotal
  FROM dbo.Invoices i
    INNER JOIN dbo.InvoiceItems ii
      ON i.InvoiceID = ii.InvoiceID;
"""

]

def run_experiment(query_idx, prompt_idx, model):
  prompt = PromptTemplate="""
{explanation_prompt}

{sql_query}

""".strip()
  prompt = ChatPromptTemplate.from_template(prompt)
  prompt = prompt.partial(explanation_prompt=explanation_prompts[prompt_idx])
  output_parser = StrOutputParser()
  sql_expl_chain = prompt | model | output_parser


  q=queries[query_idx]
  resp = sql_expl_chain.invoke({"sql_query": q})
  print(f"SQL QUERY (no. {query_idx})\n\n"+q+f"\n\nEXPLANATION (prompt {PROMPT_IDX}):\n\n"+resp)

In [123]:
run_experiment(query_idx=3, prompt_idx=1, model=gpt3_5)

SQL QUERY (no. 3)


SELECT 
  abs(sum(amount)) as total_per_cat, 
  abs(sum(amount))/(select sum(amount) from finance_flow where amount>0) *100 as percentage,
  category.name 
FROM finance_flow 
LEFT JOIN category ON finance_flow.cat_id = category.id 
WHERE amount<0
GROUP BY category.name


EXPLANATION (prompt 1):

Summary: This SQL query fetches the total amount and percentage of negative amounts per category from the finance_flow table, along with the name of each category.

Details:
- Tables used: finance_flow, category
- Filtering criteria: 
  - amount<0
