In [2]:
from openai import OpenAI
from dotenv import load_dotenv
import os
import langchain
from langchain.chat_models import ChatOpenAI
from langchain.agents import tool
from typing import Optional
from pydantic import BaseModel, Field
from langchain.tools.render import format_tool_to_openai_function



In [3]:
langchain.__version__

'0.0.354'

In [5]:

# Load environment variables from .env file
load_dotenv()

# Get the OpenAI API key from the environment
api_key = os.getenv("OPENAI_API_KEY")
# Check if the API key is available
if not api_key:
    raise ValueError("API key is not set. Make sure it is available in your .env file.")

In [8]:
client = OpenAI(api_key=api_key)

# Rest of your script remains unchanged
response = client.chat.completions.create(
    model="gpt-3.5-turbo",
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Who won the world series in 2020?"},
        {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
        {"role": "user", "content": "Where was it played?"}
    ]
)
print(response)


ChatCompletion(id='chatcmpl-8dMVvOxk2Iv7hOzIylxoCQ0tpiDCp', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='The World Series in 2020 was played at Globe Life Field in Arlington, Texas.', role='assistant', function_call=None, tool_calls=None))], created=1704392027, model='gpt-3.5-turbo-0613', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=18, prompt_tokens=53, total_tokens=71))


## Defining Tools

In [9]:
import psycopg2
import pandas as pd


connection_params = {
  "user": "postgres",
  "password": "password",
  "database": "youtube_db",
  "host": "localhost",
  "port": "5432"
}


def get_postgres_data(query:str):
    # Establish a connection to the PostgreSQL database
    conn = psycopg2.connect(**connection_params)

    # Create a cursor object to execute queries
    cursor = conn.cursor()

    try:
        # Execute the SQL query
        cursor.execute(query)

        # Fetch all the results
        results = cursor.fetchall()

        # Get the cursor description
        description = cursor.description

        # Return both results and cursor description
        return results, description

    finally:
        cursor.close()
        conn.close()

def get_postgres_df(query, connection_params=connection_params):
    # Execute the query and get the results along with the cursor description
    results, description = get_postgres_data(query)
    print(results)
    print(description)

    # Get the column names from the cursor description
    columns = [desc[0] for desc in description]

    # Create a DataFrame from the results
    df = pd.DataFrame(results, columns=columns)

    return df

# Example usage:
postgres_query = "SELECT * FROM cities_table_data"
postgres_df = get_postgres_df(postgres_query)
# print(postgres_df)


[('Total', None, None, None, 26625, 2157.3497, '0:04:51'), ('0x164b85cef5ab402d:0x8467b6b037a24d49', 'Addis Ababa', 'ET', 'ET-AA', 1252, 127.5042, '0:06:06'), ('0x487a4d4c5226f5db:0xd9be143804fe6baa', 'Manchester', 'GB', 'GB-ENG', 23, 0.0402, '0:00:06'), ('0x3568eb6de823cd35:0x35d8cb74247108a7', 'Busan', 'KR', 'KR-26', 14, 0.0354, '0:00:09'), ('0x168e8fde9837cabf:0x191f55de7e67db40', 'Khartoum', 'SD', 'SD-KH', 12, 1.2627, '0:06:18'), ('0x3397ba0942ef7375:0x4a9a32d9fe083d40', 'Quezon City', 'PH', 'PH-00', 11, 0.0615, '0:00:20'), ('0x3ae253d10f7a7003:0x320b2e4d32d3838d', 'Colombo', 'LK', 'LK-1', 11, 0.0363, '0:00:11'), ('0x102354e509f894f7:0xc8fde921f89849f6', 'Cotonou', 'BJ', 'BJ-LI', 10, 1.1665, '0:06:59'), ('0x399bfd991f32b16b:0x93ccba8909978be7', 'Lucknow', 'IN', 'IN-UP', 10, 0.029, '0:00:10')]
(Column(name='Cities', type_code=25), Column(name='City name', type_code=25), Column(name='Geography', type_code=25), Column(name='Geography.1', type_code=25), Column(name='Views', type_code=2

In [10]:
from typing import List, Tuple


class SQLQuery(BaseModel):
    query: str = Field(description="SQL query to execute")

@tool
def execute_sql(query: str) -> Tuple[List[Tuple], List[Tuple]]:
    """Returns the result of SQL query execution and cursor description"""
    return get_postgres_data(query)

In [11]:
execute_sql('SELECT * FROM device_type_table_data WHERE "Average view duration" = \'0:05:19\'')

([('Computer', 19267, 1709.2645, '0:05:19')],
 (Column(name='Device type', type_code=25),
  Column(name='Views', type_code=20),
  Column(name='Watch time (hours)', type_code=701),
  Column(name='Average view duration', type_code=25)))

In [12]:
class SQLTable(BaseModel):
    table: str = Field(description="Table name")

@tool
def get_table_columns(table:str) -> str:
    """Returns a list of table column names and types in JSON"""
    
    query = f'''
   SELECT column_name, data_type
    FROM information_schema.columns
    WHERE table_name = '{table}';
    '''
        
    result_df = get_postgres_df(query)
    
    # Convert the result DataFrame to a list of dictionaries
    result_list = result_df.to_dict('records')
    
    # Convert the list to a JSON-formatted string
    return str(result_list)

In [13]:
get_table_columns({'table': 'cities_table_data'})

[('Cities', 'text'), ('City name', 'text'), ('Geography', 'text'), ('Geography.1', 'text'), ('Views', 'bigint'), ('Watch time (hours)', 'double precision'), ('Average view duration', 'text')]
(Column(name='column_name', type_code=1043), Column(name='data_type', type_code=1043))


"[{'column_name': 'Cities', 'data_type': 'text'}, {'column_name': 'City name', 'data_type': 'text'}, {'column_name': 'Geography', 'data_type': 'text'}, {'column_name': 'Geography.1', 'data_type': 'text'}, {'column_name': 'Views', 'data_type': 'bigint'}, {'column_name': 'Watch time (hours)', 'data_type': 'double precision'}, {'column_name': 'Average view duration', 'data_type': 'text'}]"

In [14]:
class SQLTableColumn(BaseModel):
    database: str = Field(description="Database name")
    table: str = Field(description="Table name")
    column: str = Field(description="Column name")
    n: Optional[int] = Field(description="Number of rows, default limit 10")

@tool
def get_table_column_distr(database: str, table: str, column: str, n:int = 10) -> str:
    """Returns top n values for the column in JSON"""
    print(column)
    q = f'''

     SELECT "{column}", COUNT(1) AS count
    FROM {table}
    GROUP BY 1
    ORDER BY 2 DESC
    LIMIT {n};
   
    '''
    
    return str(list(get_postgres_df(q)[column].values))

In [15]:
get_table_column_distr({'database': 'youtube_db', 'table': 'cities_table_data', 'column': 'Cities'})

Cities
[('0x399bfd991f32b16b:0x93ccba8909978be7', 1), ('Total', 1), ('0x164b85cef5ab402d:0x8467b6b037a24d49', 1), ('0x102354e509f894f7:0xc8fde921f89849f6', 1), ('0x3397ba0942ef7375:0x4a9a32d9fe083d40', 1), ('0x3568eb6de823cd35:0x35d8cb74247108a7', 1), ('0x3ae253d10f7a7003:0x320b2e4d32d3838d', 1), ('0x487a4d4c5226f5db:0xd9be143804fe6baa', 1), ('0x168e8fde9837cabf:0x191f55de7e67db40', 1)]
(Column(name='Cities', type_code=25), Column(name='count', type_code=20))


"['0x399bfd991f32b16b:0x93ccba8909978be7', 'Total', '0x164b85cef5ab402d:0x8467b6b037a24d49', '0x102354e509f894f7:0xc8fde921f89849f6', '0x3397ba0942ef7375:0x4a9a32d9fe083d40', '0x3568eb6de823cd35:0x35d8cb74247108a7', '0x3ae253d10f7a7003:0x320b2e4d32d3838d', '0x487a4d4c5226f5db:0xd9be143804fe6baa', '0x168e8fde9837cabf:0x191f55de7e67db40']"

In [16]:
sql_functions = list(map(format_tool_to_openai_function, [execute_sql, get_table_columns, get_table_column_distr]))

In [17]:
sql_tools = {
    'execute_sql': execute_sql,
    'get_table_columns': get_table_columns,
    'get_table_column_distr': get_table_column_distr
}

### OpenAI Functions Agent

In [18]:
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.agents.format_scratchpad import format_to_openai_function_messages
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser

In [19]:
llm = ChatOpenAI(temperature=0.1, model = 'gpt-3.5-turbo')\
  .bind(functions = sql_functions)

In [20]:
system_message = '''
You are working as a data analyst for the 10 Academy YouTube channel. Your role is crucial, as the insights you provide guide the decision-making process for the content and strategy team. Accuracy in data reporting is paramount, and if there is uncertainty about a request, you seek clarification before providing an answer.

The data for analysis is stored in a SQL database, and the relevant tables along with their descriptions and columns are as follows:
when generating sql query, put the column name under double quotation
- device_type_chart_data: Information about the views on different devices. Columns: "Date" (TEXT), "Device type" (TEXT), "Views" (INTEGER)

- device_type_table_data: Detailed data on device types, including watch time and average view duration. Columns: "Device type" (TEXT), "Views" (INTEGER), "Watch time (hours)" (DOUBLE PRECISION), "Average view duration" (TEXT)

- subscription_source_chart_data: Data on subscribers based on subscription sources. Columns: "Date" (TEXT), "Subscription source" (TEXT), "Subscribers" (INTEGER)

- subscription_source_table_data: Detailed information on subscription sources, including subscribers gained and lost. Columns: "Subscription source" (TEXT), "Subscribers" (INTEGER), "Subscribers gained" (INTEGER), "Subscribers lost" (INTEGER)

- viewership_by_date_table_data: Overview of views, watch time, and average view duration by date. Columns: "Date" (TEXT), "Views" (DOUBLE PRECISION), "Watch time (hours)" (DOUBLE PRECISION), "Average view duration" (TEXT)

- viewer_gender_table_data: Data on viewer gender, including views percentage and watch time distribution. Columns: "Viewer gender" (TEXT), "Views (%)" (DOUBLE PRECISION), "Average view duration" (TEXT), "Average percentage viewed (%)" (DOUBLE PRECISION), "Watch time (hours) (%)" (DOUBLE PRECISION)

- traffic_source_chart_data: Information about views from different traffic sources. Columns: "Date" (TEXT), "Traffic source" (TEXT), "Views" (INTEGER)

- traffic_source_table_data: Detailed data on traffic sources, including watch time, average view duration, impressions, and click-through rate. Columns: "Traffic source" (TEXT), "Views" (DOUBLE PRECISION), "Watch time (hours)" (DOUBLE PRECISION), "Average view duration" (TEXT), "Impressions" (DOUBLE PRECISION), "Impressions click-through rate (%)" (DOUBLE PRECISION)

- subtitles_and_cc_chart_data: Analysis of views based on the presence of subtitles and closed captions. Columns: "Date" (TEXT), "Subtitles and CC" (TEXT), "Views" (INTEGER)

- subtitles_and_cc_table_data: Detailed data on subtitles and closed captions, including watch time and average view duration. Columns: "Subtitles and CC" (TEXT), "Views" (INTEGER), "Watch time (hours)" (DOUBLE PRECISION), "Average view duration" (TEXT)

- new_and_returning_viewers_chart_data: Data on new and returning viewers and their respective views. Columns: "Date" (TEXT), "New and returning viewers" (TEXT), "Views" (INTEGER)

- new_and_returning_viewers_table_data: Detailed information on new and returning viewers, including watch time and average view duration. Columns: "New and returning viewers" (TEXT), "Views" (INTEGER), "Watch time (hours)" (DOUBLE PRECISION), "Average view duration" (TEXT)

- operating_system_chart_data: Analysis of views based on different operating systems. Columns: "Date" (TEXT), "Operating system" (TEXT), "Views" (INTEGER)

- operating_system_table_data: Detailed data on operating systems, including watch time and average view duration. Columns: "Operating system" (TEXT), "Views" (INTEGER), "Watch time (hours)" (DOUBLE PRECISION), "Average view duration" (TEXT)

- cities_chart_data: Information about views based on different cities. Columns: "Date" (TEXT), "Cities" (TEXT), "City name" (TEXT), "Views" (INTEGER)

- cities_table_data: Detailed data on cities, including geography, views, watch time, and average view duration. Columns: "Cities" (TEXT), "City name" (TEXT), "Geography" (TEXT), "Geography.1" (TEXT), "Views" (INTEGER), "Watch time (hours)" (DOUBLE PRECISION), "Average view duration" (TEXT)

- viewer_age_table_data: Data on viewer age, including views percentage and watch time distribution. Columns: "Viewer age" (TEXT), "Views (%)" (DOUBLE PRECISION), "Average view duration" (TEXT), "Average percentage viewed (%)" (DOUBLE PRECISION), "Watch time (hours) (%)" (DOUBLE PRECISION)

- subscription_status_chart_data: Analysis of views based on subscription status. Columns: "Date" (TEXT), "Subscription status" (TEXT), "Views" (INTEGER)

- subscription_status_table_data: Detailed data on subscription status, including views, watch time, and average view duration. Columns: "Subscription status" (TEXT), "Views" (INTEGER), "Watch time (hours)" (DOUBLE PRECISION), "Average view duration" (TEXT)

- geography_chart_data: Overview of views based on different geographic locations. Columns: "Date" (TEXT), "Geography" (TEXT), "Views" (INTEGER)

- geography_table_data: Detailed data on geography, including views, watch time, and average view duration. Columns: "Geography" (TEXT), "Views" (INTEGER), "Watch time (hours)" (DOUBLE PRECISION), "Average view duration" (TEXT)

- sharing_service_chart_data: Information about shares on different sharing services. Columns: "Date" (TEXT), "Sharing service" (TEXT), "Shares" (INTEGER)

- sharing_service_table_data: Detailed data on sharing services, including the number of shares. Columns: "Sharing service" (TEXT), "Shares" (INTEGER)

- content_type_chart_data: Analysis of views based on different content types. Columns: "Date" (TEXT), "Content type" (TEXT), "Views" (INTEGER)

- content_type_table_data: Detailed data on content types, including views, watch time, and average view duration. Columns: "Content type" (TEXT), "Views" (INTEGER), "Watch time (hours)" (DOUBLE PRECISION), "Average view duration" (TEXT)

- totals_table_data: Overall summary of views, subscribers, and shares by date. Columns: "Date" (TEXT), "Views" (DOUBLE PRECISION), "Subscribers" (DOUBLE PRECISION), "Shares" (DOUBLE PRECISION)

'''

analyst_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_message),
        ("user", "{question}"),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ]
)

In [21]:
analyst_agent = (
    {
        "question": lambda x: x["question"],
        "agent_scratchpad": lambda x: format_to_openai_function_messages(x["intermediate_steps"]),
    }
    | analyst_prompt
    | llm
    | OpenAIFunctionsAgentOutputParser()
)

In [22]:
analyst_agent.invoke({"question": "which device have the highest Average view duration", "intermediate_steps": []})

AgentActionMessageLog(tool='get_table_column_distr', tool_input={'database': 'device_type_table_data', 'table': 'device_type_table_data', 'column': 'Average view duration', 'n': 1}, log="\nInvoking: `get_table_column_distr` with `{'database': 'device_type_table_data', 'table': 'device_type_table_data', 'column': 'Average view duration', 'n': 1}`\n\n\n", message_log=[AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\n  "database": "device_type_table_data",\n  "table": "device_type_table_data",\n  "column": "Average view duration",\n  "n": 1\n}', 'name': 'get_table_column_distr'}})])

In [23]:
from langchain_core.agents import AgentFinish

question = "which device have the highest Average view duration?", "intermediate_steps"
intermediate_steps = []
num_iters = 0

while True:
    if num_iters >= 10:  
        break

    output = analyst_agent.invoke(
        {
            "question": question,
            "intermediate_steps": intermediate_steps,
        }
    )
    num_iters += 1

    if isinstance(output, AgentFinish):
        model_output = output.return_values["output"]
        break
    else:
        print(f'Executing tool: {output.tool}, arguments: {output.tool_input}')
        observation = sql_tools[output.tool](output.tool_input)
        print(f'Observation: {observation}')
        print()
        intermediate_steps.append((output, observation))
        
print('Model output:', model_output)

Executing tool: get_table_columns, arguments: {'table': 'device_type_table_data'}
[('Device type', 'text'), ('Views', 'bigint'), ('Watch time (hours)', 'double precision'), ('Average view duration', 'text')]
(Column(name='column_name', type_code=1043), Column(name='data_type', type_code=1043))
Observation: [{'column_name': 'Device type', 'data_type': 'text'}, {'column_name': 'Views', 'data_type': 'bigint'}, {'column_name': 'Watch time (hours)', 'data_type': 'double precision'}, {'column_name': 'Average view duration', 'data_type': 'text'}]

Executing tool: execute_sql, arguments: {'query': 'SELECT "Device type", "Average view duration" FROM device_type_table_data ORDER BY "Average view duration" DESC LIMIT 1'}
Observation: ([('Computer', '0:05:19')], (Column(name='Device type', type_code=25), Column(name='Average view duration', type_code=25)))

Model output: The device with the highest average view duration is "Computer" with an average view duration of 0 hours, 5 minutes, and 19 seco

In [25]:
from langchain.agents import AgentExecutor

analyst_agent_executor = AgentExecutor(
    agent=analyst_agent, 
    tools=[execute_sql, get_table_columns, get_table_column_distr], 
    verbose=True,
    max_iterations=10, # early stopping criteria
    early_stopping_method='generate', 
    # to ask model to generate the final answer after stopping
)

response = analyst_agent_executor.invoke(
  {"question": "what is the Device type having the highest Average view duration?"}
)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `get_table_column_distr` with `{'database': 'device_type_table_data', 'table': 'device_type_table_data', 'column': 'Average view duration', 'n': 1}`


[0mAverage view duration
[('0:05:19', 1)]
(Column(name='Average view duration', type_code=25), Column(name='count', type_code=20))
[38;5;200m[1;3m['0:05:19'][0m[32;1m[1;3m
Invoking: `execute_sql` with `{'query': 'SELECT "Device type" FROM device_type_table_data WHERE "Average view duration" = \'0:05:19\''}`


[0m[36;1m[1;3m([('Computer',)], (Column(name='Device type', type_code=25),))[0m[32;1m[1;3mThe device type with the highest average view duration is "Computer".[0m

[1m> Finished chain.[0m


In [27]:
response["output"]

'The device type with the highest average view duration is "Computer".'

In [53]:
from langchain.agents import AgentType, Tool, initialize_agent
from langchain.schema import SystemMessage

agent_kwargs = {
    "system_message": SystemMessage(content=system_message)
}

analyst_agent_openai = initialize_agent(
    llm=ChatOpenAI(temperature=0.1, model = 'gpt-3.5-turbo'),
    agent=AgentType.OPENAI_FUNCTIONS, 
    tools=[execute_sql, get_table_columns, get_table_column_distr], 
    agent_kwargs=agent_kwargs,
    verbose=True,
    max_iterations=10,
    early_stopping_method='generate'
)

In [54]:
analyst_agent_openai.get_input_schema().schema()

{'title': 'ChainInput',
 'type': 'object',
 'properties': {'input': {'title': 'Input'}}}

In [55]:
analyst_agent_openai.get_output_schema().schema()

{'title': 'ChainOutput',
 'type': 'object',
 'properties': {'output': {'title': 'Output'}}}

In [57]:
analyst_agent_openai.run("what is the Device type having the highest Average view duration?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `get_table_column_distr` with `{'database': 'device_type_table_data', 'table': 'device_type_table_data', 'column': 'Average view duration', 'n': 1}`


[0mAverage view duration
[('0:05:19', 1)]
(Column(name='Average view duration', type_code=25), Column(name='count', type_code=20))
[38;5;200m[1;3m['0:05:19'][0m[32;1m[1;3m
Invoking: `execute_sql` with `{'query': 'SELECT "Device type" FROM device_type_table_data WHERE "Average view duration" = \'0:05:19\''}`


[0m[36;1m[1;3m([('Computer',)], (Column(name='Device type', type_code=25),))[0m[32;1m[1;3mThe device type with the highest average view duration is "Computer".[0m

[1m> Finished chain.[0m


'The device type with the highest average view duration is "Computer".'