In [1]:
from google.cloud import bigquery
from sqlalchemy import *
from sqlalchemy.engine import create_engine
from sqlalchemy.schema import *
from pybigquery.api import ApiClient
import os

from langchain.agents import create_sql_agent, AgentExecutor
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.llms.openai import AzureOpenAI

In [2]:
# BigQuery initialization
service_account = "gcp-service-account.json"
project = 'gcp-project-id'
dataset = 'langchain_tutorial'
table = 'churn-dataset'

sqlalchemy_url = f"bigquery://{project}/{dataset}?credentials_path={service_account}"

In [3]:
# Initialize OpenAI
path = 'azure_openai_api_key.txt'
with open(path) as f:
    API_key = f.readlines()[0]
    
os.environ['OPENAI_API_KEY'] = API_key
os.environ['OPENAI_API_TYPE'] = 'azure'
os.environ['OPENAI_API_VERSION'] = '2023-03-15-preview'
os.environ['OPENAI_API_BASE'] = 'openai-base-url'

In [None]:
# Setup LangChain agent
db = SQLDatabase.from_uri(sqlalchemy_url)

llm = AzureOpenAI(deployment_name = "text-davinci-003",
                  model_name = "text-davinci-003")

toolkit = SQLDatabaseToolkit(db = db,
                             llm = llm)

agent_executor = create_sql_agent(llm = llm,
                                  toolkit = toolkit,
                                  verbose = True,
                                  top_k = 1000) #Limit the max number of results

In [None]:
agent_executor.run("How many male users churned? ")

In [6]:
agent_executor.run("""How many users churned which had internet service?
How many churned which had no internet service?
And for each of the groups, how many did not churn?""")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m
Observation: [38;5;200m[1;3mchurn-dataset[0m
Thought:[32;1m[1;3m I should query the schema of the churn-dataset table
Action: sql_db_schema
Action Input: "churn-dataset"[0m
Observation: [33;1m[1;3m
CREATE TABLE `churn-dataset` (
	`customerID` STRING, 
	`gender` STRING, 
	`SeniorCitizen` INT64, 
	`Partner` BOOL, 
	`Dependents` BOOL, 
	`tenure` INT64, 
	`PhoneService` BOOL, 
	`MultipleLines` STRING, 
	`InternetService` STRING, 
	`OnlineSecurity` STRING, 
	`OnlineBackup` STRING, 
	`DeviceProtection` STRING, 
	`TechSupport` STRING, 
	`StreamingTV` STRING, 
	`StreamingMovies` STRING, 
	`Contract` STRING, 
	`PaperlessBilling` BOOL, 
	`PaymentMethod` STRING, 
	`MonthlyCharges` FLOAT64, 
	`TotalCharges` STRING, 
	`Churn` BOOL OPTIONS(description='')
)

/*
3 rows from churn-dataset table:
customerID	gender	SeniorCitizen	Partner	Dependents	tenure	PhoneService	MultipleLines	Interne

'1756 users churned with internet service, 113 users churned without internet service, 3761 users did not churn with internet service and 1413 users did not churn without internet service.'